From a21481d1cccfe4013e58981e3dff0d277d2ba52a Mon Sep 17 00:00:00 2001 From: julienChemillier Date: Wed, 18 Jan 2023 10:25:46 +0100 Subject: [PATCH] Change 'wanted_number' from float to int --- src/cnn/cnn.c | 10 +++++----- src/cnn/include/cnn.h | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/cnn/cnn.c b/src/cnn/cnn.c index 744a1df..c4a8357 100644 --- a/src/cnn/cnn.c +++ b/src/cnn/cnn.c @@ -126,8 +126,8 @@ void forward_propagation(Network* network) { } } -void backward_propagation(Network* network, float wanted_number) { - float* wanted_output = generate_wanted_output(wanted_number); +void backward_propagation(Network* network, int wanted_number) { + float* wanted_output = generate_wanted_output(wanted_number, 10); int n = network->size; int activation, input_depth, input_width, output_depth, output_width; float*** input; @@ -215,9 +215,9 @@ float compute_cross_entropy_loss(float* output, float* wanted_output, int len) { return loss; } -float* generate_wanted_output(float wanted_number) { - float* wanted_output = (float*)malloc(sizeof(float)*10); - for (int i=0; i < 10; i++) { +float* generate_wanted_output(int wanted_number, int size_output) { + float* wanted_output = (float*)malloc(sizeof(float)*size_output); + for (int i=0; i < size_output; i++) { if (i==wanted_number) { wanted_output[i]=1; } diff --git a/src/cnn/include/cnn.h b/src/cnn/include/cnn.h index 884d13a..749084a 100644 --- a/src/cnn/include/cnn.h +++ b/src/cnn/include/cnn.h @@ -32,7 +32,7 @@ void forward_propagation(Network* network); /* * Propage en arrière le cnn */ -void backward_propagation(Network* network, float wanted_number); +void backward_propagation(Network* network, int wanted_number); /* * Met à 0 chaque valeur de l'input avec une probabilité de dropout % @@ -57,6 +57,6 @@ float compute_cross_entropy_loss(float* output, float* wanted_output, int len); /* * On considère que la sortie voulue comporte 10 éléments */ -float* generate_wanted_output(float wanted_number); +float* generate_wanted_output(int wanted_number, int size_output); #endif \ No newline at end of file