From 38bcfb700eb19707b781ccc9eed3cc28e06ecfe1 Mon Sep 17 00:00:00 2001 From: augustin64 Date: Wed, 31 May 2023 10:39:10 +0200 Subject: [PATCH] test_network: Fix number of categories --- src/cnn/test_network.c | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/cnn/test_network.c b/src/cnn/test_network.c index 3212aeb..88c47f6 100644 --- a/src/cnn/test_network.c +++ b/src/cnn/test_network.c @@ -189,8 +189,7 @@ void recognize_jpg(Network* network, char* input_file, char* out) { imgRawImage* image = loadJpegImageFile(input_file); int height = image->height; int width = image->width; - - assert(image->width == image->height); + int nb_categories = network->width[network->size-1]; if (! strcmp(out, "json")) { printf("{\n"); @@ -203,15 +202,15 @@ void recognize_jpg(Network* network, char* input_file, char* out) { if (! strcmp(out, "json")) { - for (int j=0; j < 50; j++) { + for (int j=0; j < nb_categories; j++) { printf("%f", network->input[network->size-1][0][0][j]); - if (j+1 < 10) { + if (j+1 < nb_categories) { printf(", "); } } } else { - maxi = indice_max(network->input[network->size-1][0][0], 50); + maxi = indice_max(network->input[network->size-1][0][0], nb_categories); printf("Catégorie reconnue: %d\n", maxi); }