diff --git a/src/cnn/cnn.c b/src/cnn/cnn.c index 6e300f0..8216fac 100644 --- a/src/cnn/cnn.c +++ b/src/cnn/cnn.c @@ -249,14 +249,15 @@ void backward_propagation(Network* network, int wanted_number) { int output_depth = network->depth[i+1]; int output_width = network->width[i+1]; - int activation = i==0?SIGMOID:network->kernel[i-1]->activation; + int is_last_layer = i==0; + int activation = is_last_layer?SIGMOID:network->kernel[i-1]->activation; if (k_i->cnn) { // Convolution - backward_convolution(k_i->cnn, input, input_z, output, input_depth, input_width, output_depth, output_width, -activation, i==0); + backward_convolution(k_i->cnn, input, input_z, output, input_depth, input_width, output_depth, output_width, -activation, is_last_layer); } else if (k_i->nn) { // Full connection if (k_i->linearisation == DOESNT_LINEARISE) { // Vecteur -> Vecteur - backward_dense(k_i->nn, input[0][0], input_z[0][0], output[0][0], input_width, output_width, -activation, i==0); + backward_dense(k_i->nn, input[0][0], input_z[0][0], output[0][0], input_width, output_width, -activation, is_last_layer); } else { // Matrice -> vecteur backward_linearisation(k_i->nn, input, input_z, output[0][0], input_depth, input_width, output_width, -activation); }