diff --git a/src/cnn/include/train.h b/src/cnn/include/train.h index b7c8cdb..f9901e5 100644 --- a/src/cnn/include/train.h +++ b/src/cnn/include/train.h @@ -8,7 +8,7 @@ /* -* Structure donnée en argument à la fonction 'train_thread' + * Structure donnée en argument à la fonction 'train_thread' */ typedef struct TrainParameters { Network* network; // Réseau @@ -25,11 +25,24 @@ typedef struct TrainParameters { float loss; // Loss (à renvoyer) } TrainParameters; +/* + * Structure donnée en argument à la fonction 'load_image' +*/ +typedef struct LoadImageParameters { + jpegDataset* dataset; // Dataset si de type JPEG + int index; // Numéro de l'image à charger +} LoadImageParameters; + /* * Partie entière supérieure de a/b */ int div_up(int a, int b); +/* + * Fonction auxiliaire pour charger (ouvrir et décompresser) les images de manière asynchrone + * économise environ 20ms par image pour des images de taille 256*256*3 +*/ +void* load_image(void* parameters); /* * Fonction auxiliaire d'entraînement destinée à être exécutée sur plusieurs threads à la fois diff --git a/src/cnn/train.c b/src/cnn/train.c index 3ccc891..62515db 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -28,6 +28,19 @@ int div_up(int a, int b) { // Partie entière supérieure de a/b return ((a % b) != 0) ? (a / b + 1) : (a / b); } +void* load_image(void* parameters) { + LoadImageParameters* param = (LoadImageParameters*)parameters; + + if (!param->dataset->images[param->index]) { + imgRawImage* image = loadJpegImageFile(param->dataset->fileNames[param->index]); + param->dataset->images[param->index] = image->lpData; + free(image); + } else { + printf_warning((char*)"Image déjà chargée\n"); // Pas possible techniquement, donc on met un warning + } + + return NULL; +} void* train_thread(void* parameters) { TrainParameters* param = (TrainParameters*)parameters; @@ -49,6 +62,16 @@ void* train_thread(void* parameters) { float accuracy = 0.; float loss = 0.; + pthread_t tid; + LoadImageParameters* load_image_param; + if (dataset_type != 0) { + load_image_param = (LoadImageParameters*)malloc(sizeof(LoadImageParameters)); + load_image_param->dataset = param->dataset; + load_image_param->index = index[start]; + + pthread_create(&tid, NULL, load_image, (void*) load_image_param); + } + for (int i=start; i < start+nb_images; i++) { if (dataset_type == 0) { write_image_in_network_32(images[index[i]], height, width, network->input[0][0], true); @@ -70,11 +93,17 @@ void* train_thread(void* parameters) { accuracy += 1.; } } else { + pthread_join(tid, NULL); if (!param->dataset->images[index[i]]) { image = loadJpegImageFile(param->dataset->fileNames[index[i]]); param->dataset->images[index[i]] = image->lpData; free(image); } + + if (i != start+nb_images-1) { + load_image_param->index = index[i+1]; + pthread_create(&tid, NULL, load_image, (void*) load_image_param); + } write_image_in_network_260(param->dataset->images[index[i]], height, width, network->input[0]); forward_propagation(network); maxi = indice_max(network->input[network->size-1][0][0], param->dataset->numCategories); @@ -89,6 +118,10 @@ void* train_thread(void* parameters) { } } + if (dataset_type != 0) { + free(load_image_param); + } + param->accuracy = accuracy; param->loss = loss; return NULL;