Update cnn utils

This commit is contained in:
augustin64 2022-11-08 19:57:13 +01:00
parent 70e0ed08b7
commit 5a34311dfb

View File

@ -18,6 +18,7 @@ if (network1->var != network2->var) {
} }
bool equals_networks(Network* network1, Network* network2) { bool equals_networks(Network* network1, Network* network2) {
int output_dim;
checkEquals(size, "size", -1); checkEquals(size, "size", -1);
checkEquals(initialisation, "initialisation", -1); checkEquals(initialisation, "initialisation", -1);
checkEquals(dropout, "dropout", -1); checkEquals(dropout, "dropout", -1);
@ -50,12 +51,13 @@ bool equals_networks(Network* network1, Network* network2) {
} }
} else { } else {
// Type CNN // Type CNN
output_dim = network1->width[i];
checkEquals(kernel[i]->cnn->k_size, "kernel[i]->k_size", i); checkEquals(kernel[i]->cnn->k_size, "kernel[i]->k_size", i);
checkEquals(kernel[i]->cnn->rows, "kernel[i]->rows", i); checkEquals(kernel[i]->cnn->rows, "kernel[i]->rows", i);
checkEquals(kernel[i]->cnn->columns, "kernel[i]->columns", i); checkEquals(kernel[i]->cnn->columns, "kernel[i]->columns", i);
for (int j=0; j < network1->kernel[i]->cnn->columns; j++) { for (int j=0; j < network1->kernel[i]->cnn->columns; j++) {
for (int k=0; k < network1->kernel[i]->cnn->k_size; k++) { for (int k=0; k < output_dim; k++) {
for (int l=0; l < network1->kernel[i]->cnn->k_size; l++) { for (int l=0; l < output_dim; l++) {
checkEquals(kernel[i]->cnn->bias[j][k][l], "kernel[i]->cnn->bias[j][k][l]", l); checkEquals(kernel[i]->cnn->bias[j][k][l], "kernel[i]->cnn->bias[j][k][l]", l);
} }
} }
@ -87,6 +89,7 @@ Network* copy_network(Network* network) {
int rows; int rows;
int k_size; int k_size;
int columns; int columns;
int output_dim;
copyVar(dropout); copyVar(dropout);
copyVar(learning_rate); copyVar(learning_rate);
@ -149,6 +152,8 @@ Network* copy_network(Network* network) {
rows = network->kernel[i]->cnn->rows; rows = network->kernel[i]->cnn->rows;
k_size = network->kernel[i]->cnn->k_size; k_size = network->kernel[i]->cnn->k_size;
columns = network->kernel[i]->cnn->columns; columns = network->kernel[i]->cnn->columns;
output_dim = network->width[i];
network_cp->kernel[i]->nn = NULL; network_cp->kernel[i]->nn = NULL;
network_cp->kernel[i]->cnn = (Kernel_cnn*)malloc(sizeof(Kernel_cnn)); network_cp->kernel[i]->cnn = (Kernel_cnn*)malloc(sizeof(Kernel_cnn));
@ -160,12 +165,12 @@ Network* copy_network(Network* network) {
network_cp->kernel[i]->cnn->bias = (float***)malloc(sizeof(float**)*columns); network_cp->kernel[i]->cnn->bias = (float***)malloc(sizeof(float**)*columns);
network_cp->kernel[i]->cnn->d_bias = (float***)malloc(sizeof(float**)*columns); network_cp->kernel[i]->cnn->d_bias = (float***)malloc(sizeof(float**)*columns);
for (int j=0; j < columns; j++) { for (int j=0; j < columns; j++) {
network_cp->kernel[i]->cnn->bias[j] = (float**)malloc(sizeof(float*)*k_size); network_cp->kernel[i]->cnn->bias[j] = (float**)malloc(sizeof(float*)*output_dim);
network_cp->kernel[i]->cnn->d_bias[j] = (float**)malloc(sizeof(float*)*k_size); network_cp->kernel[i]->cnn->d_bias[j] = (float**)malloc(sizeof(float*)*output_dim);
for (int k=0; k < k_size; k++) { for (int k=0; k < output_dim; k++) {
network_cp->kernel[i]->cnn->bias[j][k] = (float*)malloc(sizeof(float)*k_size); network_cp->kernel[i]->cnn->bias[j][k] = (float*)malloc(sizeof(float)*output_dim);
network_cp->kernel[i]->cnn->d_bias[j][k] = (float*)malloc(sizeof(float)*k_size); network_cp->kernel[i]->cnn->d_bias[j][k] = (float*)malloc(sizeof(float)*output_dim);
for (int l=0; l < k_size; l++) { for (int l=0; l < output_dim; l++) {
copyVar(kernel[i]->cnn->bias[j][k][l]); copyVar(kernel[i]->cnn->bias[j][k][l]);
network_cp->kernel[i]->cnn->d_bias[j][k][l] = 0.; network_cp->kernel[i]->cnn->d_bias[j][k][l] = 0.;
} }