diff --git a/src/mnist/main.c b/src/mnist/main.c index 3ced45b..47e66cb 100644 --- a/src/mnist/main.c +++ b/src/mnist/main.c @@ -33,9 +33,13 @@ void help(char* call) { printf("\t\t--labels | -l [FILENAME]\tFichier contenant les labels.\n"); printf("\t\t--out | -o [FILENAME]\tFichier où écrire le réseau de neurones.\n"); printf("\trecognize:\n"); - printf("\t\t--modele | -m [FILENAME]\tFichier contenant le réseau de neurones.\n"); - printf("\t\t--in | -i [FILENAME]\tFichier contenant les images à reconnaître.\n"); - printf("\t\t--out | -o (text|json)\tFormat de sortie.\n"); + printf("\t\t--modele | -m [FILENAME]\tFichier contenant le réseau de neurones.\n"); + printf("\t\t--in | -i [FILENAME]\tFichier contenant les images à reconnaître.\n"); + printf("\t\t--out | -o (text|json)\tFormat de sortie.\n"); + printf("\ttest:\n"); + printf("\t\t--images | -i [FILENAME]\tFichier contenant les images.\n"); + printf("\t\t--labels | -l [FILENAME]\tFichier contenant les labels.\n"); + printf("\t\t--modele | -m [FILENAME]\tFichier contenant le réseau de neurones.\n"); } @@ -111,9 +115,10 @@ void train(int batches, int couches, int neurons, char* recovery, char* image_fi modification_du_reseau_neuronal(reseau); ecrire_reseau(out, reseau); } + suppression_du_reseau_neuronal(reseau); } -void recognize(char* modele, char* entree, char* sortie) { +float** recognize(char* modele, char* entree) { Reseau* reseau = lire_reseau(modele); Couche* derniere_couche = reseau->couches[reseau->nb_couches-1]; @@ -123,6 +128,33 @@ void recognize(char* modele, char* entree, char* sortie) { int width = parameters[2]; int*** images = read_mnist_images(entree); + float** results = malloc(sizeof(float*)*nb_images); + + for (int i=0; i < nb_images; i++) { + results[i] = malloc(sizeof(float)*derniere_couche->nb_neurones); + + ecrire_image_dans_reseau(images[i], reseau, height, width); + forward_propagation(reseau); + + for (int j=0; j < derniere_couche->nb_neurones; j++) { + results[i][j] = derniere_couche->neurones[j]->z; + } + } + suppression_du_reseau_neuronal(reseau); + + return results; +} + +void print_recognize(char* modele, char* entree, char* sortie) { + Reseau* reseau = lire_reseau(modele); + int nb_der_couche = reseau->couches[reseau->nb_couches-1]->nb_neurones; + + suppression_du_reseau_neuronal(reseau); + + int* parameters = read_mnist_images_parameters(entree); + int nb_images = parameters[0]; + + float** resultats = recognize(modele, entree); if (! strcmp(sortie, "json")) { printf("{\n"); @@ -133,18 +165,15 @@ void recognize(char* modele, char* entree, char* sortie) { else printf("\"%d\" : [", i); - ecrire_image_dans_reseau(images[i], reseau, height, width); - forward_propagation(reseau); - - for (int j=0; j < derniere_couche->nb_neurones; j++) { + for (int j=0; j < nb_der_couche; j++) { if (! strcmp(sortie, "json")) { - printf("%f", derniere_couche->neurones[j]->z); + printf("%f", resultats[i][j]); - if (j+1 < derniere_couche->nb_neurones) { + if (j+1 < nb_der_couche) { printf(", "); } } else - printf("Probabilité %d: %f\n", j, derniere_couche->neurones[j]->z); + printf("Probabilité %d: %f\n", j, resultats[i][j]); } if (! strcmp(sortie, "json")) { if (i+1 < nb_images) { @@ -159,6 +188,27 @@ void recognize(char* modele, char* entree, char* sortie) { } +void test(char* modele, char* fichier_images, char* fichier_labels) { + Reseau* reseau = lire_reseau(modele); + int nb_der_couche = reseau->couches[reseau->nb_couches-1]->nb_neurones; + + suppression_du_reseau_neuronal(reseau); + + int* parameters = read_mnist_images_parameters(fichier_images); + int nb_images = parameters[0]; + + float** resultats = recognize(modele, fichier_images); + unsigned int* labels = read_mnist_labels(fichier_labels); + float accuracy; + + for (int i=0; i < nb_images; i++) { + if (indice_max(resultats[i], nb_der_couche) == labels[i]) { + accuracy += 1. / (float)nb_images; + } + } + printf("%d Images\tAccuracy: %0.1f%%\n", nb_images, accuracy*100); +} + int main(int argc, char* argv[]) { if (argc < 2) { @@ -174,7 +224,7 @@ int main(int argc, char* argv[]) { char* labels = NULL; char* recovery = NULL; char* out = NULL; - int i=2; + int i = 2; while (i < argc) { // Utiliser un switch serait sans doute plus élégant if ((! strcmp(argv[i], "--batches"))||(! strcmp(argv[i], "-b"))) { @@ -224,7 +274,7 @@ int main(int argc, char* argv[]) { char* in = NULL; char* modele = NULL; char* out = NULL; - int i=2; + int i = 2; while(i < argc) { if ((! strcmp(argv[i], "--in"))||(! strcmp(argv[i], "-i"))) { in = argv[i+1]; @@ -251,10 +301,30 @@ int main(int argc, char* argv[]) { if (! out) { out = "text"; } - recognize(modele, in, out); + print_recognize(modele, in, out); // Reconnaissance puis affichage des données sous le format spécifié exit(0); } + if (! strcmp(argv[1], "test")) { + char* modele = NULL; + char* images = NULL; + char* labels = NULL; + int i = 2; + while (i < argc) { + if ((! strcmp(argv[i], "--images"))||(! strcmp(argv[i], "-i"))) { + images = argv[i+1]; + i += 2; + } else if ((! strcmp(argv[i], "--labels"))||(! strcmp(argv[i], "-l"))) { + labels = argv[i+1]; + i += 2; + } else if ((! strcmp(argv[i], "--modele"))||(! strcmp(argv[i], "-m"))) { + modele = argv[i+1]; + i += 2; + } + } + test(modele, images, labels); + exit(0); + } printf("Option choisie non reconnue: %s\n", argv[1]); help(argv[0]); return 1;