From 5a0f807a00e83c530592cd224df58bfd1ef65313 Mon Sep 17 00:00:00 2001 From: augustin64 Date: Fri, 18 Nov 2022 14:09:49 +0100 Subject: [PATCH] Update update.c --- src/cnn/include/cnn.h | 9 --------- src/cnn/include/update.h | 4 ++-- src/cnn/train.c | 5 +++-- src/cnn/update.c | 35 +++++++++++++++++++++++------------ 4 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/cnn/include/cnn.h b/src/cnn/include/cnn.h index d7dfa08..f5b4f14 100644 --- a/src/cnn/include/cnn.h +++ b/src/cnn/include/cnn.h @@ -34,15 +34,6 @@ void drop_neurones(float*** input, int depth, int dim1, int dim2, int dropout); */ void copy_input_to_input_z(float*** output, float*** output_a, int output_depth, int output_rows, int output_columns); -/* -* Bascule les données de d_weights dans weights -*/ -void update_weights(Network* network); - -/* -* Bascule les données de d_bias dans bias -*/ -void update_bias(Network* network); /* * Renvoie l'erreur du réseau neuronal pour une sortie (RMS) */ diff --git a/src/cnn/include/update.h b/src/cnn/include/update.h index 9c16886..4b50be9 100644 --- a/src/cnn/include/update.h +++ b/src/cnn/include/update.h @@ -7,13 +7,13 @@ * Met à jours les poids à partir de données obtenus après plusieurs backpropagations * Puis met à 0 tous les d_weights */ -void update_weights(Network* network); +void update_weights(Network* network, Network* d_network); /* * Met à jours les biais à partir de données obtenus après plusieurs backpropagations * Puis met à 0 tous les d_bias */ -void update_bias(Network* network); +void update_bias(Network* network, Network* d_network); /* * Met à 0 toutes les données de backpropagation de poids diff --git a/src/cnn/train.c b/src/cnn/train.c index 9a19e9a..7731875 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -195,10 +195,11 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di accuracy += train_params->accuracy / (float) nb_images_total; current_accuracy = accuracy * nb_images_total/(j*BATCHES); - update_weights(network); - update_bias(network); + update_weights(network, network); + update_bias(network, network); printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); + fflush(stdout); #endif } #ifdef USE_MULTITHREADING diff --git a/src/cnn/update.c b/src/cnn/update.c index a251506..ac8086d 100644 --- a/src/cnn/update.c +++ b/src/cnn/update.c @@ -1,13 +1,16 @@ +#include #include "include/update.h" #include "include/struct.h" -void update_weights(Network* network) { +void update_weights(Network* network, Network* d_network) { int n = network->size; int input_depth, input_width, output_depth, output_width, k_size; Kernel* k_i; + Kernel* dk_i; for (int i=0; i<(n-1); i++) { k_i = network->kernel[i]; + dk_i = d_network->kernel[i]; input_depth = network->depth[i]; input_width = network->width[i]; output_depth = network->depth[i+1]; @@ -15,13 +18,14 @@ void update_weights(Network* network) { if (k_i->cnn) { // Convolution Kernel_cnn* cnn = k_i->cnn; + Kernel_cnn* d_cnn = dk_i->cnn; k_size = cnn->k_size; for (int a=0; aw[a][b][c][d] -= network->learning_rate * cnn->d_w[a][b][c][d]; - cnn->d_w[a][b][c][d] = 0; + cnn->w[a][b][c][d] -= network->learning_rate * d_cnn->d_w[a][b][c][d]; + d_cnn->d_w[a][b][c][d] = 0; } } } @@ -29,19 +33,21 @@ void update_weights(Network* network) { } else if (k_i->nn) { // Full connection if (input_depth==1) { // Vecteur -> Vecteur Kernel_nn* nn = k_i->nn; + Kernel_nn* d_nn = dk_i->nn; for (int a=0; aweights[a][b] -= network->learning_rate * nn->d_weights[a][b]; - nn->d_weights[a][b] = 0; + nn->weights[a][b] -= network->learning_rate * d_nn->d_weights[a][b]; + d_nn->d_weights[a][b] = 0; } } } else { // Matrice -> vecteur Kernel_nn* nn = k_i->nn; + Kernel_nn* d_nn = dk_i->nn; int input_size = input_width*input_width*input_depth; for (int a=0; aweights[a][b] -= network->learning_rate * nn->d_weights[a][b]; - nn->d_weights[a][b] = 0; + nn->weights[a][b] -= network->learning_rate * d_nn->d_weights[a][b]; + d_nn->d_weights[a][b] = 0; } } } @@ -51,30 +57,35 @@ void update_weights(Network* network) { } } -void update_bias(Network* network) { +void update_bias(Network* network, Network* d_network) { + int n = network->size; int output_width, output_depth; Kernel* k_i; + Kernel* dk_i; for (int i=0; i<(n-1); i++) { k_i = network->kernel[i]; + dk_i = d_network->kernel[i]; output_width = network->width[i+1]; output_depth = network->depth[i+1]; if (k_i->cnn) { // Convolution Kernel_cnn* cnn = k_i->cnn; + Kernel_cnn* d_cnn = dk_i->cnn; for (int a=0; abias[a][b][c] -= network->learning_rate * cnn->d_bias[a][b][c]; - cnn->d_bias[a][b][c] = 0; + cnn->bias[a][b][c] -= network->learning_rate * d_cnn->d_bias[a][b][c]; + d_cnn->d_bias[a][b][c] = 0; } } } } else if (k_i->nn) { // Full connection Kernel_nn* nn = k_i->nn; + Kernel_nn* d_nn = dk_i->nn; for (int a=0; abias[a] -= network->learning_rate * nn->d_bias[a]; - nn->d_bias[a] = 0; + nn->bias[a] -= network->learning_rate * d_nn->d_bias[a]; + d_nn->d_bias[a] = 0; } } else { // Pooling (void)0; // Ne rien faire pour la couche pooling