mirror of
https://github.com/augustin64/projet-tipe
synced 2025-02-02 19:39:39 +01:00
Trying to improve train.c readability
This commit is contained in:
parent
7d7cd2e3a7
commit
81ff4f4d00
@ -104,15 +104,25 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
srand(time(NULL));
|
srand(time(NULL));
|
||||||
Network* network;
|
|
||||||
int input_dim = -1;
|
|
||||||
int input_depth = -1;
|
|
||||||
|
|
||||||
float loss;
|
float loss;
|
||||||
float batch_loss; // May be redundant with loss, but gives more informations
|
float batch_loss; // May be redundant with loss, but gives more informations
|
||||||
float accuracy;
|
float accuracy;
|
||||||
float current_accuracy;
|
float current_accuracy;
|
||||||
|
|
||||||
|
|
||||||
|
//* Différents timers pour mesurer les performance en terme de vitesse
|
||||||
|
double start_time, end_time;
|
||||||
|
double elapsed_time;
|
||||||
|
|
||||||
|
double algo_start = omp_get_wtime();
|
||||||
|
|
||||||
|
start_time = omp_get_wtime();
|
||||||
|
|
||||||
|
|
||||||
|
//* Chargement du dataset
|
||||||
|
int input_dim = -1;
|
||||||
|
int input_depth = -1;
|
||||||
|
|
||||||
int nb_images_total; // Images au total
|
int nb_images_total; // Images au total
|
||||||
int nb_images_total_remaining; // Images restantes dans un batch
|
int nb_images_total_remaining; // Images restantes dans un batch
|
||||||
int batches_epoques; // Batches par époque
|
int batches_epoques; // Batches par époque
|
||||||
@ -120,15 +130,6 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
int*** images; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
|
int*** images; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
|
||||||
unsigned int* labels; // Labels associés aux images du dataset MNIST
|
unsigned int* labels; // Labels associés aux images du dataset MNIST
|
||||||
jpegDataset* dataset; // Structure de données décrivant un dataset d'images jpeg
|
jpegDataset* dataset; // Structure de données décrivant un dataset d'images jpeg
|
||||||
int* shuffle_index; // shuffle_index[i] contient le nouvel index de l'élément à l'emplacement i avant mélange
|
|
||||||
|
|
||||||
double start_time, end_time;
|
|
||||||
double elapsed_time;
|
|
||||||
|
|
||||||
double algo_start = omp_get_wtime();
|
|
||||||
|
|
||||||
start_time = omp_get_wtime();
|
|
||||||
|
|
||||||
if (dataset_type == 0) { // Type MNIST
|
if (dataset_type == 0) { // Type MNIST
|
||||||
// Chargement des images du set de données MNIST
|
// Chargement des images du set de données MNIST
|
||||||
int* parameters = read_mnist_images_parameters(images_file);
|
int* parameters = read_mnist_images_parameters(images_file);
|
||||||
@ -148,22 +149,28 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
nb_images_total = dataset->numImages;
|
nb_images_total = dataset->numImages;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialisation du réseau
|
//* Création du réseau
|
||||||
|
Network* network;
|
||||||
if (!recover) {
|
if (!recover) {
|
||||||
// Le nouveau TA calculé à partir du loss est majoré par 0.75*TA
|
network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_dim, input_depth);
|
||||||
network = create_network_lenet5(LEARNING_RATE*0.75, 0, TANH, GLOROT, input_dim, input_depth);
|
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_dim, input_depth);
|
||||||
//network = create_simple_one(LEARNING_RATE*0.75, 0, RELU, GLOROT, input_dim, input_depth);
|
|
||||||
} else {
|
} else {
|
||||||
network = read_network(recover);
|
network = read_network(recover);
|
||||||
network->learning_rate = LEARNING_RATE;
|
network->learning_rate = LEARNING_RATE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
shuffle_index = (int*)malloc(sizeof(int)*nb_images_total);
|
shuffle_index[i] contient le nouvel index de l'élément à l'emplacement i avant mélange
|
||||||
|
Cela permet de réordonner le jeu d'apprentissage pour éviter certains biais
|
||||||
|
qui pourraient provenir de l'ordre établi.
|
||||||
|
*/
|
||||||
|
int* shuffle_index = (int*)malloc(sizeof(int)*nb_images_total);
|
||||||
for (int i=0; i < nb_images_total; i++) {
|
for (int i=0; i < nb_images_total; i++) {
|
||||||
shuffle_index[i] = i;
|
shuffle_index[i] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//* Création des paramètres d'entrée de train_thread
|
||||||
#ifdef USE_MULTITHREADING
|
#ifdef USE_MULTITHREADING
|
||||||
int nb_remaining_images; // Nombre d'images restantes à lancer pour une série de threads
|
int nb_remaining_images; // Nombre d'images restantes à lancer pour une série de threads
|
||||||
// Récupération du nombre de threads disponibles
|
// Récupération du nombre de threads disponibles
|
||||||
@ -221,6 +228,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
train_params->nb_images = BATCHES;
|
train_params->nb_images = BATCHES;
|
||||||
train_params->index = shuffle_index;
|
train_params->index = shuffle_index;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
end_time = omp_get_wtime();
|
end_time = omp_get_wtime();
|
||||||
|
|
||||||
elapsed_time = end_time - start_time;
|
elapsed_time = end_time - start_time;
|
||||||
@ -229,6 +237,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
printf_time(elapsed_time);
|
printf_time(elapsed_time);
|
||||||
printf("\n\n");
|
printf("\n\n");
|
||||||
|
|
||||||
|
//* Boucle d'apprentissage
|
||||||
for (int i=0; i < epochs; i++) {
|
for (int i=0; i < epochs; i++) {
|
||||||
|
|
||||||
start_time = omp_get_wtime();
|
start_time = omp_get_wtime();
|
||||||
@ -320,10 +329,8 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: " YELLOW "%0.4f%%" RESET, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: " YELLOW "%0.4f%%" RESET, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
#endif
|
#endif
|
||||||
// Il serait intéressant d'utiliser la perte calculée pour
|
|
||||||
// savoir l'avancement dans l'apprentissage et donc comment adapter le taux d'apprentissage
|
|
||||||
network->learning_rate = LEARNING_RATE*log(batch_loss+1);
|
|
||||||
}
|
}
|
||||||
|
//* Fin d'une époque: affichage des résultats et sauvegarde du réseau
|
||||||
end_time = omp_get_wtime();
|
end_time = omp_get_wtime();
|
||||||
elapsed_time = end_time - start_time;
|
elapsed_time = end_time - start_time;
|
||||||
#ifdef USE_MULTITHREADING
|
#ifdef USE_MULTITHREADING
|
||||||
@ -338,8 +345,12 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
write_network(out, network);
|
write_network(out, network);
|
||||||
// If you want to test the network between each epoch, uncomment the following line:
|
// If you want to test the network between each epoch, uncomment the following line:
|
||||||
//test_network(0, out, "data/mnist/t10k-images-idx3-ubyte", "data/mnist/t10k-labels-idx1-ubyte", NULL, false);
|
//test_network(0, out, "data/mnist/t10k-images-idx3-ubyte", "data/mnist/t10k-labels-idx1-ubyte", NULL, false);
|
||||||
|
|
||||||
|
// Learning Rate decay
|
||||||
|
network->learning_rate -= LEARNING_RATE*(1./(float)(epochs+1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//* Fin de l'algo
|
||||||
// To generate a new neural and compare performances with scripts/benchmark_binary.py
|
// To generate a new neural and compare performances with scripts/benchmark_binary.py
|
||||||
if (epochs == 0) {
|
if (epochs == 0) {
|
||||||
write_network(out, network);
|
write_network(out, network);
|
||||||
|
Loading…
Reference in New Issue
Block a user