Update linearisation detection

This commit is contained in:
augustin64 2023-01-17 12:49:35 +01:00
parent 721d9597d4
commit c78c8a0ade
4 changed files with 6 additions and 7 deletions

View File

@ -106,7 +106,7 @@ void forward_propagation(Network* network) {
choose_apply_function_matrix(activation, output, output_depth, output_width); choose_apply_function_matrix(activation, output, output_depth, output_width);
} }
else if (k_i->nn) { // Full connection else if (k_i->nn) { // Full connection
if (input_depth==1) { // Vecteur -> Vecteur if (k_i->linearisation == 0) { // Vecteur -> Vecteur
make_dense(k_i->nn, input[0][0], output[0][0], input_width, output_width); make_dense(k_i->nn, input[0][0], output[0][0], input_width, output_width);
} else { // Matrice -> Vecteur } else { // Matrice -> Vecteur
make_dense_linearised(k_i->nn, input, output[0][0], input_depth, input_width, output_width); make_dense_linearised(k_i->nn, input, output[0][0], input_depth, input_width, output_width);
@ -155,7 +155,7 @@ void backward_propagation(Network* network, float wanted_number) {
backward_convolution(k_i->cnn, input, input_z, output, input_depth, input_width, output_depth, output_width, d_f, i==0); backward_convolution(k_i->cnn, input, input_z, output, input_depth, input_width, output_depth, output_width, d_f, i==0);
} else if (k_i->nn) { // Full connection } else if (k_i->nn) { // Full connection
ptr d_f = get_function_activation(activation); ptr d_f = get_function_activation(activation);
if (input_depth==1) { // Vecteur -> Vecteur if (k_i->linearisation == 0) { // Vecteur -> Vecteur
backward_fully_connected(k_i->nn, input[0][0], input_z[0][0], output[0][0], input_width, output_width, d_f, i==0); backward_fully_connected(k_i->nn, input[0][0], input_z[0][0], output[0][0], input_width, output_width, d_f, i==0);
} else { // Matrice -> vecteur } else { // Matrice -> vecteur
backward_linearisation(k_i->nn, input, input_z, output[0][0], input_depth, input_width, output_width, d_f); backward_linearisation(k_i->nn, input, input_z, output[0][0], input_depth, input_width, output_width, d_f);

View File

@ -95,7 +95,6 @@ void add_2d_average_pooling(Network* network, int dim_output) {
printf("Impossible de rajouter une couche d'average pooling, le réseau est déjà plein\n"); printf("Impossible de rajouter une couche d'average pooling, le réseau est déjà plein\n");
return; return;
} }
int kernel_size = dim_input/dim_output;
if (dim_input%dim_output != 0) { if (dim_input%dim_output != 0) {
printf("Erreur de dimension dans l'average pooling\n"); printf("Erreur de dimension dans l'average pooling\n");
return; return;

View File

@ -117,7 +117,7 @@ void free_network(Network* network) {
if (network->kernel[i]->cnn != NULL) { // Convolution if (network->kernel[i]->cnn != NULL) { // Convolution
free_convolution(network, i); free_convolution(network, i);
} else if (network->kernel[i]->nn != NULL) { } else if (network->kernel[i]->nn != NULL) {
if (network->depth[i]==1) { // Dense non linearised if (network->kernel[i]->linearisation == 0) { // Dense non linearised
free_dense(network, i); free_dense(network, i);
} else { // Dense lineariation } else { // Dense lineariation
free_dense_linearisation(network, i); free_dense_linearisation(network, i);

View File

@ -31,7 +31,7 @@ void update_weights(Network* network, Network* d_network, int nb_images) {
} }
} }
} else if (k_i->nn) { // Full connection } else if (k_i->nn) { // Full connection
if (input_depth==1) { // Vecteur -> Vecteur if (k_i->linearisation == 0) { // Vecteur -> Vecteur
Kernel_nn* nn = k_i->nn; Kernel_nn* nn = k_i->nn;
Kernel_nn* d_nn = dk_i->nn; Kernel_nn* d_nn = dk_i->nn;
for (int a=0; a<input_width; a++) { for (int a=0; a<input_width; a++) {
@ -119,7 +119,7 @@ void reset_d_weights(Network* network) {
} }
} }
} else if (k_i->nn) { // Full connection } else if (k_i->nn) { // Full connection
if (input_depth==1) { // Vecteur -> Vecteur if (k_i->linearisation == 0) { // Vecteur -> Vecteur
Kernel_nn* nn = k_i_1->nn; Kernel_nn* nn = k_i_1->nn;
for (int a=0; a<input_width; a++) { for (int a=0; a<input_width; a++) {
for (int b=0; b<output_width; b++) { for (int b=0; b<output_width; b++) {