Add count_null_weights

This commit is contained in:
augustin64 2023-02-19 10:22:42 +01:00
parent db92b367ad
commit 63ef37dc56
2 changed files with 67 additions and 0 deletions

View File

@ -32,4 +32,9 @@ Network* copy_network(Network* network);
* Copie les paramètres d'un réseau dans un réseau déjà alloué en mémoire
*/
void copy_network_parameters(Network* network_src, Network* network_dest);
/*
* Compte le nombre de poids nuls dans un réseau
*/
int count_null_weights(Network* network);
#endif

View File

@ -2,6 +2,7 @@
#include <stdio.h>
#include <stdbool.h>
#include <string.h>
#include <math.h>
#include "../include/memory_management.h"
#include "../include/colors.h"
@ -303,4 +304,65 @@ void copy_network_parameters(Network* network_src, Network* network_dest) {
}
}
}
}
int count_null_weights(Network* network) {
float epsilon = 0.000001;
int null_weights = 0;
int null_bias = 0;
int size = network->size;
// Paramètres des couches NN
int input_units;
int output_units;
// Paramètres des couches CNN
int rows;
int k_size;
int columns;
int output_dim;
for (int i=0; i < size-1; i++) {
if (!network->kernel[i]->cnn && network->kernel[i]->nn) { // Cas du NN
input_units = network->kernel[i]->nn->input_units;
output_units = network->kernel[i]->nn->output_units;
for (int j=0; j < output_units; j++) {
null_bias += fabs(network->kernel[i]->nn->bias[j]) <= epsilon;
}
for (int j=0; j < input_units; j++) {
for (int k=0; k < output_units; k++) {
null_weights += fabs(network->kernel[i]->nn->weights[j][k]) <= epsilon;
}
}
}
else if (network->kernel[i]->cnn && !network->kernel[i]->nn) { // Cas du CNN
rows = network->kernel[i]->cnn->rows;
k_size = network->kernel[i]->cnn->k_size;
columns = network->kernel[i]->cnn->columns;
output_dim = network->width[i+1];
for (int j=0; j < columns; j++) {
for (int k=0; k < output_dim; k++) {
for (int l=0; l < output_dim; l++) {
null_bias += fabs(network->kernel[i]->cnn->bias[j][k][l]) <= epsilon;
}
}
}
for (int j=0; j < rows; j++) {
for (int k=0; k < columns; k++) {
for (int l=0; l < k_size; l++) {
for (int m=0; m < k_size; m++) {
null_weights = fabs(network->kernel[i]->cnn->w[j][k][l][m]) <= epsilon;
}
}
}
}
}
}
return null_weights;
}