diff --git a/src/cnn/train.c b/src/cnn/train.c index a5bcf02..e161e2f 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -91,7 +91,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di float accuracy; float current_accuracy; - int nb_images_total; + int nb_images_total; // Images au total + int nb_images_total_remaining; // Images restantes dans un batch + int batches_epoques; // Batches par époque int*** images; unsigned int* labels; @@ -180,9 +182,17 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di // du multi-threading car chaque copie du réseau initiale sera légèrement différente // et donnera donc des résultats différents sur les mêmes images. accuracy = 0.; - for (int j=0; j < nb_images_total / BATCHES; j++) { + batches_epoques = nb_images_total / BATCHES; + nb_images_total_remaining = nb_images_total; + for (int j=0; j < batches_epoques; j++) { #ifdef USE_MULTITHREADING - nb_remaining_images = BATCHES; + if (j == batches_epoques-1) { + nb_remaining_images = nb_images_total_remaining; + nb_images_total_remaining = 0; + } else { + nb_images_total_remaining -= BATCHES; + nb_remaining_images = BATCHES; + } for (int k=0; k < nb_threads; k++) { if (k == nb_threads-1) { @@ -191,8 +201,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di } else { nb_remaining_images -= BATCHES / nb_threads; } + train_parameters[k]->start = BATCHES*j + (BATCHES/nb_threads)*k; train_parameters[k]->network = copy_network(network); - train_parameters[k]->start = BATCHES*j + (nb_images_total/BATCHES)*k; + pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]); } for (int k=0; k < nb_threads; k++) { @@ -206,6 +217,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di } current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES); printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); + fflush(stdout); #else train_params->start = j*BATCHES;