diff --git a/src/cnn/backpropagation.c b/src/cnn/backpropagation.c index 339482d..ee46fd9 100644 --- a/src/cnn/backpropagation.c +++ b/src/cnn/backpropagation.c @@ -3,8 +3,6 @@ #include "include/backpropagation.h" #include "include/struct.h" -// The first layer needs to be a convolution or a fully connected one - int min(int a, int b) { return a b ? a : b; } -// Euh..... tout peut être faux à cause de la source void softmax_backward(float* input, float* input_z, float* output, int size) { /* Input et output ont la même taille - On considère que la dernière couche a utilisée softmax */ + On considère que la dernière couche a utilisée softmax + et que l'erreur est MSE */ + for (int i=0; i < size; i++){ - input[i] = (output[i]-input[i])*input[i]; // ∂E/∂out_i * ∂out_i/∂net_i = 𝛿_i + input[i] = (output[i]-input[i])*input[i]*(1-input[i]); } } void backward_2d_pooling(float*** input, float*** output, int input_width, int output_width, int depth) { /* Input et output ont la même profondeur (depth) */ - // Inventé par moi-même (et que moi (vraiment que moi (lvdmm))) - int size = output_width - input_width + 1; // Taille du pooling + //int size = output_width - input_width +1; + int size = input_width/output_width; // Taille du pooling int n = size*size; // Nombre d'éléments dans le pooling for (int a=0; a < depth; a++) @@ -82,12 +81,11 @@ void backward_linearisation(Kernel_nn* ker, float*** input, float*** input_z, fl // Weights int cpt = 0; - int nb_elem = depth_input*dim_input*dim_input*size_output; for (int i=0; i < depth_input; i++) { for (int k=0; k < dim_input; k++) { for (int l=0; l < dim_input; l++) { for (int j=0; j < size_output; j++) { - ker->d_weights[cpt][j] += input[i][k][l]*output[j]/nb_elem; + ker->d_weights[cpt][j] += input[i][k][l]*output[j]; } cpt++; } @@ -122,7 +120,7 @@ void backward_convolution(Kernel_cnn* ker, float*** input, float*** input_z, flo // Weights int k_size = dim_input - dim_output +1; - + for (int h=0; h < depth_input; h++) { for (int i=0; i < depth_output; i++) { for (int j=0; j < k_size; j++) {