mirror of
https://github.com/augustin64/projet-tipe
synced 2025-01-23 23:26:25 +01:00
Add detailed train timings
This commit is contained in:
parent
964687d1b4
commit
5738892142
@ -8,6 +8,8 @@
|
|||||||
#define LEARNING_RATE 3e-4 // Taux d'apprentissage
|
#define LEARNING_RATE 3e-4 // Taux d'apprentissage
|
||||||
#define USE_MULTITHREADING // Commenter pour utiliser un seul coeur durant l'apprentissage (meilleur pour des tailles de batchs traités rapidement)
|
#define USE_MULTITHREADING // Commenter pour utiliser un seul coeur durant l'apprentissage (meilleur pour des tailles de batchs traités rapidement)
|
||||||
|
|
||||||
|
//#define DETAILED_TRAIN_TIMINGS // Afficher le temps de forward/ backward
|
||||||
|
|
||||||
//* Paramètres d'ADAM optimizer
|
//* Paramètres d'ADAM optimizer
|
||||||
#define ALPHA 3e-4
|
#define ALPHA 3e-4
|
||||||
#define BETA_1 0.9
|
#define BETA_1 0.9
|
||||||
|
@ -62,6 +62,10 @@ void* train_thread(void* parameters) {
|
|||||||
float accuracy = 0.;
|
float accuracy = 0.;
|
||||||
float loss = 0.;
|
float loss = 0.;
|
||||||
|
|
||||||
|
#ifdef DETAILED_TRAIN_TIMINGS
|
||||||
|
double start_time;
|
||||||
|
#endif
|
||||||
|
|
||||||
pthread_t tid;
|
pthread_t tid;
|
||||||
LoadImageParameters* load_image_param = (LoadImageParameters*)malloc(sizeof(LoadImageParameters));
|
LoadImageParameters* load_image_param = (LoadImageParameters*)malloc(sizeof(LoadImageParameters));
|
||||||
if (dataset_type != 0) {
|
if (dataset_type != 0) {
|
||||||
@ -74,7 +78,20 @@ void* train_thread(void* parameters) {
|
|||||||
for (int i=start; i < start+nb_images; i++) {
|
for (int i=start; i < start+nb_images; i++) {
|
||||||
if (dataset_type == 0) {
|
if (dataset_type == 0) {
|
||||||
write_image_in_network_32(images[index[i]], height, width, network->input[0][0], true);
|
write_image_in_network_32(images[index[i]], height, width, network->input[0][0], true);
|
||||||
|
|
||||||
|
#ifdef DETAILED_TRAIN_TIMINGS
|
||||||
|
start_time = omp_get_wtime();
|
||||||
|
#endif
|
||||||
|
|
||||||
forward_propagation(network);
|
forward_propagation(network);
|
||||||
|
|
||||||
|
#ifdef DETAILED_TRAIN_TIMINGS
|
||||||
|
printf("Temps de forward: ");
|
||||||
|
printf_time(omp_get_wtime() - start_time);
|
||||||
|
printf("\n");
|
||||||
|
start_time = omp_get_wtime();
|
||||||
|
#endif
|
||||||
|
|
||||||
maxi = indice_max(network->input[network->size-1][0][0], 10);
|
maxi = indice_max(network->input[network->size-1][0][0], 10);
|
||||||
if (maxi == -1) {
|
if (maxi == -1) {
|
||||||
printf("\n");
|
printf("\n");
|
||||||
@ -88,6 +105,13 @@ void* train_thread(void* parameters) {
|
|||||||
|
|
||||||
backward_propagation(network, labels[index[i]]);
|
backward_propagation(network, labels[index[i]]);
|
||||||
|
|
||||||
|
#ifdef DETAILED_TRAIN_TIMINGS
|
||||||
|
printf("Temps de backward: ");
|
||||||
|
printf_time(omp_get_wtime() - start_time);
|
||||||
|
printf("\n");
|
||||||
|
start_time = omp_get_wtime();
|
||||||
|
#endif
|
||||||
|
|
||||||
if (maxi == labels[index[i]]) {
|
if (maxi == labels[index[i]]) {
|
||||||
accuracy += 1.;
|
accuracy += 1.;
|
||||||
}
|
}
|
||||||
@ -104,10 +128,31 @@ void* train_thread(void* parameters) {
|
|||||||
pthread_create(&tid, NULL, load_image, (void*) load_image_param);
|
pthread_create(&tid, NULL, load_image, (void*) load_image_param);
|
||||||
}
|
}
|
||||||
write_256_image_in_network(param->dataset->images[index[i]], width, param->dataset->numComponents, network->width[0], network->input[0]);
|
write_256_image_in_network(param->dataset->images[index[i]], width, param->dataset->numComponents, network->width[0], network->input[0]);
|
||||||
|
|
||||||
|
#ifdef DETAILED_TRAIN_TIMINGS
|
||||||
|
start_time = omp_get_wtime();
|
||||||
|
#endif
|
||||||
|
|
||||||
forward_propagation(network);
|
forward_propagation(network);
|
||||||
|
|
||||||
|
#ifdef DETAILED_TRAIN_TIMINGS
|
||||||
|
printf("Temps de forward: ");
|
||||||
|
printf_time(omp_get_wtime() - start_time);
|
||||||
|
printf("\n");
|
||||||
|
start_time = omp_get_wtime();
|
||||||
|
#endif
|
||||||
|
|
||||||
maxi = indice_max(network->input[network->size-1][0][0], param->dataset->numCategories);
|
maxi = indice_max(network->input[network->size-1][0][0], param->dataset->numCategories);
|
||||||
backward_propagation(network, param->dataset->labels[index[i]]);
|
backward_propagation(network, param->dataset->labels[index[i]]);
|
||||||
|
|
||||||
|
#ifdef DETAILED_TRAIN_TIMINGS
|
||||||
|
printf("Temps de backward: ");
|
||||||
|
printf_time(omp_get_wtime() - start_time);
|
||||||
|
printf("\n");
|
||||||
|
start_time = omp_get_wtime();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
if (maxi == (int)param->dataset->labels[index[i]]) {
|
if (maxi == (int)param->dataset->labels[index[i]]) {
|
||||||
accuracy += 1.;
|
accuracy += 1.;
|
||||||
}
|
}
|
||||||
@ -160,9 +205,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
int nb_images_total_remaining; // Images restantes dans un batch
|
int nb_images_total_remaining; // Images restantes dans un batch
|
||||||
int batches_epoques; // Batches par époque
|
int batches_epoques; // Batches par époque
|
||||||
|
|
||||||
int*** images; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
|
int*** images = NULL; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
|
||||||
unsigned int* labels; // Labels associés aux images du dataset MNIST
|
unsigned int* labels = NULL; // Labels associés aux images du dataset MNIST
|
||||||
jpegDataset* dataset; // Structure de données décrivant un dataset d'images jpeg
|
jpegDataset* dataset = NULL; // Structure de données décrivant un dataset d'images jpeg
|
||||||
if (dataset_type == 0) { // Type MNIST
|
if (dataset_type == 0) { // Type MNIST
|
||||||
// Chargement des images du set de données MNIST
|
// Chargement des images du set de données MNIST
|
||||||
int* parameters = read_mnist_images_parameters(images_file);
|
int* parameters = read_mnist_images_parameters(images_file);
|
||||||
@ -186,10 +231,14 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
Network* network;
|
Network* network;
|
||||||
if (!recover) {
|
if (!recover) {
|
||||||
if (dataset_type == 0) {
|
if (dataset_type == 0) {
|
||||||
network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_width, input_depth);
|
network = create_network_lenet5(LEARNING_RATE, 0, LEAKY_RELU, HE, input_width, input_depth);
|
||||||
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth);
|
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth);
|
||||||
} else {
|
} else {
|
||||||
network = create_network_VGG16(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, dataset->numCategories);
|
network = create_network_VGG16(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, dataset->numCategories);
|
||||||
|
|
||||||
|
#ifdef USE_MULTITHREADING
|
||||||
|
printf_warning("Utilisation de VGG16 avec multithreading. La quantité de RAM utilisée peut devenir excessive");
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
network = read_network(recover);
|
network = read_network(recover);
|
||||||
|
Loading…
Reference in New Issue
Block a user