From 3a50b08179f86cc167c70917897bd5accd5e56ea Mon Sep 17 00:00:00 2001 From: augustin64 Date: Fri, 10 Mar 2023 18:19:23 +0100 Subject: [PATCH] Add max_pooling backward --- src/cnn/backpropagation.c | 32 ++++++++++++++++++++++++++++++- src/cnn/cnn.c | 4 ++-- src/cnn/include/backpropagation.h | 6 ++++++ 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/src/cnn/backpropagation.c b/src/cnn/backpropagation.c index 4a2045d..8a89f59 100644 --- a/src/cnn/backpropagation.c +++ b/src/cnn/backpropagation.c @@ -1,3 +1,5 @@ +#include +#include #include #include "include/backpropagation.h" @@ -30,7 +32,6 @@ void softmax_backward_cross_entropy(float* input, float* output, int size) { void backward_average_pooling(float*** input, float*** output, int input_width, int output_width, int depth) { /* Input et output ont la même profondeur (depth) */ - //int size = output_width - input_width +1; int size = input_width/output_width; // Taille du pooling int n = size*size; // Nombre d'éléments dans le pooling @@ -52,6 +53,35 @@ void backward_average_pooling(float*** input, float*** output, int input_width, } } +void backward_max_pooling(float*** input, float*** output, int input_width, int output_width, int depth) { + int size = input_width/output_width; + + float m; // Maximum + int a_max, b_max; // Indices du maximum + + for (int i=0; i < depth; i++) { + for (int j=0; j < output_width; j++) { + for (int k=0; k < output_width; k++) { + m = -FLT_MAX; + a_max = -1; + b_max = -1; + + for (int a=0; a < size; a++) { + for (int b=0; b < size; b++) { + if (input[i][size*j +a][size*k +b] > m) { + m = input[i][size*j +a][size*k +b]; + a_max = a; + b_max = b; + input[i][size*j +a][size*k +b] = 0; + } + } + } + input[i][size*j +a_max][size*k +b_max] = output[i][j][k]; + } + } + } +} + void backward_dense(Kernel_nn* ker, float* input, float* input_z, float* output, int size_input, int size_output, ptr d_function, int is_first) { // Bias for (int j=0; j < size_output; j++) { diff --git a/src/cnn/cnn.c b/src/cnn/cnn.c index a9cdd10..c18a37b 100644 --- a/src/cnn/cnn.c +++ b/src/cnn/cnn.c @@ -18,7 +18,7 @@ int indice_max(float* tab, int n) { int indice = -1; - float maxi = FLT_MIN; + float maxi = -FLT_MAX; for (int i=0; i < n; i++) { if (tab[i] > maxi) { @@ -187,7 +187,7 @@ void backward_propagation(Network* network, int wanted_number) { if (k_i->pooling == AVG_POOLING) { backward_average_pooling(input, output, input_width, output_width, input_depth); // Depth pour input et output a la même valeur } else { - printf_error("La backpropagation de ce pooling n'est pas encore implémentée\n"); + backward_max_pooling(input, output, input_width, output_width, input_depth); // Depth pour input et output a la même valeur } } } diff --git a/src/cnn/include/backpropagation.h b/src/cnn/include/backpropagation.h index 15a6a85..dea99ff 100644 --- a/src/cnn/include/backpropagation.h +++ b/src/cnn/include/backpropagation.h @@ -31,6 +31,12 @@ void softmax_backward_cross_entropy(float* input, float* output, int size); */ void backward_average_pooling(float*** input, float*** output, int input_width, int output_width, int depth); +/* +* Transfert les informations d'erreur à travers une couche de max pooling +* en considérant cross_entropy comme fonction d'erreur +*/ +void backward_max_pooling(float*** input, float*** output, int input_width, int output_width, int depth); + /* * Transfert les informations d'erreur à travers une couche fully connected */