From be8d87d4be34080b6e998145ac10e7c48db33fe7 Mon Sep 17 00:00:00 2001 From: Julien Chemillier Date: Tue, 26 Apr 2022 17:46:41 +0200 Subject: [PATCH] Add batches --- src/mnist/main.c | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/mnist/main.c b/src/mnist/main.c index a602c4f..9011f24 100644 --- a/src/mnist/main.c +++ b/src/mnist/main.c @@ -7,6 +7,9 @@ #include "neuron_io.c" #include "mnist.c" +#define EPOCHS 10 +#define BATCHES 50 + void print_image(unsigned int width, unsigned int height, int** image, float* previsions) { char tab[] = {' ', '.', ':', '%', '#', '\0'}; @@ -125,8 +128,15 @@ void train(int batches, int layers, int neurons, char* recovery, char* image_fil } loss += loss_computing(network, labels[j]) / (float)nb_images; free(desired_output); + + if (j%BATCHES==BATCHES-1) + network_modification(network, BATCHES); + } - network_modification(network, nb_images); + + if (nb_images%BATCHES != 0) + network_modification(network, nb_images%BATCHES); + printf("\rBatch [%d/%d]\tImage [%d/%d]\tAccuracy: %0.1f%%\tLoss: %f\n",i, batches, nb_images, nb_images, accuracy*100, loss); write_network(out, network); } @@ -238,7 +248,7 @@ int main(int argc, char* argv[]) { exit(1); } if (! strcmp(argv[1], "train")) { - int batches = 100; + int batches = EPOCHS; int layers = 2; int neurons = 784; char* images = NULL;