diff --git a/src/cnn/train.c b/src/cnn/train.c index 7b858e9..a26220b 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -205,74 +205,76 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di batches_epoques = div_up(nb_images_total, BATCHES); nb_images_total_remaining = nb_images_total; #ifndef USE_MULTITHREADING - train_params->nb_images = BATCHES; + train_params->nb_images = BATCHES; #endif for (int j=0; j < batches_epoques; j++) { #ifdef USE_MULTITHREADING - if (j == batches_epoques-1) { - nb_remaining_images = nb_images_total_remaining; - nb_images_total_remaining = 0; - } else { - nb_images_total_remaining -= BATCHES; - nb_remaining_images = BATCHES; - } - - for (int k=0; k < nb_threads; k++) { - if (k == nb_threads-1) { - train_parameters[k]->nb_images = nb_remaining_images; - nb_remaining_images = 0; + if (j == batches_epoques-1) { + nb_remaining_images = nb_images_total_remaining; + nb_images_total_remaining = 0; } else { - nb_remaining_images -= BATCHES / nb_threads; + nb_images_total_remaining -= BATCHES; + nb_remaining_images = BATCHES; } - train_parameters[k]->start = BATCHES*j + (BATCHES/nb_threads)*k; - train_parameters[k]->network = copy_network(network); - if (train_parameters[k]->start+train_parameters[k]->nb_images >= nb_images_total) { - train_parameters[k]->nb_images = nb_images_total - train_parameters[k]->start -1; + for (int k=0; k < nb_threads; k++) { + if (k == nb_threads-1) { + train_parameters[k]->nb_images = nb_remaining_images; + nb_remaining_images = 0; + } else { + nb_remaining_images -= BATCHES / nb_threads; + } + train_parameters[k]->start = BATCHES*j + (BATCHES/nb_threads)*k; + + if (train_parameters[k]->start+train_parameters[k]->nb_images >= nb_images_total) { + train_parameters[k]->nb_images = nb_images_total - train_parameters[k]->start -1; + } + if (train_parameters[k]->nb_images > 0) { + train_parameters[k]->network = copy_network(network); + pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]); + } else { + train_parameters[k]->network = NULL; + } } - if (train_parameters[k]->nb_images > 0) { - pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]); - } else { - tid[k] = 0; + for (int k=0; k < nb_threads; k++) { + // On attend la terminaison de chaque thread un à un + if (train_parameters[k]->network) { + pthread_join( tid[k], NULL ); + accuracy += train_parameters[k]->accuracy / (float) nb_images_total; + } } - } - for (int k=0; k < nb_threads; k++) { - // On attend la terminaison de chaque thread un à un - if (tid[k] != 0) { - pthread_join( tid[k], NULL ); - accuracy += train_parameters[k]->accuracy / (float) nb_images_total; + + // On attend que tous les fils aient fini avant d'appliquer des modifications au réseau principal + for (int k=0; k < nb_threads; k++) { + if (train_parameters[k]->network) { // Si le fil a été utilisé + update_weights(network, train_parameters[k]->network, train_parameters[k]->nb_images); + update_bias(network, train_parameters[k]->network, train_parameters[k]->nb_images); + free_network(train_parameters[k]->network); + } } - } - - // On attend que tous les fils aient fini avant d'appliquer des modifications au réseau principal - for (int k=0; k < nb_threads; k++) { - update_weights(network, train_parameters[k]->network, train_parameters[k]->nb_images); - update_bias(network, train_parameters[k]->network, train_parameters[k]->nb_images); - free_network(train_parameters[k]->network); - } - current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES); - printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); - fflush(stdout); + current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES); + printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); + fflush(stdout); #else - (void)nb_images_total_remaining; // Juste pour enlever un warning + (void)nb_images_total_remaining; // Juste pour enlever un warning - train_params->start = j*BATCHES; + train_params->start = j*BATCHES; - // Ne pas dépasser le nombre d'images à cause de la partie entière - if (j == batches_epoques-1) { - train_params->nb_images = nb_images_total - j*BATCHES; - } - - train_thread((void*)train_params); - - accuracy += train_params->accuracy / (float) nb_images_total; - current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES); - - update_weights(network, network, train_params->nb_images); - update_bias(network, network, train_params->nb_images); - - printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); - fflush(stdout); + // Ne pas dépasser le nombre d'images à cause de la partie entière + if (j == batches_epoques-1) { + train_params->nb_images = nb_images_total - j*BATCHES; + } + + train_thread((void*)train_params); + + accuracy += train_params->accuracy / (float) nb_images_total; + current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES); + + update_weights(network, network, train_params->nb_images); + update_bias(network, network, train_params->nb_images); + + printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); + fflush(stdout); #endif } end_time = omp_get_wtime();