2023-03-10 18:19:23 +01:00
|
|
|
#include <stdio.h>
|
|
|
|
#include <float.h>
|
2022-11-03 17:50:11 +01:00
|
|
|
#include <math.h>
|
2022-11-03 18:13:01 +01:00
|
|
|
|
|
|
|
#include "include/backpropagation.h"
|
|
|
|
#include "include/struct.h"
|
2022-11-03 17:50:11 +01:00
|
|
|
|
|
|
|
int min(int a, int b) {
|
|
|
|
return a<b?a:b;
|
|
|
|
}
|
|
|
|
|
|
|
|
int max(int a, int b) {
|
2022-11-03 18:13:01 +01:00
|
|
|
return a > b ? a : b;
|
2022-11-03 17:50:11 +01:00
|
|
|
}
|
|
|
|
|
2023-02-24 11:03:51 +01:00
|
|
|
void softmax_backward_mse(float* input, float* output, int size) {
|
2023-02-24 11:01:59 +01:00
|
|
|
/* Input et output ont la même taille */
|
2023-02-07 18:39:38 +01:00
|
|
|
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int i=0; i < size; i++){
|
2023-02-07 18:39:38 +01:00
|
|
|
input[i] = (output[i]-input[i])*input[i]*(1-input[i]);
|
2022-11-03 17:50:11 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-02-24 11:03:51 +01:00
|
|
|
void softmax_backward_cross_entropy(float* input, float* output, int size) {
|
2023-02-24 11:01:59 +01:00
|
|
|
/* Input et output ont la même taille */
|
|
|
|
|
|
|
|
for (int i=0; i < size; i++){
|
|
|
|
input[i] = output[i] - input[i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-08 20:48:34 +01:00
|
|
|
void backward_average_pooling(float*** input, float*** output, int input_width, int output_width, int depth) {
|
2022-11-03 17:50:11 +01:00
|
|
|
/* Input et output ont la même profondeur (depth) */
|
|
|
|
|
2023-02-07 18:39:38 +01:00
|
|
|
int size = input_width/output_width; // Taille du pooling
|
2022-11-03 17:50:11 +01:00
|
|
|
int n = size*size; // Nombre d'éléments dans le pooling
|
|
|
|
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int a=0; a < depth; a++)
|
|
|
|
for (int b=0; b < input_width; b++)
|
|
|
|
for (int c=0; c < input_width; c++)
|
2022-11-03 17:50:11 +01:00
|
|
|
input[a][b][c] = 0;
|
|
|
|
|
|
|
|
for (int i=0; i < depth; i++) {
|
|
|
|
for (int j=0; j < output_width; j++) {
|
|
|
|
for (int k=0; k < output_width; k++) {
|
|
|
|
for (int a=0; a < size; a++) {
|
|
|
|
for (int b=0; b < size; b++) {
|
|
|
|
input[i][size*j +a][size*k +b] += output[i][j][k]/n;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-10 18:19:23 +01:00
|
|
|
void backward_max_pooling(float*** input, float*** output, int input_width, int output_width, int depth) {
|
|
|
|
int size = input_width/output_width;
|
|
|
|
|
|
|
|
float m; // Maximum
|
|
|
|
int a_max, b_max; // Indices du maximum
|
|
|
|
|
|
|
|
for (int i=0; i < depth; i++) {
|
|
|
|
for (int j=0; j < output_width; j++) {
|
|
|
|
for (int k=0; k < output_width; k++) {
|
|
|
|
m = -FLT_MAX;
|
|
|
|
a_max = -1;
|
|
|
|
b_max = -1;
|
|
|
|
|
|
|
|
for (int a=0; a < size; a++) {
|
|
|
|
for (int b=0; b < size; b++) {
|
|
|
|
if (input[i][size*j +a][size*k +b] > m) {
|
|
|
|
m = input[i][size*j +a][size*k +b];
|
|
|
|
a_max = a;
|
|
|
|
b_max = b;
|
|
|
|
}
|
2023-03-11 19:36:46 +01:00
|
|
|
input[i][size*j +a][size*k +b] = 0;
|
2023-03-10 18:19:23 +01:00
|
|
|
}
|
|
|
|
}
|
2023-03-11 19:36:46 +01:00
|
|
|
input[i][size*j +a_max][size*k +b_max] = output[i][j][k]/(size*size);
|
2023-03-10 18:19:23 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-02-17 14:56:05 +01:00
|
|
|
void backward_dense(Kernel_nn* ker, float* input, float* input_z, float* output, int size_input, int size_output, ptr d_function, int is_first) {
|
2022-11-03 17:50:11 +01:00
|
|
|
// Bias
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int j=0; j < size_output; j++) {
|
2023-01-20 13:41:38 +01:00
|
|
|
ker->d_bias[j] += output[j];
|
2022-11-03 17:50:11 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Weights
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int i=0; i < size_input; i++) {
|
|
|
|
for (int j=0; j < size_output; j++) {
|
2023-01-20 13:41:38 +01:00
|
|
|
ker->d_weights[i][j] += input[i]*output[j];
|
2022-11-03 17:50:11 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Input
|
2022-11-03 18:13:01 +01:00
|
|
|
if (is_first==1) {// Pas besoin de backpropager dans l'input
|
2022-11-03 17:50:11 +01:00
|
|
|
return;
|
2022-11-03 18:13:01 +01:00
|
|
|
}
|
2022-11-03 17:50:11 +01:00
|
|
|
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int i=0; i < size_input; i++) {
|
2022-11-03 17:50:11 +01:00
|
|
|
float tmp=0;
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int j=0; j < size_output; j++) {
|
2022-11-03 17:50:11 +01:00
|
|
|
tmp += output[j]*ker->weights[i][j];
|
|
|
|
}
|
2022-11-03 18:45:38 +01:00
|
|
|
input[i] = tmp*d_function(input_z[i]);
|
2022-11-03 17:50:11 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void backward_linearisation(Kernel_nn* ker, float*** input, float*** input_z, float* output, int depth_input, int dim_input, int size_output, ptr d_function) {
|
|
|
|
// Bias
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int j=0; j < size_output; j++) {
|
2022-11-03 17:50:11 +01:00
|
|
|
ker->d_bias[j] += output[j];
|
|
|
|
}
|
|
|
|
|
|
|
|
// Weights
|
|
|
|
int cpt = 0;
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int i=0; i < depth_input; i++) {
|
|
|
|
for (int k=0; k < dim_input; k++) {
|
|
|
|
for (int l=0; l < dim_input; l++) {
|
|
|
|
for (int j=0; j < size_output; j++) {
|
2023-02-07 18:39:38 +01:00
|
|
|
ker->d_weights[cpt][j] += input[i][k][l]*output[j];
|
2022-11-03 17:50:11 +01:00
|
|
|
}
|
2022-11-09 10:55:14 +01:00
|
|
|
cpt++;
|
2022-11-03 17:50:11 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Input
|
|
|
|
cpt = 0;
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int i=0; i < depth_input; i++) {
|
|
|
|
for (int k=0; k < dim_input; k++) {
|
|
|
|
for (int l=0; l < dim_input; l++) {
|
2022-11-03 17:50:11 +01:00
|
|
|
float tmp=0;
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int j=0; j < size_output; j++) {
|
2022-11-03 17:50:11 +01:00
|
|
|
tmp += output[j]*ker->weights[cpt][j];
|
|
|
|
}
|
2022-11-03 18:45:38 +01:00
|
|
|
input[i][k][l] = tmp*d_function(input_z[i][k][l]);
|
2022-11-03 17:50:11 +01:00
|
|
|
cpt++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void backward_convolution(Kernel_cnn* ker, float*** input, float*** input_z, float*** output, int depth_input, int dim_input, int depth_output, int dim_output, ptr d_function, int is_first) {
|
|
|
|
// Bias
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int i=0; i < depth_output; i++) {
|
|
|
|
for (int j=0; j < dim_output; j++) {
|
|
|
|
for (int k=0; k < dim_output; k++) {
|
2023-03-11 19:40:25 +01:00
|
|
|
ker->d_bias[i][j][k] += output[i][j][k];
|
2022-11-03 17:50:11 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Weights
|
|
|
|
int k_size = dim_input - dim_output +1;
|
2023-02-07 18:39:38 +01:00
|
|
|
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int h=0; h < depth_input; h++) {
|
|
|
|
for (int i=0; i < depth_output; i++) {
|
|
|
|
for (int j=0; j < k_size; j++) {
|
|
|
|
for (int k=0; k < k_size; k++) {
|
2022-11-03 17:50:11 +01:00
|
|
|
float tmp = 0;
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int l=0; l < dim_output; l++) {
|
|
|
|
for (int m=0; m < dim_output; m++) {
|
2022-11-03 17:50:11 +01:00
|
|
|
tmp += input[h][l+j][m+k]*output[i][l][m];
|
|
|
|
}
|
|
|
|
}
|
2023-02-19 13:38:33 +01:00
|
|
|
ker->d_weights[h][i][j][k] += tmp;
|
2022-11-03 17:50:11 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Input
|
|
|
|
if (is_first==1) // Pas besoin de backpropager dans l'input
|
|
|
|
return;
|
2022-11-12 14:20:13 +01:00
|
|
|
int min_m, max_m, min_n, max_n;
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int i=0; i < depth_input; i++) {
|
|
|
|
for (int j=0; j < dim_input; j++) {
|
|
|
|
for (int k=0; k < dim_input; k++) {
|
2022-11-03 17:50:11 +01:00
|
|
|
float tmp = 0;
|
2022-11-03 18:13:01 +01:00
|
|
|
for (int l=0; l < depth_output; l++) {
|
2022-11-12 14:20:13 +01:00
|
|
|
min_m = max(0, k_size-1-j);
|
|
|
|
max_m = min(k_size, dim_input - j);
|
|
|
|
min_n = max(0, k_size-1-k);
|
|
|
|
max_n = min(k_size, dim_input-k);
|
2022-11-03 17:50:11 +01:00
|
|
|
for (int m=min_m; m < max_m; m++) {
|
|
|
|
for (int n=min_n; n < max_n; n++) {
|
2023-02-19 13:38:33 +01:00
|
|
|
tmp += output[l][j-k_size+m+1][k-k_size+n+1]*ker->weights[i][l][m][n];
|
2022-11-03 17:50:11 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2022-11-03 18:45:38 +01:00
|
|
|
input[i][j][k] = tmp*d_function(input_z[i][j][k]);
|
2022-11-03 17:50:11 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|