mirror of
https://github.com/augustin64/projet-tipe
synced 2025-01-24 07:36:24 +01:00
Add recovery option
This commit is contained in:
parent
963a4afcff
commit
cedb240df2
@ -35,6 +35,6 @@ void* train_thread(void* parameters);
|
||||
/*
|
||||
* Fonction principale d'entraînement du réseau neuronal convolutif
|
||||
*/
|
||||
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out);
|
||||
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out, char* recover);
|
||||
|
||||
#endif
|
@ -24,6 +24,7 @@ void help(char* call) {
|
||||
printf("\t(mnist)\t--images | -i [FILENAME]\tFichier contenant les images.\n");
|
||||
printf("\t(mnist)\t--labels | -l [FILENAME]\tFichier contenant les labels.\n");
|
||||
printf("\t (jpg) \t--datadir | -dd [FOLDER]\tDossier contenant les images.\n");
|
||||
printf("\t\t--recover | -r [FILENAME]\tRécupérer depuis un modèle existant.\n");
|
||||
printf("\t\t--epochs | -e [int]\t\tNombre d'époques.\n");
|
||||
printf("\t\t--out | -o [FILENAME]\tFichier où écrire le réseau de neurones.\n");
|
||||
printf("\trecognize:\n");
|
||||
@ -55,6 +56,7 @@ int main(int argc, char* argv[]) {
|
||||
int epochs = EPOCHS;
|
||||
int dataset_type = 0;
|
||||
char* out = NULL;
|
||||
char* recover = NULL;
|
||||
int i = 2;
|
||||
while (i < argc) {
|
||||
if ((! strcmp(argv[i], "--dataset"))||(! strcmp(argv[i], "-d"))) {
|
||||
@ -80,6 +82,9 @@ int main(int argc, char* argv[]) {
|
||||
else if ((! strcmp(argv[i], "--out"))||(! strcmp(argv[i], "-o"))) {
|
||||
out = argv[i+1];
|
||||
i += 2;
|
||||
} else if ((! strcmp(argv[i], "--recover"))||(! strcmp(argv[i], "-r"))) {
|
||||
recover = argv[i+1];
|
||||
i += 2;
|
||||
} else {
|
||||
printf("Option choisie inconnue: %s\n", argv[i]);
|
||||
i++;
|
||||
@ -111,7 +116,7 @@ int main(int argc, char* argv[]) {
|
||||
printf("Pas de fichier de sortie spécifié, défaut: out.bin\n");
|
||||
out = "out.bin";
|
||||
}
|
||||
train(dataset_type, images_file, labels_file, data_dir, epochs, out);
|
||||
train(dataset_type, images_file, labels_file, data_dir, epochs, out, recover);
|
||||
return 0;
|
||||
}
|
||||
if (! strcmp(argv[1], "test")) {
|
||||
|
@ -75,8 +75,9 @@ void* train_thread(void* parameters) {
|
||||
}
|
||||
|
||||
|
||||
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out) {
|
||||
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out, char* recover) {
|
||||
srand(time(NULL));
|
||||
Network* network;
|
||||
int input_dim = -1;
|
||||
int input_depth = -1;
|
||||
float accuracy;
|
||||
@ -111,7 +112,12 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
||||
}
|
||||
|
||||
// Initialisation du réseau
|
||||
Network* network = create_network_lenet5(1, 0, TANH, GLOROT, input_dim, input_depth);
|
||||
if (!recover) {
|
||||
network = create_network_lenet5(1, 0, TANH, GLOROT, input_dim, input_depth);
|
||||
} else {
|
||||
network = read_network(recover);
|
||||
}
|
||||
|
||||
|
||||
shuffle_index = (int*)malloc(sizeof(int)*nb_images_total);
|
||||
for (int i=0; i < nb_images_total; i++) {
|
||||
@ -184,6 +190,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
||||
knuth_shuffle(shuffle_index, nb_images_total);
|
||||
batches_epoques = div_up(nb_images_total, BATCHES);
|
||||
nb_images_total_remaining = nb_images_total;
|
||||
#ifndef USE_MULTITHREADING
|
||||
train_params->nb_images = BATCHES;
|
||||
#endif
|
||||
for (int j=0; j < batches_epoques; j++) {
|
||||
#ifdef USE_MULTITHREADING
|
||||
if (j == batches_epoques-1) {
|
||||
@ -222,6 +231,11 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
||||
(void)nb_images_total_remaining; // Juste pour enlever un warning
|
||||
|
||||
train_params->start = j*BATCHES;
|
||||
|
||||
// Ne pas dépasser le nombre d'images à cause de la partie entière
|
||||
if (j == batches_epoques-1) {
|
||||
train_params->nb_images = nb_images_total - j*BATCHES;
|
||||
}
|
||||
|
||||
train_thread((void*)train_params);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user