diff --git a/src/cnn/creation.c b/src/cnn/creation.c index e565477..2c261ad 100644 --- a/src/cnn/creation.c +++ b/src/cnn/creation.c @@ -233,7 +233,7 @@ void add_dense(Network* network, int output_units, int activation) { } } - initialisation_1d_matrix(network->initialisation, nn->bias, output_units, input_units, output_units); + initialisation_1d_matrix(network->initialisation, nn->bias, output_units, input_units); initialisation_2d_matrix(network->initialisation, nn->weights, input_units, output_units, input_units, output_units); create_a_line_input_layer(network, n, output_units); create_a_line_input_z_layer(network, n, output_units); @@ -273,7 +273,7 @@ void add_dense_linearisation(Network* network, int output_units, int activation) nn->d_weights[i][j] = 0.; } } - initialisation_1d_matrix(network->initialisation, nn->bias, output_units, input_units, output_units); + initialisation_1d_matrix(network->initialisation, nn->bias, output_units, input_units); initialisation_2d_matrix(network->initialisation, nn->weights, input_units, output_units, input_units, output_units); create_a_line_input_layer(network, n, output_units); create_a_line_input_z_layer(network, n, output_units); diff --git a/src/cnn/include/initialisation.h b/src/cnn/include/initialisation.h index 80f6467..64dab61 100644 --- a/src/cnn/include/initialisation.h +++ b/src/cnn/include/initialisation.h @@ -12,7 +12,7 @@ /* * Initialise une matrice 1d dim de float en fonction du type d'initialisation */ -void initialisation_1d_matrix(int initialisation, float* matrix, int dim, int n_in, int n_out); +void initialisation_1d_matrix(int initialisation, float* matrix, int dim, int n_in); /* * Initialise une matrice 2d dim1*dim2 de float en fonction du type d'initialisation diff --git a/src/cnn/initialisation.c b/src/cnn/initialisation.c index 276a0d4..28d8b39 100644 --- a/src/cnn/initialisation.c +++ b/src/cnn/initialisation.c @@ -9,10 +9,10 @@ // LeCun initialisation: SELU (1/fan_in) // Only uniform for the moment -void initialisation_1d_matrix(int initialisation, float* matrix, int dim, int n_in, int n_out) { +void initialisation_1d_matrix(int initialisation, float* matrix, int dim, int n_in) { int n; if (initialisation == GLOROT) { - n = (n_in + n_out)/2; + n = (n_in + n)/2; } else if (initialisation == HE) { n = n_in/2;