From 65505858da09d49d88630e797cbdbf60f539b293 Mon Sep 17 00:00:00 2001 From: julienChemillier Date: Sat, 12 Nov 2022 14:20:13 +0100 Subject: [PATCH] Update backprop --- src/cnn/backpropagation.c | 12 +++++----- src/cnn/cnn.c | 1 - src/cnn/creation.c | 4 ++-- src/cnn/function.c | 45 ++++++++++++++++++++++++++++---------- src/cnn/include/function.h | 18 ++++++++++----- src/cnn/include/struct.h | 2 +- 6 files changed, 54 insertions(+), 28 deletions(-) diff --git a/src/cnn/backpropagation.c b/src/cnn/backpropagation.c index 49ea293..3ab7b59 100644 --- a/src/cnn/backpropagation.c +++ b/src/cnn/backpropagation.c @@ -146,19 +146,19 @@ void backward_convolution(Kernel_cnn* ker, float*** input, float*** input_z, flo // Input if (is_first==1) // Pas besoin de backpropager dans l'input return; - + int min_m, max_m, min_n, max_n; for (int i=0; i < depth_input; i++) { for (int j=0; j < dim_input; j++) { for (int k=0; k < dim_input; k++) { float tmp = 0; for (int l=0; l < depth_output; l++) { - int min_m = k_size - max(k_size, dim_input-i); - int max_m = min(k_size, i+1); - int min_n = k_size - max(k_size, dim_input-j); - int max_n = min(k_size, j+1); + min_m = max(0, k_size-1-j); + max_m = min(k_size, dim_input - j); + min_n = max(0, k_size-1-k); + max_n = min(k_size, dim_input-k); for (int m=min_m; m < max_m; m++) { for (int n=min_n; n < max_n; n++) { - tmp += output[l][i-m][j-n]*ker->w[i][l][m][n]; + tmp += output[l][j-k_size+m+1][k-k_size+n+1]*ker->w[i][l][m][n]; } } } diff --git a/src/cnn/cnn.c b/src/cnn/cnn.c index 0a06f46..8ee3782 100644 --- a/src/cnn/cnn.c +++ b/src/cnn/cnn.c @@ -90,7 +90,6 @@ void forward_propagation(Network* network) { } void backward_propagation(Network* network, float wanted_number) { - printf_warning("Appel de backward_propagation, incomplet\n"); float* wanted_output = generate_wanted_output(wanted_number); int n = network->size; int activation, input_depth, input_width, output_depth, output_width; diff --git a/src/cnn/creation.c b/src/cnn/creation.c index 81e418d..26b258c 100644 --- a/src/cnn/creation.c +++ b/src/cnn/creation.c @@ -102,8 +102,8 @@ void add_2d_average_pooling(Network* network, int dim_output) { } network->kernel[k_pos]->cnn = NULL; network->kernel[k_pos]->nn = NULL; - network->kernel[k_pos]->activation = 100*kernel_size; // Ne contient pas de fonction d'activation - network->kernel[k_pos]->linearisation = 0; + network->kernel[k_pos]->activation = IDENTITY; // Ne contient pas de fonction d'activation + network->kernel[k_pos]->linearisation = kernel_size; create_a_cube_input_layer(network, n, network->depth[n-1], network->width[n-1]/2); create_a_cube_input_z_layer(network, n, network->depth[n-1], network->width[n-1]/2); // Will it be used ? network->size++; diff --git a/src/cnn/function.c b/src/cnn/function.c index 6afda6c..40d8f22 100644 --- a/src/cnn/function.c +++ b/src/cnn/function.c @@ -9,6 +9,15 @@ float max_float(float a, float b) { return a < b ? b:a; } +float identity(float x) { + return x; +} + +float identity_derivative(float x) { + (void)x; + return 1; +} + float sigmoid(float x) { return 1/(1 + exp(-x)); } @@ -105,26 +114,38 @@ void choose_apply_function_vector(int activation, float*** input, int dim) { ptr get_function_activation(int activation) { if (activation == RELU) { return &relu; - } else if (activation == -RELU) { + } + if (activation == -RELU) { return &relu_derivative; - } else if (activation == SIGMOID) { + } + if (activation == -IDENTITY) { + return &identity_derivative; + } + if (activation == IDENTITY) { + return &identity; + } + if (activation == SIGMOID) { return &sigmoid; - } else if (activation == -SIGMOID) { + } + if (activation == -SIGMOID) { return &sigmoid_derivative; - } else if (activation == SOFTMAX) { + } + if (activation == SOFTMAX) { printf("Erreur, impossible de renvoyer la fonction softmax\n"); return NULL; - } else if (activation == -SOFTMAX) { + } + if (activation == -SOFTMAX) { printf("Erreur, impossible de renvoyer la dérivée de la fonction softmax\n"); return NULL; - } else if (activation == TANH) { - return &tanh_; - } else if (activation == -TANH) { - return &tanh_derivative; - } else { - printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation); - return NULL; } + if (activation == TANH) { + return &tanh_; + } + if (activation == -TANH) { + return &tanh_derivative; + } + printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation); + return NULL; } // to use: // float a = 5; int activation; diff --git a/src/cnn/include/function.h b/src/cnn/include/function.h index 86da475..5f9b8ee 100644 --- a/src/cnn/include/function.h +++ b/src/cnn/include/function.h @@ -3,10 +3,12 @@ // Les dérivées sont l'opposé -#define TANH 1 -#define SIGMOID 2 -#define RELU 3 -#define SOFTMAX 4 +#define IDENTITY 1 +#define TANH 2 +#define SIGMOID 3 +#define RELU 4 +#define SOFTMAX 5 + typedef float (*ptr)(float); typedef ptr (*pm)(); @@ -16,6 +18,10 @@ typedef ptr (*pm)(); */ float max_float(float a, float b); +float identity(float x); + +float identity_derivative(float x); + float sigmoid(float x); float sigmoid_derivative(float x); @@ -29,12 +35,12 @@ float tanh_(float x); float tanh_derivative(float x); /* -* Applique softmax sur ???? +* Applique softmax sur input[depth][rows][columns] */ void apply_softmax_input(float ***input, int depth, int rows, int columns); /* -* Applique la fonction f sur ???? +* Applique la fonction f sur input[depth][rows][columns] */ void apply_function_input(float (*f)(float), float*** input, int depth, int rows, int columns); diff --git a/src/cnn/include/struct.h b/src/cnn/include/struct.h index fb42c37..267294e 100644 --- a/src/cnn/include/struct.h +++ b/src/cnn/include/struct.h @@ -23,7 +23,7 @@ typedef struct Kernel_nn { typedef struct Kernel { Kernel_cnn* cnn; // NULL si ce n'est pas un cnn Kernel_nn* nn; // NULL si ce n'est pas un nn - int activation; // Vaut l'activation sauf pour un pooling où il: vaut pooling_size*100 + int activation; // Vaut l'identifiant de la fonction d'activation int linearisation; // Vaut 1 si c'est la linéarisation d'une couche, 0 sinon ?? Ajouter dans les autres } Kernel;