diff --git a/src/cnn/include/train.h b/src/cnn/include/train.h index 2cfe564..9057f8b 100644 --- a/src/cnn/include/train.h +++ b/src/cnn/include/train.h @@ -35,6 +35,6 @@ void* train_thread(void* parameters); /* * Fonction principale d'entraînement du réseau neuronal convolutif */ -void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out); +void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out, char* recover); #endif \ No newline at end of file diff --git a/src/cnn/main.c b/src/cnn/main.c index 07915df..ba90b6f 100644 --- a/src/cnn/main.c +++ b/src/cnn/main.c @@ -24,6 +24,7 @@ void help(char* call) { printf("\t(mnist)\t--images | -i [FILENAME]\tFichier contenant les images.\n"); printf("\t(mnist)\t--labels | -l [FILENAME]\tFichier contenant les labels.\n"); printf("\t (jpg) \t--datadir | -dd [FOLDER]\tDossier contenant les images.\n"); + printf("\t\t--recover | -r [FILENAME]\tRécupérer depuis un modèle existant.\n"); printf("\t\t--epochs | -e [int]\t\tNombre d'époques.\n"); printf("\t\t--out | -o [FILENAME]\tFichier où écrire le réseau de neurones.\n"); printf("\trecognize:\n"); @@ -55,6 +56,7 @@ int main(int argc, char* argv[]) { int epochs = EPOCHS; int dataset_type = 0; char* out = NULL; + char* recover = NULL; int i = 2; while (i < argc) { if ((! strcmp(argv[i], "--dataset"))||(! strcmp(argv[i], "-d"))) { @@ -80,6 +82,9 @@ int main(int argc, char* argv[]) { else if ((! strcmp(argv[i], "--out"))||(! strcmp(argv[i], "-o"))) { out = argv[i+1]; i += 2; + } else if ((! strcmp(argv[i], "--recover"))||(! strcmp(argv[i], "-r"))) { + recover = argv[i+1]; + i += 2; } else { printf("Option choisie inconnue: %s\n", argv[i]); i++; @@ -111,7 +116,7 @@ int main(int argc, char* argv[]) { printf("Pas de fichier de sortie spécifié, défaut: out.bin\n"); out = "out.bin"; } - train(dataset_type, images_file, labels_file, data_dir, epochs, out); + train(dataset_type, images_file, labels_file, data_dir, epochs, out, recover); return 0; } if (! strcmp(argv[1], "test")) { diff --git a/src/cnn/train.c b/src/cnn/train.c index e86434b..bb807de 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -75,8 +75,9 @@ void* train_thread(void* parameters) { } -void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out) { +void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out, char* recover) { srand(time(NULL)); + Network* network; int input_dim = -1; int input_depth = -1; float accuracy; @@ -111,7 +112,12 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di } // Initialisation du réseau - Network* network = create_network_lenet5(1, 0, TANH, GLOROT, input_dim, input_depth); + if (!recover) { + network = create_network_lenet5(1, 0, TANH, GLOROT, input_dim, input_depth); + } else { + network = read_network(recover); + } + shuffle_index = (int*)malloc(sizeof(int)*nb_images_total); for (int i=0; i < nb_images_total; i++) { @@ -184,6 +190,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di knuth_shuffle(shuffle_index, nb_images_total); batches_epoques = div_up(nb_images_total, BATCHES); nb_images_total_remaining = nb_images_total; + #ifndef USE_MULTITHREADING + train_params->nb_images = BATCHES; + #endif for (int j=0; j < batches_epoques; j++) { #ifdef USE_MULTITHREADING if (j == batches_epoques-1) { @@ -222,6 +231,11 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di (void)nb_images_total_remaining; // Juste pour enlever un warning train_params->start = j*BATCHES; + + // Ne pas dépasser le nombre d'images à cause de la partie entière + if (j == batches_epoques-1) { + train_params->nb_images = nb_images_total - j*BATCHES; + } train_thread((void*)train_params);