mirror of
https://github.com/augustin64/projet-tipe
synced 2025-01-23 15:16:26 +01:00
Trying to improve train.c readability
This commit is contained in:
parent
7d7cd2e3a7
commit
81ff4f4d00
155
src/cnn/train.c
155
src/cnn/train.c
@ -104,15 +104,25 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
||||
}
|
||||
#endif
|
||||
srand(time(NULL));
|
||||
Network* network;
|
||||
int input_dim = -1;
|
||||
int input_depth = -1;
|
||||
|
||||
float loss;
|
||||
float batch_loss; // May be redundant with loss, but gives more informations
|
||||
float accuracy;
|
||||
float current_accuracy;
|
||||
|
||||
|
||||
//* Différents timers pour mesurer les performance en terme de vitesse
|
||||
double start_time, end_time;
|
||||
double elapsed_time;
|
||||
|
||||
double algo_start = omp_get_wtime();
|
||||
|
||||
start_time = omp_get_wtime();
|
||||
|
||||
|
||||
//* Chargement du dataset
|
||||
int input_dim = -1;
|
||||
int input_depth = -1;
|
||||
|
||||
int nb_images_total; // Images au total
|
||||
int nb_images_total_remaining; // Images restantes dans un batch
|
||||
int batches_epoques; // Batches par époque
|
||||
@ -120,15 +130,6 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
||||
int*** images; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
|
||||
unsigned int* labels; // Labels associés aux images du dataset MNIST
|
||||
jpegDataset* dataset; // Structure de données décrivant un dataset d'images jpeg
|
||||
int* shuffle_index; // shuffle_index[i] contient le nouvel index de l'élément à l'emplacement i avant mélange
|
||||
|
||||
double start_time, end_time;
|
||||
double elapsed_time;
|
||||
|
||||
double algo_start = omp_get_wtime();
|
||||
|
||||
start_time = omp_get_wtime();
|
||||
|
||||
if (dataset_type == 0) { // Type MNIST
|
||||
// Chargement des images du set de données MNIST
|
||||
int* parameters = read_mnist_images_parameters(images_file);
|
||||
@ -148,79 +149,86 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
||||
nb_images_total = dataset->numImages;
|
||||
}
|
||||
|
||||
// Initialisation du réseau
|
||||
//* Création du réseau
|
||||
Network* network;
|
||||
if (!recover) {
|
||||
// Le nouveau TA calculé à partir du loss est majoré par 0.75*TA
|
||||
network = create_network_lenet5(LEARNING_RATE*0.75, 0, TANH, GLOROT, input_dim, input_depth);
|
||||
//network = create_simple_one(LEARNING_RATE*0.75, 0, RELU, GLOROT, input_dim, input_depth);
|
||||
network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_dim, input_depth);
|
||||
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_dim, input_depth);
|
||||
} else {
|
||||
network = read_network(recover);
|
||||
network->learning_rate = LEARNING_RATE;
|
||||
}
|
||||
|
||||
|
||||
shuffle_index = (int*)malloc(sizeof(int)*nb_images_total);
|
||||
/*
|
||||
shuffle_index[i] contient le nouvel index de l'élément à l'emplacement i avant mélange
|
||||
Cela permet de réordonner le jeu d'apprentissage pour éviter certains biais
|
||||
qui pourraient provenir de l'ordre établi.
|
||||
*/
|
||||
int* shuffle_index = (int*)malloc(sizeof(int)*nb_images_total);
|
||||
for (int i=0; i < nb_images_total; i++) {
|
||||
shuffle_index[i] = i;
|
||||
}
|
||||
|
||||
|
||||
//* Création des paramètres d'entrée de train_thread
|
||||
#ifdef USE_MULTITHREADING
|
||||
int nb_remaining_images; // Nombre d'images restantes à lancer pour une série de threads
|
||||
// Récupération du nombre de threads disponibles
|
||||
int nb_threads = get_nprocs();
|
||||
pthread_t *tid = (pthread_t*)malloc(nb_threads * sizeof(pthread_t));
|
||||
int nb_remaining_images; // Nombre d'images restantes à lancer pour une série de threads
|
||||
// Récupération du nombre de threads disponibles
|
||||
int nb_threads = get_nprocs();
|
||||
pthread_t *tid = (pthread_t*)malloc(nb_threads * sizeof(pthread_t));
|
||||
|
||||
// Création des paramètres donnés à chaque thread dans le cas du multi-threading
|
||||
TrainParameters** train_parameters = (TrainParameters**)malloc(sizeof(TrainParameters*)*nb_threads);
|
||||
TrainParameters* param;
|
||||
bool* thread_used = (bool*)malloc(sizeof(bool)*nb_threads);
|
||||
// Création des paramètres donnés à chaque thread dans le cas du multi-threading
|
||||
TrainParameters** train_parameters = (TrainParameters**)malloc(sizeof(TrainParameters*)*nb_threads);
|
||||
TrainParameters* param;
|
||||
bool* thread_used = (bool*)malloc(sizeof(bool)*nb_threads);
|
||||
|
||||
for (int k=0; k < nb_threads; k++) {
|
||||
train_parameters[k] = (TrainParameters*)malloc(sizeof(TrainParameters));
|
||||
param = train_parameters[k];
|
||||
param->dataset_type = dataset_type;
|
||||
if (dataset_type == 0) {
|
||||
param->images = images;
|
||||
param->labels = labels;
|
||||
param->dataset = NULL;
|
||||
param->width = 28;
|
||||
param->height = 28;
|
||||
} else {
|
||||
param->dataset = dataset;
|
||||
param->width = dataset->width;
|
||||
param->height = dataset->height;
|
||||
param->images = NULL;
|
||||
param->labels = NULL;
|
||||
for (int k=0; k < nb_threads; k++) {
|
||||
train_parameters[k] = (TrainParameters*)malloc(sizeof(TrainParameters));
|
||||
param = train_parameters[k];
|
||||
param->dataset_type = dataset_type;
|
||||
if (dataset_type == 0) {
|
||||
param->images = images;
|
||||
param->labels = labels;
|
||||
param->dataset = NULL;
|
||||
param->width = 28;
|
||||
param->height = 28;
|
||||
} else {
|
||||
param->dataset = dataset;
|
||||
param->width = dataset->width;
|
||||
param->height = dataset->height;
|
||||
param->images = NULL;
|
||||
param->labels = NULL;
|
||||
}
|
||||
param->nb_images = BATCHES / nb_threads;
|
||||
param->index = shuffle_index;
|
||||
param->network = copy_network(network);
|
||||
}
|
||||
param->nb_images = BATCHES / nb_threads;
|
||||
param->index = shuffle_index;
|
||||
param->network = copy_network(network);
|
||||
}
|
||||
#else
|
||||
// Création des paramètres donnés à l'unique
|
||||
// thread dans l'hypothèse ou le multi-threading n'est pas utilisé.
|
||||
// Cela est utile à des fins de débogage notamment,
|
||||
// où l'utilisation de threads rend vite les choses plus compliquées qu'elles ne le sont.
|
||||
TrainParameters* train_params = (TrainParameters*)malloc(sizeof(TrainParameters));
|
||||
// Création des paramètres donnés à l'unique
|
||||
// thread dans l'hypothèse ou le multi-threading n'est pas utilisé.
|
||||
// Cela est utile à des fins de débogage notamment,
|
||||
// où l'utilisation de threads rend vite les choses plus compliquées qu'elles ne le sont.
|
||||
TrainParameters* train_params = (TrainParameters*)malloc(sizeof(TrainParameters));
|
||||
|
||||
train_params->network = network;
|
||||
train_params->dataset_type = dataset_type;
|
||||
if (dataset_type == 0) {
|
||||
train_params->images = images;
|
||||
train_params->labels = labels;
|
||||
train_params->width = 28;
|
||||
train_params->height = 28;
|
||||
train_params->dataset = NULL;
|
||||
} else {
|
||||
train_params->dataset = dataset;
|
||||
train_params->width = dataset->width;
|
||||
train_params->height = dataset->height;
|
||||
train_params->images = NULL;
|
||||
train_params->labels = NULL;
|
||||
}
|
||||
train_params->nb_images = BATCHES;
|
||||
train_params->index = shuffle_index;
|
||||
train_params->network = network;
|
||||
train_params->dataset_type = dataset_type;
|
||||
if (dataset_type == 0) {
|
||||
train_params->images = images;
|
||||
train_params->labels = labels;
|
||||
train_params->width = 28;
|
||||
train_params->height = 28;
|
||||
train_params->dataset = NULL;
|
||||
} else {
|
||||
train_params->dataset = dataset;
|
||||
train_params->width = dataset->width;
|
||||
train_params->height = dataset->height;
|
||||
train_params->images = NULL;
|
||||
train_params->labels = NULL;
|
||||
}
|
||||
train_params->nb_images = BATCHES;
|
||||
train_params->index = shuffle_index;
|
||||
#endif
|
||||
|
||||
end_time = omp_get_wtime();
|
||||
|
||||
elapsed_time = end_time - start_time;
|
||||
@ -229,6 +237,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
||||
printf_time(elapsed_time);
|
||||
printf("\n\n");
|
||||
|
||||
//* Boucle d'apprentissage
|
||||
for (int i=0; i < epochs; i++) {
|
||||
|
||||
start_time = omp_get_wtime();
|
||||
@ -320,10 +329,8 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
||||
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: " YELLOW "%0.4f%%" RESET, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
||||
fflush(stdout);
|
||||
#endif
|
||||
// Il serait intéressant d'utiliser la perte calculée pour
|
||||
// savoir l'avancement dans l'apprentissage et donc comment adapter le taux d'apprentissage
|
||||
network->learning_rate = LEARNING_RATE*log(batch_loss+1);
|
||||
}
|
||||
//* Fin d'une époque: affichage des résultats et sauvegarde du réseau
|
||||
end_time = omp_get_wtime();
|
||||
elapsed_time = end_time - start_time;
|
||||
#ifdef USE_MULTITHREADING
|
||||
@ -338,8 +345,12 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
||||
write_network(out, network);
|
||||
// If you want to test the network between each epoch, uncomment the following line:
|
||||
//test_network(0, out, "data/mnist/t10k-images-idx3-ubyte", "data/mnist/t10k-labels-idx1-ubyte", NULL, false);
|
||||
|
||||
// Learning Rate decay
|
||||
network->learning_rate -= LEARNING_RATE*(1./(float)(epochs+1));
|
||||
}
|
||||
|
||||
//* Fin de l'algo
|
||||
// To generate a new neural and compare performances with scripts/benchmark_binary.py
|
||||
if (epochs == 0) {
|
||||
write_network(out, network);
|
||||
|
Loading…
Reference in New Issue
Block a user