diff --git a/src/cnn/include/struct.h b/src/cnn/include/struct.h index fdb0f75..3a0cbf1 100644 --- a/src/cnn/include/struct.h +++ b/src/cnn/include/struct.h @@ -10,6 +10,9 @@ #define DOESNT_LINEARISE 0 #define DO_LINEARISE 1 + +//-------------------------- Réseau classique -------------------------- + typedef struct Kernel_cnn { // Noyau ayant une couche matricielle en sortie int k_size; // k_size = 2*padding + input_width + stride - output_width*stride @@ -63,6 +66,7 @@ typedef struct Kernel { } Kernel; + typedef struct Network{ int dropout; // Probabilité d'abandon d'un neurone dans [0, 100] (entiers) float learning_rate; // Taux d'apprentissage du réseau @@ -80,4 +84,55 @@ typedef struct Network{ float**** input; // input[i] = f(input_z[i]) où f est la fonction d'activation de la couche i } Network; + + +//------------------- Réseau pour la backpropagation ------------------- + +/* +* On définit ici la classe D_Network associé à la classe Network +* Elle permet la backpropagation des réseaux auxquels elle est associée +*/ + + +typedef struct D_Kernel_cnn { + // Noyau ayant une couche matricielle en sortie + + float*** d_bias; // d_bias[columns][output_width][output_width] + #ifdef ADAM_CNN_BIAS + float*** s_d_bias; // s_d_bias[columns][output_width][output_width] + float*** v_d_bias; // v_d_bias[columns][output_width][output_width] + #endif + + float**** d_weights; // d_weights[rows][columns][k_size][k_size] + #ifdef ADAM_CNN_WEIGHTS + float**** s_d_weights; // s_d_weights[rows][columns][k_size][k_size] + float**** v_d_weights; // v_d_weights[rows][columns][k_size][k_size] + #endif +} D_Kernel_cnn; + +typedef struct D_Kernel_nn { + // Noyau ayant une couche vectorielle en sortie + + float* d_bias; // d_bias[size_output] + #ifdef ADAM_DENSE_BIAS + float* s_d_bias; // s_d_bias[size_output] + float* v_d_bias; // v_d_bias[size_output] + #endif + + float** d_weights; // d_weights[size_input][size_output] + #ifdef ADAM_DENSE_WEIGHTS + float** s_d_weights; // s_d_weights[size_input][size_output] + float** v_d_weights; // v_d_weights[size_input][size_output] + #endif +} D_Kernel_nn; + +typedef struct D_Kernel { + D_Kernel_cnn* cnn; // NULL si ce n'est pas un cnn + D_Kernel_nn* nn; // NULL si ce n'est pas un nn +} D_Kernel; + +typedef struct D_Network{ + D_Kernel** kernel; // kernel[size], contient tous les kernels +} D_Network; + #endif \ No newline at end of file