From d03f7493b2e33a6fb6c148446a35199695de0a4c Mon Sep 17 00:00:00 2001 From: augustin64 Date: Wed, 16 Nov 2022 10:38:01 +0100 Subject: [PATCH] Move update --- src/cnn/include/train.h | 2 +- src/cnn/train.c | 24 +++++++++++++----------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/cnn/include/train.h b/src/cnn/include/train.h index 58f7c3a..95356b8 100644 --- a/src/cnn/include/train.h +++ b/src/cnn/include/train.h @@ -4,7 +4,7 @@ #define DEF_TRAIN_H #define EPOCHS 10 -#define BATCHES 100 +#define BATCHES 120 #define USE_MULTITHREADING diff --git a/src/cnn/train.c b/src/cnn/train.c index 7084191..9a19e9a 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -3,6 +3,7 @@ #include #include #include +#include #include "../mnist/include/mnist.h" #include "include/initialisation.h" @@ -46,19 +47,13 @@ void* train_thread(void* parameters) { int start = param->start; int nb_images = param->nb_images; float accuracy = 0.; - int cpt=1; for (int i=start; i < start+nb_images; i++) { if (dataset_type == 0) { write_image_in_network_32(images[i], height, width, network->input[0][0]); forward_propagation(network); maxi = indice_max(network->input[network->size-1][0][0], 10); backward_propagation(network, labels[i]); - if (cpt==16) { // Update the network - update_weights(network); - update_bias(network); - cpt = 0; - } - cpt++; + if (maxi == labels[i]) { accuracy += 1.; } @@ -74,6 +69,7 @@ void* train_thread(void* parameters) { void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out) { + srand(time(NULL)); int input_dim = -1; int input_depth = -1; float accuracy; @@ -190,19 +186,25 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di free_network(train_parameters[k]->network); } current_accuracy = accuracy * nb_images_total/(j*BATCHES); - printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); + printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); #else train_params->start = j*BATCHES; + train_thread((void*)train_params); + accuracy += train_params->accuracy / (float) nb_images_total; current_accuracy = accuracy * nb_images_total/(j*BATCHES); - printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); + + update_weights(network); + update_bias(network); + + printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); #endif } #ifdef USE_MULTITHREADING - printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100); + printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET" \n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100); #else - printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", i, epochs, nb_images_total, nb_images_total, accuracy*100); + printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET" \n", i, epochs, nb_images_total, nb_images_total, accuracy*100); #endif write_network(out, network); }