diff --git a/src/cnn/free.c b/src/cnn/free.c index 6c20b7f..600c969 100644 --- a/src/cnn/free.c +++ b/src/cnn/free.c @@ -118,6 +118,5 @@ void free_network(Network* network) { free_2d_average_pooling(network, i); } } - printf("Network freed successfully !\n"); free_network_creation(network); } diff --git a/src/cnn/include/struct.h b/src/cnn/include/struct.h index a153210..4ac7c1e 100644 --- a/src/cnn/include/struct.h +++ b/src/cnn/include/struct.h @@ -30,14 +30,14 @@ typedef struct Kernel { typedef struct Network{ int dropout; // Contient la probabilité d'abandon d'un neurone dans [0, 100] (entiers) - int learning_rate; + int learning_rate; // Taux d'apprentissage du réseau int initialisation; // Contient le type d'initialisation int max_size; // Taille du tableau contenant le réseau int size; // Taille actuelle du réseau (size ≤ max_size) int* width; // width[size] int* depth; // depth[size] - Kernel** kernel; // Tableau de tous les kernels - float**** input; // Tableau de toutes les couches du réseau input[nb couches][couche->depth][couche->dim][couche->dim] + Kernel** kernel; // kernel[size], contient tous les kernels + float**** input; // Tableau de toutes les couches du réseau input[size][couche->depth][couche->width][couche->width] } Network; #endif \ No newline at end of file diff --git a/src/cnn/include/utils.h b/src/cnn/include/utils.h index 5197c7d..190f0a2 100644 --- a/src/cnn/include/utils.h +++ b/src/cnn/include/utils.h @@ -14,4 +14,9 @@ */ bool equals_networks(Network* network1, Network* network2); +/* + * Duplique un réseau +*/ +Network* copy_network(Network* network); + #endif \ No newline at end of file diff --git a/src/cnn/utils.c b/src/cnn/utils.c index 1c50a6f..f834529 100644 --- a/src/cnn/utils.c +++ b/src/cnn/utils.c @@ -6,7 +6,16 @@ #include "../colors.h" #include "include/struct.h" -#define checkEquals(var, name, indice) if (network1->var != network2->var) { printf_error("network1->" name " et network2->" name " ne sont pas égaux\n"); if (indice != -1) {printf(BOLDBLUE"[ INFO_ ]"RESET" indice: %d\n", indice);} return false; } +#define copyVar(var) network_cp->var = network->var + +#define checkEquals(var, name, indice) \ +if (network1->var != network2->var) { \ + printf_error("network1->" name " et network2->" name " ne sont pas égaux\n"); \ + if (indice != -1) { \ + printf(BOLDBLUE"[ INFO_ ]"RESET" indice: %d\n", indice); \ + } \ + return false; \ +} bool equals_networks(Network* network1, Network* network2) { checkEquals(size, "size", -1); @@ -64,4 +73,139 @@ bool equals_networks(Network* network1, Network* network2) { } return true; +} + + +Network* copy_network(Network* network) { + Network* network_cp = (Network*)malloc(sizeof(Network)); + // Paramètre du réseau + int size = network->size; + // Paramètres des couches NN + int input_units; + int output_units; + // Paramètres des couches CNN + int rows; + int k_size; + int columns; + + copyVar(dropout); + copyVar(learning_rate); + copyVar(initialisation); + copyVar(max_size); + copyVar(size); + + network_cp->width = (int*)malloc(sizeof(int)*size); + network_cp->depth = (int*)malloc(sizeof(int)*size); + + for (int i=0; i < size; i++) { + copyVar(width[i]); + copyVar(depth[i]); + } + + network_cp->kernel = (Kernel**)malloc(sizeof(Kernel*)*size); + for (int i=0; i < size; i++) { + network_cp->kernel[i] = (Kernel*)malloc(sizeof(Kernel)); + if (!network->kernel[i]->nn && !network->kernel[i]->cnn) { // Cas de la couche de linéarisation + copyVar(kernel[i]->activation); + copyVar(kernel[i]->linearisation); // 1 + network_cp->kernel[i]->cnn = NULL; + network_cp->kernel[i]->nn = NULL; + } + else if (!network->kernel[i]->cnn) { // Cas du NN + copyVar(kernel[i]->activation); + copyVar(kernel[i]->linearisation); // 0 + + input_units = network->kernel[i]->nn->input_units; + output_units = network->kernel[i]->nn->output_units; + + network_cp->kernel[i]->cnn = NULL; + network_cp->kernel[i]->nn = (Kernel_nn*)malloc(sizeof(Kernel_nn)); + + copyVar(kernel[i]->nn->input_units); + copyVar(kernel[i]->nn->output_units); + + network_cp->kernel[i]->nn->bias = (float*)malloc(sizeof(float)*output_units); + network_cp->kernel[i]->nn->d_bias = (float*)malloc(sizeof(float)*output_units); + for (int j=0; j < output_units; j++) { + copyVar(kernel[i]->nn->bias[j]); + network_cp->kernel[i]->nn->d_bias[j] = 0.; + } + + network_cp->kernel[i]->nn->weights = (float**)malloc(sizeof(float*)*input_units); + network_cp->kernel[i]->nn->d_weights = (float**)malloc(sizeof(float*)*input_units); + for (int j=0; j < input_units; j++) { + network_cp->kernel[i]->nn->weights[j] = (float*)malloc(sizeof(float)*output_units); + network_cp->kernel[i]->nn->d_weights[j] = (float*)malloc(sizeof(float)*output_units); + for (int k=0; k < output_units; k++) { + copyVar(kernel[i]->nn->weights[j][k]); + network_cp->kernel[i]->nn->d_weights[j][k] = 0.; + } + } + } + else { // Cas du CNN + copyVar(kernel[i]->activation); + copyVar(kernel[i]->linearisation); // 0 + + rows = network->kernel[i]->cnn->rows; + k_size = network->kernel[i]->cnn->k_size; + columns = network->kernel[i]->cnn->columns; + + network_cp->kernel[i]->nn = NULL; + network_cp->kernel[i]->cnn = (Kernel_cnn*)malloc(sizeof(Kernel_cnn)); + + copyVar(kernel[i]->cnn->rows); + copyVar(kernel[i]->cnn->k_size); + copyVar(kernel[i]->cnn->columns); + + network_cp->kernel[i]->cnn->bias = (float***)malloc(sizeof(float**)*columns); + network_cp->kernel[i]->cnn->d_bias = (float***)malloc(sizeof(float**)*columns); + for (int j=0; j < columns; j++) { + network_cp->kernel[i]->cnn->bias[j] = (float**)malloc(sizeof(float*)*k_size); + network_cp->kernel[i]->cnn->d_bias[j] = (float**)malloc(sizeof(float*)*k_size); + for (int k=0; k < k_size; k++) { + network_cp->kernel[i]->cnn->bias[j][k] = (float*)malloc(sizeof(float)*k_size); + network_cp->kernel[i]->cnn->d_bias[j][k] = (float*)malloc(sizeof(float)*k_size); + for (int l=0; l < k_size; l++) { + copyVar(kernel[i]->cnn->bias[j][k][l]); + network_cp->kernel[i]->cnn->d_bias[j][k][l] = 0.; + } + } + } + + network_cp->kernel[i]->cnn->w = (float****)malloc(sizeof(float***)*rows); + network_cp->kernel[i]->cnn->d_w = (float****)malloc(sizeof(float***)*rows); + for (int j=0; j < rows; j++) { + network_cp->kernel[i]->cnn->w[j] = (float***)malloc(sizeof(float**)*columns); + network_cp->kernel[i]->cnn->d_w[j] = (float***)malloc(sizeof(float**)*columns); + for (int k=0; k < columns; k++) { + network_cp->kernel[i]->cnn->w[j][k] = (float**)malloc(sizeof(float*)*k_size); + network_cp->kernel[i]->cnn->d_w[j][k] = (float**)malloc(sizeof(float*)*k_size); + for (int l=0; l < k_size; l++) { + network_cp->kernel[i]->cnn->w[j][k][l] = (float*)malloc(sizeof(float)*k_size); + network_cp->kernel[i]->cnn->d_w[j][k][l] = (float*)malloc(sizeof(float)*k_size); + for (int m=0; m < k_size; m++) { + copyVar(kernel[i]->cnn->w[j][k][l][m]); + network_cp->kernel[i]->cnn->d_w[j][k][l][m] = 0.; + } + } + } + } + } + } + + network_cp->input = (float****)malloc(sizeof(float***)*size); + for (int i=0; i < size; i++) { // input[size][couche->depth][couche->dim][couche->dim] + network_cp->input[i] = (float***)malloc(sizeof(float**)*network->depth[i]); + for (int j=0; j < network->depth[i]; j++) { + network_cp->input[i][j] = (float**)malloc(sizeof(float*)*network->width[i]); + for (int k=0; k < network->width[i]; k++) { + network_cp->input[i][j][k] = (float*)malloc(sizeof(float)*network->width[i]); + for (int l=0; l < network->width[i]; l++) { + network_cp->input[i][j][k][l] = 0.; + } + } + } + } + + return network_cp; } \ No newline at end of file diff --git a/test/cnn_utils.c b/test/cnn_utils.c new file mode 100644 index 0000000..1783958 --- /dev/null +++ b/test/cnn_utils.c @@ -0,0 +1,25 @@ +#include +#include + +#include "../src/colors.h" +#include "../src/cnn/creation.c" +#include "../src/cnn/utils.c" + +int main() { + printf("Création du réseau\n"); + Network* network = create_network_lenet5(0, 0, 3, 2, 32, 1); + printf("OK\n"); + + printf("Copie du réseau\n"); + Network* network_cp = copy_network(network); + printf("OK\n"); + + printf("Vérification de l'égalité des réseaux\n"); + if (! equals_networks(network, network_cp)) { + printf_error("Les deux réseaux obtenus ne sont pas égaux.\n"); + exit(1); + } + printf("OK\n"); + + return 0; +} \ No newline at end of file