diff --git a/src/mnist/utils.c b/src/mnist/utils.c index 6ba556f..1a8f2e1 100644 --- a/src/mnist/utils.c +++ b/src/mnist/utils.c @@ -27,6 +27,9 @@ void help(char* call) { printf("\t\t--delta | -d [FILENAME]\tFichier de patch à utiliser.\n"); printf("\tprint-images:\n"); printf("\t\t--images | -i [FILENAME]\tFichier contenant les images.\n"); + printf("\tprint-poids-neurone:\n"); + printf("\t\t--reseau | -r [FILENAME]\tFichier contenant le réseau de neurones.\n"); + printf("\t\t--neurone | -n [int]\tNuméro du neurone dont il faut afficher les poids.\n"); } @@ -172,6 +175,24 @@ void print_images(char* filename) { } +void print_poids_neurone(char* filename, int num_neurone) { + Network* network = read_network(filename); + int nb_layers = network->nb_layers; + + Layer* layer = network->layers[nb_layers-2]; + int nb_neurons = layer->nb_neurons; + printf("[\n"); + for (int i=0; i < nb_neurons; i++) { + printf("%f", layer->neurons[i]->weights[num_neurone]); + if (i != nb_neurons -1) + printf(", "); + else + printf("\n"); + } + printf("]\n"); +} + + int main(int argc, char* argv[]) { if (argc < 2) { printf("Pas d'action spécifiée\n"); @@ -294,6 +315,28 @@ int main(int argc, char* argv[]) { } print_images(images); exit(0); + } else if (! strcmp(argv[1], "print-poids-neurone")) { + char* reseau = NULL; + int neurone = 0; + int i = 2; + while (i < argc) { + if ((! strcmp(argv[i], "--reseau"))||(! strcmp(argv[i], "-r"))) { + reseau = argv[i+1]; + i += 2; + } else if ((! strcmp(argv[i], "--neurone"))||(! strcmp(argv[i], "-n"))) { + neurone = strtol(argv[i+1], NULL, 10); + i += 2; + } else { + printf("%s : Argument non reconnu\n", argv[i]); + i++; + } + } + if (!reseau) { + printf("--reseau: Argument obligatoire.\n"); + exit(1); + } + print_poids_neurone(reseau, neurone); + exit(0); } printf("Option choisie non reconnue: %s\n", argv[1]); help(argv[0]);