From 61468044a97e3d87457e28c05b9b4aef473212a7 Mon Sep 17 00:00:00 2001 From: augustin64 Date: Tue, 17 Jan 2023 12:50:35 +0100 Subject: [PATCH] Update train.c --- src/cnn/train.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cnn/train.c b/src/cnn/train.c index 7ac13b1..87e6399 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -46,7 +46,7 @@ void* train_thread(void* parameters) { write_image_in_network_32(images[index[i]], height, width, network->input[0][0]); forward_propagation(network); maxi = indice_max(network->input[network->size-1][0][0], 10); - backward_propagation(network, labels[i]); + backward_propagation(network, labels[index[i]]); if (maxi == labels[index[i]]) { accuracy += 1.; @@ -294,7 +294,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di #else free(train_params); #endif - + if (dataset_type == 0) { for (int i=0; i < nb_images_total; i++) { for (int j=0; j < 28; j++) {