From 3e7309c6e13bda4d1dac28b92f2a97aabfb68bf0 Mon Sep 17 00:00:00 2001 From: augustin64 Date: Tue, 15 Nov 2022 17:50:33 +0100 Subject: [PATCH] Add current_accuracy estimation --- src/cnn/train.c | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/cnn/train.c b/src/cnn/train.c index 22b4c07..00067ed 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -54,7 +54,6 @@ void* train_thread(void* parameters) { maxi = indice_max(network->input[network->size-1][0][0], 10); backward_propagation(network, labels[i]); if (cpt==16) { // Update the network - printf("a\n"); update_weights(network); update_bias(network); cpt = 0; @@ -78,9 +77,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di int input_dim = -1; int input_depth = -1; float accuracy; + float current_accuracy; int nb_images_total; - int nb_remaining_images; int*** images; unsigned int* labels; @@ -109,6 +108,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di Network* network = create_network_lenet5(0, 0, TANH, GLOROT, input_dim, input_depth); #ifdef USE_MULTITHREADING + int nb_remaining_images; // Nombre d'images restantes à lancer pour une série de threads // Récupération du nombre de threads disponibles int nb_threads = get_nprocs(); pthread_t *tid = (pthread_t*)malloc(nb_threads * sizeof(pthread_t)); @@ -168,9 +168,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di // et donnera donc des résultats différents sur les mêmes images. accuracy = 0.; for (int j=0; j < nb_images_total / BATCHES; j++) { + #ifdef USE_MULTITHREADING nb_remaining_images = BATCHES; - #ifdef USE_MULTITHREADING for (int k=0; k < nb_threads; k++) { if (k == nb_threads-1) { train_parameters[k]->nb_images = nb_remaining_images; @@ -189,18 +189,20 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di // TODO patch_network(network, train_parameters[k]->network, train_parameters[k]->nb_images); free_network(train_parameters[k]->network); } - printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: %0.1f%%", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, accuracy*100); + current_accuracy = accuracy * nb_images_total/(j*BATCHES); + printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); #else train_params->start = j*BATCHES; train_thread((void*)train_params); accuracy += train_params->accuracy / (float) nb_images_total; - printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: %0.1f%%", i, epochs, BATCHES*(j+1), nb_images_total, accuracy*100); + current_accuracy = accuracy * nb_images_total/(j*BATCHES); + printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); #endif } #ifdef USE_MULTITHREADING - printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: %0.1f%%\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100); + printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100); #else - printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: %0.1f%%\n", i, epochs, nb_images_total, nb_images_total, accuracy*100); + printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", i, epochs, nb_images_total, nb_images_total, accuracy*100); #endif write_network(out, network); }