From 5a34311dfbb7d7a9d0307238fb99efb7f38cd219 Mon Sep 17 00:00:00 2001 From: augustin64 Date: Tue, 8 Nov 2022 19:57:13 +0100 Subject: [PATCH] Update cnn utils --- src/cnn/utils.c | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/cnn/utils.c b/src/cnn/utils.c index d794f6a..d02f73b 100644 --- a/src/cnn/utils.c +++ b/src/cnn/utils.c @@ -12,12 +12,13 @@ if (network1->var != network2->var) { \ printf_error("network1->" name " et network2->" name " ne sont pas égaux\n"); \ if (indice != -1) { \ - printf(BOLDBLUE"[ INFO_ ]"RESET" indice: %d\n", indice); \ + printf(BOLDBLUE"[ INFO_ ]" RESET " indice: %d\n", indice); \ } \ return false; \ } bool equals_networks(Network* network1, Network* network2) { + int output_dim; checkEquals(size, "size", -1); checkEquals(initialisation, "initialisation", -1); checkEquals(dropout, "dropout", -1); @@ -50,12 +51,13 @@ bool equals_networks(Network* network1, Network* network2) { } } else { // Type CNN + output_dim = network1->width[i]; checkEquals(kernel[i]->cnn->k_size, "kernel[i]->k_size", i); checkEquals(kernel[i]->cnn->rows, "kernel[i]->rows", i); checkEquals(kernel[i]->cnn->columns, "kernel[i]->columns", i); for (int j=0; j < network1->kernel[i]->cnn->columns; j++) { - for (int k=0; k < network1->kernel[i]->cnn->k_size; k++) { - for (int l=0; l < network1->kernel[i]->cnn->k_size; l++) { + for (int k=0; k < output_dim; k++) { + for (int l=0; l < output_dim; l++) { checkEquals(kernel[i]->cnn->bias[j][k][l], "kernel[i]->cnn->bias[j][k][l]", l); } } @@ -87,6 +89,7 @@ Network* copy_network(Network* network) { int rows; int k_size; int columns; + int output_dim; copyVar(dropout); copyVar(learning_rate); @@ -149,6 +152,8 @@ Network* copy_network(Network* network) { rows = network->kernel[i]->cnn->rows; k_size = network->kernel[i]->cnn->k_size; columns = network->kernel[i]->cnn->columns; + output_dim = network->width[i]; + network_cp->kernel[i]->nn = NULL; network_cp->kernel[i]->cnn = (Kernel_cnn*)malloc(sizeof(Kernel_cnn)); @@ -160,12 +165,12 @@ Network* copy_network(Network* network) { network_cp->kernel[i]->cnn->bias = (float***)malloc(sizeof(float**)*columns); network_cp->kernel[i]->cnn->d_bias = (float***)malloc(sizeof(float**)*columns); for (int j=0; j < columns; j++) { - network_cp->kernel[i]->cnn->bias[j] = (float**)malloc(sizeof(float*)*k_size); - network_cp->kernel[i]->cnn->d_bias[j] = (float**)malloc(sizeof(float*)*k_size); - for (int k=0; k < k_size; k++) { - network_cp->kernel[i]->cnn->bias[j][k] = (float*)malloc(sizeof(float)*k_size); - network_cp->kernel[i]->cnn->d_bias[j][k] = (float*)malloc(sizeof(float)*k_size); - for (int l=0; l < k_size; l++) { + network_cp->kernel[i]->cnn->bias[j] = (float**)malloc(sizeof(float*)*output_dim); + network_cp->kernel[i]->cnn->d_bias[j] = (float**)malloc(sizeof(float*)*output_dim); + for (int k=0; k < output_dim; k++) { + network_cp->kernel[i]->cnn->bias[j][k] = (float*)malloc(sizeof(float)*output_dim); + network_cp->kernel[i]->cnn->d_bias[j][k] = (float*)malloc(sizeof(float)*output_dim); + for (int l=0; l < output_dim; l++) { copyVar(kernel[i]->cnn->bias[j][k][l]); network_cp->kernel[i]->cnn->d_bias[j][k][l] = 0.; }