diff --git a/src/cnn/include/utils.h b/src/cnn/include/utils.h index 20e9678..bb1579e 100644 --- a/src/cnn/include/utils.h +++ b/src/cnn/include/utils.h @@ -32,4 +32,9 @@ Network* copy_network(Network* network); * Copie les paramètres d'un réseau dans un réseau déjà alloué en mémoire */ void copy_network_parameters(Network* network_src, Network* network_dest); + +/* +* Compte le nombre de poids nuls dans un réseau +*/ +int count_null_weights(Network* network); #endif \ No newline at end of file diff --git a/src/cnn/utils.c b/src/cnn/utils.c index 7a1031f..37241ac 100644 --- a/src/cnn/utils.c +++ b/src/cnn/utils.c @@ -2,6 +2,7 @@ #include #include #include +#include #include "../include/memory_management.h" #include "../include/colors.h" @@ -303,4 +304,65 @@ void copy_network_parameters(Network* network_src, Network* network_dest) { } } } +} + + +int count_null_weights(Network* network) { + float epsilon = 0.000001; + + int null_weights = 0; + int null_bias = 0; + + int size = network->size; + // Paramètres des couches NN + int input_units; + int output_units; + // Paramètres des couches CNN + int rows; + int k_size; + int columns; + int output_dim; + + for (int i=0; i < size-1; i++) { + if (!network->kernel[i]->cnn && network->kernel[i]->nn) { // Cas du NN + + input_units = network->kernel[i]->nn->input_units; + output_units = network->kernel[i]->nn->output_units; + + for (int j=0; j < output_units; j++) { + null_bias += fabs(network->kernel[i]->nn->bias[j]) <= epsilon; + } + for (int j=0; j < input_units; j++) { + for (int k=0; k < output_units; k++) { + null_weights += fabs(network->kernel[i]->nn->weights[j][k]) <= epsilon; + } + } + } + else if (network->kernel[i]->cnn && !network->kernel[i]->nn) { // Cas du CNN + + rows = network->kernel[i]->cnn->rows; + k_size = network->kernel[i]->cnn->k_size; + columns = network->kernel[i]->cnn->columns; + output_dim = network->width[i+1]; + + for (int j=0; j < columns; j++) { + for (int k=0; k < output_dim; k++) { + for (int l=0; l < output_dim; l++) { + null_bias += fabs(network->kernel[i]->cnn->bias[j][k][l]) <= epsilon; + } + } + } + for (int j=0; j < rows; j++) { + for (int k=0; k < columns; k++) { + for (int l=0; l < k_size; l++) { + for (int m=0; m < k_size; m++) { + null_weights = fabs(network->kernel[i]->cnn->w[j][k][l][m]) <= epsilon; + } + } + } + } + } + } + + return null_weights; } \ No newline at end of file