diff --git a/src/mnist/main.c b/src/mnist/main.c index 24fafbd..06ec50a 100644 --- a/src/mnist/main.c +++ b/src/mnist/main.c @@ -64,6 +64,8 @@ void help(char* call) { printf("\t\t--labels | -l [FILENAME]\tFichier contenant les labels.\n"); printf("\t\t--out | -o [FILENAME]\tFichier où écrire le réseau de neurones.\n"); printf("\t\t--delta | -d [FILENAME]\tFichier où écrire le réseau différentiel.\n"); + printf("\t\t--nb-images | -N [int]\tNombres d'images à traiter.\n"); + printf("\t\t--start | -s [int]\tPremière image à traiter.\n"); printf("\trecognize:\n"); printf("\t\t--modele | -m [FILENAME]\tFichier contenant le réseau de neurones.\n"); printf("\t\t--in | -i [FILENAME]\tFichier contenant les images à reconnaître.\n"); @@ -120,7 +122,7 @@ void* train_images(void* parameters) { } -void train(int epochs, int layers, int neurons, char* recovery, char* image_file, char* label_file, char* out, char* delta) { +void train(int epochs, int layers, int neurons, char* recovery, char* image_file, char* label_file, char* out, char* delta, int nb_images_to_process, int start) { // Entraînement du réseau sur le set de données MNIST Network* network; Network* delta_network; @@ -172,6 +174,10 @@ void train(int epochs, int layers, int neurons, char* recovery, char* image_file int*** images = read_mnist_images(image_file); unsigned int* labels = read_mnist_labels(label_file); + if (nb_images_to_process != -1) { + nb_images_total = nb_images_to_process; + } + TrainParameters** train_parameters = (TrainParameters**)malloc(sizeof(TrainParameters*)*nb_threads); for (int i=0; i < epochs; i++) { accuracy = 0.; @@ -184,7 +190,7 @@ void train(int epochs, int layers, int neurons, char* recovery, char* image_file train_parameters[j]->images = (int***)images; train_parameters[j]->labels = (int*)labels; train_parameters[j]->nb_images = BATCHES / nb_threads; - train_parameters[j]->start = nb_images_total - BATCHES*(nb_images_total / BATCHES - k -1) - nb_remaining_images; + train_parameters[j]->start = nb_images_total - BATCHES*(nb_images_total / BATCHES - k -1) - nb_remaining_images + start; train_parameters[j]->height = height; train_parameters[j]->width = width; @@ -327,6 +333,8 @@ int main(int argc, char* argv[]) { int epochs = EPOCHS; int layers = 2; int neurons = 784; + int nb_images = -1; + int start = 0; char* images = NULL; char* labels = NULL; char* recovery = NULL; @@ -360,6 +368,12 @@ int main(int argc, char* argv[]) { } else if ((! strcmp(argv[i], "--delta"))||(! strcmp(argv[i], "-d"))) { delta = argv[i+1]; i += 2; + } else if ((! strcmp(argv[i], "--nb-images"))||(! strcmp(argv[i], "-N"))) { + nb_images = strtol(argv[i+1], NULL, 10); + i += 2; + } else if ((! strcmp(argv[i], "--start"))||(! strcmp(argv[i], "-s"))) { + start = strtol(argv[i+1], NULL, 10); + i += 2; } else { printf("%s : Argument non reconnu\n", argv[i]); i++; @@ -378,7 +392,7 @@ int main(int argc, char* argv[]) { out = "out.bin"; } // Entraînement en sourçant neural_network.c - train(epochs, layers, neurons, recovery, images, labels, out, delta); + train(epochs, layers, neurons, recovery, images, labels, out, delta, nb_images, start); exit(0); } if (! strcmp(argv[1], "recognize")) {