Add current_accuracy estimation

This commit is contained in:
augustin64 2022-11-15 17:50:33 +01:00
parent ef4c5087b7
commit 3e7309c6e1

View File

@ -54,7 +54,6 @@ void* train_thread(void* parameters) {
maxi = indice_max(network->input[network->size-1][0][0], 10); maxi = indice_max(network->input[network->size-1][0][0], 10);
backward_propagation(network, labels[i]); backward_propagation(network, labels[i]);
if (cpt==16) { // Update the network if (cpt==16) { // Update the network
printf("a\n");
update_weights(network); update_weights(network);
update_bias(network); update_bias(network);
cpt = 0; cpt = 0;
@ -78,9 +77,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
int input_dim = -1; int input_dim = -1;
int input_depth = -1; int input_depth = -1;
float accuracy; float accuracy;
float current_accuracy;
int nb_images_total; int nb_images_total;
int nb_remaining_images;
int*** images; int*** images;
unsigned int* labels; unsigned int* labels;
@ -109,6 +108,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
Network* network = create_network_lenet5(0, 0, TANH, GLOROT, input_dim, input_depth); Network* network = create_network_lenet5(0, 0, TANH, GLOROT, input_dim, input_depth);
#ifdef USE_MULTITHREADING #ifdef USE_MULTITHREADING
int nb_remaining_images; // Nombre d'images restantes à lancer pour une série de threads
// Récupération du nombre de threads disponibles // Récupération du nombre de threads disponibles
int nb_threads = get_nprocs(); int nb_threads = get_nprocs();
pthread_t *tid = (pthread_t*)malloc(nb_threads * sizeof(pthread_t)); pthread_t *tid = (pthread_t*)malloc(nb_threads * sizeof(pthread_t));
@ -168,9 +168,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
// et donnera donc des résultats différents sur les mêmes images. // et donnera donc des résultats différents sur les mêmes images.
accuracy = 0.; accuracy = 0.;
for (int j=0; j < nb_images_total / BATCHES; j++) { for (int j=0; j < nb_images_total / BATCHES; j++) {
#ifdef USE_MULTITHREADING
nb_remaining_images = BATCHES; nb_remaining_images = BATCHES;
#ifdef USE_MULTITHREADING
for (int k=0; k < nb_threads; k++) { for (int k=0; k < nb_threads; k++) {
if (k == nb_threads-1) { if (k == nb_threads-1) {
train_parameters[k]->nb_images = nb_remaining_images; train_parameters[k]->nb_images = nb_remaining_images;
@ -189,18 +189,20 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
// TODO patch_network(network, train_parameters[k]->network, train_parameters[k]->nb_images); // TODO patch_network(network, train_parameters[k]->network, train_parameters[k]->nb_images);
free_network(train_parameters[k]->network); free_network(train_parameters[k]->network);
} }
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: %0.1f%%", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, accuracy*100); current_accuracy = accuracy * nb_images_total/(j*BATCHES);
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
#else #else
train_params->start = j*BATCHES; train_params->start = j*BATCHES;
train_thread((void*)train_params); train_thread((void*)train_params);
accuracy += train_params->accuracy / (float) nb_images_total; accuracy += train_params->accuracy / (float) nb_images_total;
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: %0.1f%%", i, epochs, BATCHES*(j+1), nb_images_total, accuracy*100); current_accuracy = accuracy * nb_images_total/(j*BATCHES);
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
#endif #endif
} }
#ifdef USE_MULTITHREADING #ifdef USE_MULTITHREADING
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: %0.1f%%\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100); printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100);
#else #else
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: %0.1f%%\n", i, epochs, nb_images_total, nb_images_total, accuracy*100); printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", i, epochs, nb_images_total, nb_images_total, accuracy*100);
#endif #endif
write_network(out, network); write_network(out, network);
} }