Add recovery option

This commit is contained in:
augustin64 2022-12-07 13:09:39 +01:00
parent 963a4afcff
commit cedb240df2
3 changed files with 23 additions and 4 deletions

View File

@ -35,6 +35,6 @@ void* train_thread(void* parameters);
/*
* Fonction principale d'entraînement du réseau neuronal convolutif
*/
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out);
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out, char* recover);
#endif

View File

@ -24,6 +24,7 @@ void help(char* call) {
printf("\t(mnist)\t--images | -i [FILENAME]\tFichier contenant les images.\n");
printf("\t(mnist)\t--labels | -l [FILENAME]\tFichier contenant les labels.\n");
printf("\t (jpg) \t--datadir | -dd [FOLDER]\tDossier contenant les images.\n");
printf("\t\t--recover | -r [FILENAME]\tRécupérer depuis un modèle existant.\n");
printf("\t\t--epochs | -e [int]\t\tNombre d'époques.\n");
printf("\t\t--out | -o [FILENAME]\tFichier où écrire le réseau de neurones.\n");
printf("\trecognize:\n");
@ -55,6 +56,7 @@ int main(int argc, char* argv[]) {
int epochs = EPOCHS;
int dataset_type = 0;
char* out = NULL;
char* recover = NULL;
int i = 2;
while (i < argc) {
if ((! strcmp(argv[i], "--dataset"))||(! strcmp(argv[i], "-d"))) {
@ -80,6 +82,9 @@ int main(int argc, char* argv[]) {
else if ((! strcmp(argv[i], "--out"))||(! strcmp(argv[i], "-o"))) {
out = argv[i+1];
i += 2;
} else if ((! strcmp(argv[i], "--recover"))||(! strcmp(argv[i], "-r"))) {
recover = argv[i+1];
i += 2;
} else {
printf("Option choisie inconnue: %s\n", argv[i]);
i++;
@ -111,7 +116,7 @@ int main(int argc, char* argv[]) {
printf("Pas de fichier de sortie spécifié, défaut: out.bin\n");
out = "out.bin";
}
train(dataset_type, images_file, labels_file, data_dir, epochs, out);
train(dataset_type, images_file, labels_file, data_dir, epochs, out, recover);
return 0;
}
if (! strcmp(argv[1], "test")) {

View File

@ -75,8 +75,9 @@ void* train_thread(void* parameters) {
}
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out) {
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out, char* recover) {
srand(time(NULL));
Network* network;
int input_dim = -1;
int input_depth = -1;
float accuracy;
@ -111,7 +112,12 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
}
// Initialisation du réseau
Network* network = create_network_lenet5(1, 0, TANH, GLOROT, input_dim, input_depth);
if (!recover) {
network = create_network_lenet5(1, 0, TANH, GLOROT, input_dim, input_depth);
} else {
network = read_network(recover);
}
shuffle_index = (int*)malloc(sizeof(int)*nb_images_total);
for (int i=0; i < nb_images_total; i++) {
@ -184,6 +190,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
knuth_shuffle(shuffle_index, nb_images_total);
batches_epoques = div_up(nb_images_total, BATCHES);
nb_images_total_remaining = nb_images_total;
#ifndef USE_MULTITHREADING
train_params->nb_images = BATCHES;
#endif
for (int j=0; j < batches_epoques; j++) {
#ifdef USE_MULTITHREADING
if (j == batches_epoques-1) {
@ -222,6 +231,11 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
(void)nb_images_total_remaining; // Juste pour enlever un warning
train_params->start = j*BATCHES;
// Ne pas dépasser le nombre d'images à cause de la partie entière
if (j == batches_epoques-1) {
train_params->nb_images = nb_images_total - j*BATCHES;
}
train_thread((void*)train_params);