mirror of
https://github.com/augustin64/projet-tipe
synced 2025-02-03 10:48:01 +01:00
Add recovery option
This commit is contained in:
parent
963a4afcff
commit
cedb240df2
@ -35,6 +35,6 @@ void* train_thread(void* parameters);
|
|||||||
/*
|
/*
|
||||||
* Fonction principale d'entraînement du réseau neuronal convolutif
|
* Fonction principale d'entraînement du réseau neuronal convolutif
|
||||||
*/
|
*/
|
||||||
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out);
|
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out, char* recover);
|
||||||
|
|
||||||
#endif
|
#endif
|
@ -24,6 +24,7 @@ void help(char* call) {
|
|||||||
printf("\t(mnist)\t--images | -i [FILENAME]\tFichier contenant les images.\n");
|
printf("\t(mnist)\t--images | -i [FILENAME]\tFichier contenant les images.\n");
|
||||||
printf("\t(mnist)\t--labels | -l [FILENAME]\tFichier contenant les labels.\n");
|
printf("\t(mnist)\t--labels | -l [FILENAME]\tFichier contenant les labels.\n");
|
||||||
printf("\t (jpg) \t--datadir | -dd [FOLDER]\tDossier contenant les images.\n");
|
printf("\t (jpg) \t--datadir | -dd [FOLDER]\tDossier contenant les images.\n");
|
||||||
|
printf("\t\t--recover | -r [FILENAME]\tRécupérer depuis un modèle existant.\n");
|
||||||
printf("\t\t--epochs | -e [int]\t\tNombre d'époques.\n");
|
printf("\t\t--epochs | -e [int]\t\tNombre d'époques.\n");
|
||||||
printf("\t\t--out | -o [FILENAME]\tFichier où écrire le réseau de neurones.\n");
|
printf("\t\t--out | -o [FILENAME]\tFichier où écrire le réseau de neurones.\n");
|
||||||
printf("\trecognize:\n");
|
printf("\trecognize:\n");
|
||||||
@ -55,6 +56,7 @@ int main(int argc, char* argv[]) {
|
|||||||
int epochs = EPOCHS;
|
int epochs = EPOCHS;
|
||||||
int dataset_type = 0;
|
int dataset_type = 0;
|
||||||
char* out = NULL;
|
char* out = NULL;
|
||||||
|
char* recover = NULL;
|
||||||
int i = 2;
|
int i = 2;
|
||||||
while (i < argc) {
|
while (i < argc) {
|
||||||
if ((! strcmp(argv[i], "--dataset"))||(! strcmp(argv[i], "-d"))) {
|
if ((! strcmp(argv[i], "--dataset"))||(! strcmp(argv[i], "-d"))) {
|
||||||
@ -80,6 +82,9 @@ int main(int argc, char* argv[]) {
|
|||||||
else if ((! strcmp(argv[i], "--out"))||(! strcmp(argv[i], "-o"))) {
|
else if ((! strcmp(argv[i], "--out"))||(! strcmp(argv[i], "-o"))) {
|
||||||
out = argv[i+1];
|
out = argv[i+1];
|
||||||
i += 2;
|
i += 2;
|
||||||
|
} else if ((! strcmp(argv[i], "--recover"))||(! strcmp(argv[i], "-r"))) {
|
||||||
|
recover = argv[i+1];
|
||||||
|
i += 2;
|
||||||
} else {
|
} else {
|
||||||
printf("Option choisie inconnue: %s\n", argv[i]);
|
printf("Option choisie inconnue: %s\n", argv[i]);
|
||||||
i++;
|
i++;
|
||||||
@ -111,7 +116,7 @@ int main(int argc, char* argv[]) {
|
|||||||
printf("Pas de fichier de sortie spécifié, défaut: out.bin\n");
|
printf("Pas de fichier de sortie spécifié, défaut: out.bin\n");
|
||||||
out = "out.bin";
|
out = "out.bin";
|
||||||
}
|
}
|
||||||
train(dataset_type, images_file, labels_file, data_dir, epochs, out);
|
train(dataset_type, images_file, labels_file, data_dir, epochs, out, recover);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
if (! strcmp(argv[1], "test")) {
|
if (! strcmp(argv[1], "test")) {
|
||||||
|
@ -75,8 +75,9 @@ void* train_thread(void* parameters) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out) {
|
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out, char* recover) {
|
||||||
srand(time(NULL));
|
srand(time(NULL));
|
||||||
|
Network* network;
|
||||||
int input_dim = -1;
|
int input_dim = -1;
|
||||||
int input_depth = -1;
|
int input_depth = -1;
|
||||||
float accuracy;
|
float accuracy;
|
||||||
@ -111,7 +112,12 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialisation du réseau
|
// Initialisation du réseau
|
||||||
Network* network = create_network_lenet5(1, 0, TANH, GLOROT, input_dim, input_depth);
|
if (!recover) {
|
||||||
|
network = create_network_lenet5(1, 0, TANH, GLOROT, input_dim, input_depth);
|
||||||
|
} else {
|
||||||
|
network = read_network(recover);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
shuffle_index = (int*)malloc(sizeof(int)*nb_images_total);
|
shuffle_index = (int*)malloc(sizeof(int)*nb_images_total);
|
||||||
for (int i=0; i < nb_images_total; i++) {
|
for (int i=0; i < nb_images_total; i++) {
|
||||||
@ -184,6 +190,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
knuth_shuffle(shuffle_index, nb_images_total);
|
knuth_shuffle(shuffle_index, nb_images_total);
|
||||||
batches_epoques = div_up(nb_images_total, BATCHES);
|
batches_epoques = div_up(nb_images_total, BATCHES);
|
||||||
nb_images_total_remaining = nb_images_total;
|
nb_images_total_remaining = nb_images_total;
|
||||||
|
#ifndef USE_MULTITHREADING
|
||||||
|
train_params->nb_images = BATCHES;
|
||||||
|
#endif
|
||||||
for (int j=0; j < batches_epoques; j++) {
|
for (int j=0; j < batches_epoques; j++) {
|
||||||
#ifdef USE_MULTITHREADING
|
#ifdef USE_MULTITHREADING
|
||||||
if (j == batches_epoques-1) {
|
if (j == batches_epoques-1) {
|
||||||
@ -223,6 +232,11 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
|
|
||||||
train_params->start = j*BATCHES;
|
train_params->start = j*BATCHES;
|
||||||
|
|
||||||
|
// Ne pas dépasser le nombre d'images à cause de la partie entière
|
||||||
|
if (j == batches_epoques-1) {
|
||||||
|
train_params->nb_images = nb_images_total - j*BATCHES;
|
||||||
|
}
|
||||||
|
|
||||||
train_thread((void*)train_params);
|
train_thread((void*)train_params);
|
||||||
|
|
||||||
accuracy += train_params->accuracy / (float) nb_images_total;
|
accuracy += train_params->accuracy / (float) nb_images_total;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user