From e584dfc79115092103c24a0ab046b249f124c4c1 Mon Sep 17 00:00:00 2001 From: augustin64 Date: Wed, 30 Nov 2022 10:21:56 +0100 Subject: [PATCH] Add test-network option implementation --- src/cnn/test_network.c | 85 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/src/cnn/test_network.c b/src/cnn/test_network.c index b51ad75..70b778b 100644 --- a/src/cnn/test_network.c +++ b/src/cnn/test_network.c @@ -9,8 +9,91 @@ #include "include/free.h" #include "include/cnn.h" -void test_network(int dataset_type, char* modele, char* images_file, char* labels_file, char* data_dir, bool preview_fails) { +void test_network_mnist(Network* network, char* images_file, char* labels_file, bool preview_fails) { + (void)preview_fails; // Inutilisé pour le moment + int width, height; // Dimensions des images + int nb_elem; // Nombre d'éléments + int maxi; // Catégorie reconnue + + int accuracy = 0; // Nombre d'images reconnues + + // Load image + int* mnist_parameters = read_mnist_images_parameters(images_file); + + int*** images = read_mnist_images(images_file); + unsigned int* labels = read_mnist_labels(labels_file); + + nb_elem = mnist_parameters[0]; + + width = mnist_parameters[1]; + height = mnist_parameters[2]; + free(mnist_parameters); + + // Load image in the first layer of the Network + for (int i=0; i < nb_elem; i++) { + if(i %(nb_elem/100) == 0) { + printf("Avancement: %.0f%%\r", 100*i/(float)nb_elem); + fflush(stdout); + } + write_image_in_network_32(images[i], height, width, network->input[0][0]); + forward_propagation(network); + maxi = indice_max(network->input[network->size-1][0][0], 10); + + if (maxi == (int)labels[i]) { + accuracy++; + } + + for (int j=0; j < height; j++) { + free(images[i][j]); + } + free(images[i]); + } + free(images); + printf("%d Images. Taux de réussite: %.2f%%\n", nb_elem, 100*accuracy/(float)nb_elem); +} + + +void test_network_jpg(Network* network, char* data_dir, bool preview_fails) { + (void)preview_fails; // Inutilisé pour le moment + jpegDataset* dataset = loadJpegDataset(data_dir); + + int accuracy = 0; + int maxi; + + for (int i=0; i < (int)dataset->numImages; i++) { + if(i %(dataset->numImages/100) == 0) { + printf("Avancement: %.1f%%\r", 1000*i/(float)dataset->numImages); + fflush(stdout); + } + write_image_in_network_260(dataset->images[i], dataset->height, dataset->height, network->input[0]); + forward_propagation(network); + maxi = indice_max(network->input[network->size-1][0][0], 50); + + if (maxi == (int)dataset->labels[i]) { + accuracy++; + } + + free(dataset->images[i]); + } + + printf("%d Images. Taux de réussite: %.2f%%\n", dataset->numImages, 100*accuracy/(float)dataset->numImages); + free(dataset->images); + free(dataset->labels); + free(dataset); +} + + +void test_network(int dataset_type, char* modele, char* images_file, char* labels_file, char* data_dir, bool preview_fails) { + Network* network = read_network(modele); + + if (dataset_type == 0) { + test_network_mnist(network, images_file, labels_file, preview_fails); + } else { + test_network_jpg(network, data_dir, preview_fails); + } + + free_network(network); }