test_network: Fix number of categories

This commit is contained in:
augustin64 2023-05-31 10:39:10 +02:00
parent d63fb2c870
commit 38bcfb700e

View File

@ -189,8 +189,7 @@ void recognize_jpg(Network* network, char* input_file, char* out) {
imgRawImage* image = loadJpegImageFile(input_file); imgRawImage* image = loadJpegImageFile(input_file);
int height = image->height; int height = image->height;
int width = image->width; int width = image->width;
int nb_categories = network->width[network->size-1];
assert(image->width == image->height);
if (! strcmp(out, "json")) { if (! strcmp(out, "json")) {
printf("{\n"); printf("{\n");
@ -203,15 +202,15 @@ void recognize_jpg(Network* network, char* input_file, char* out) {
if (! strcmp(out, "json")) { if (! strcmp(out, "json")) {
for (int j=0; j < 50; j++) { for (int j=0; j < nb_categories; j++) {
printf("%f", network->input[network->size-1][0][0][j]); printf("%f", network->input[network->size-1][0][0][j]);
if (j+1 < 10) { if (j+1 < nb_categories) {
printf(", "); printf(", ");
} }
} }
} else { } else {
maxi = indice_max(network->input[network->size-1][0][0], 50); maxi = indice_max(network->input[network->size-1][0][0], nb_categories);
printf("Catégorie reconnue: %d\n", maxi); printf("Catégorie reconnue: %d\n", maxi);
} }