diff --git a/src/cnn/include/config.h b/src/cnn/include/config.h index 97f5ed4..89789c1 100644 --- a/src/cnn/include/config.h +++ b/src/cnn/include/config.h @@ -8,6 +8,8 @@ #define LEARNING_RATE 3e-4 // Taux d'apprentissage #define USE_MULTITHREADING // Commenter pour utiliser un seul coeur durant l'apprentissage (meilleur pour des tailles de batchs traités rapidement) +//#define DETAILED_TRAIN_TIMINGS // Afficher le temps de forward/ backward + //* Paramètres d'ADAM optimizer #define ALPHA 3e-4 #define BETA_1 0.9 diff --git a/src/cnn/train.c b/src/cnn/train.c index 72cd5a7..248284b 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -62,6 +62,10 @@ void* train_thread(void* parameters) { float accuracy = 0.; float loss = 0.; + #ifdef DETAILED_TRAIN_TIMINGS + double start_time; + #endif + pthread_t tid; LoadImageParameters* load_image_param = (LoadImageParameters*)malloc(sizeof(LoadImageParameters)); if (dataset_type != 0) { @@ -74,7 +78,20 @@ void* train_thread(void* parameters) { for (int i=start; i < start+nb_images; i++) { if (dataset_type == 0) { write_image_in_network_32(images[index[i]], height, width, network->input[0][0], true); + + #ifdef DETAILED_TRAIN_TIMINGS + start_time = omp_get_wtime(); + #endif + forward_propagation(network); + + #ifdef DETAILED_TRAIN_TIMINGS + printf("Temps de forward: "); + printf_time(omp_get_wtime() - start_time); + printf("\n"); + start_time = omp_get_wtime(); + #endif + maxi = indice_max(network->input[network->size-1][0][0], 10); if (maxi == -1) { printf("\n"); @@ -88,6 +105,13 @@ void* train_thread(void* parameters) { backward_propagation(network, labels[index[i]]); + #ifdef DETAILED_TRAIN_TIMINGS + printf("Temps de backward: "); + printf_time(omp_get_wtime() - start_time); + printf("\n"); + start_time = omp_get_wtime(); + #endif + if (maxi == labels[index[i]]) { accuracy += 1.; } @@ -104,9 +128,30 @@ void* train_thread(void* parameters) { pthread_create(&tid, NULL, load_image, (void*) load_image_param); } write_256_image_in_network(param->dataset->images[index[i]], width, param->dataset->numComponents, network->width[0], network->input[0]); + + #ifdef DETAILED_TRAIN_TIMINGS + start_time = omp_get_wtime(); + #endif + forward_propagation(network); + + #ifdef DETAILED_TRAIN_TIMINGS + printf("Temps de forward: "); + printf_time(omp_get_wtime() - start_time); + printf("\n"); + start_time = omp_get_wtime(); + #endif + maxi = indice_max(network->input[network->size-1][0][0], param->dataset->numCategories); backward_propagation(network, param->dataset->labels[index[i]]); + + #ifdef DETAILED_TRAIN_TIMINGS + printf("Temps de backward: "); + printf_time(omp_get_wtime() - start_time); + printf("\n"); + start_time = omp_get_wtime(); + #endif + if (maxi == (int)param->dataset->labels[index[i]]) { accuracy += 1.; @@ -160,9 +205,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di int nb_images_total_remaining; // Images restantes dans un batch int batches_epoques; // Batches par époque - int*** images; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST) - unsigned int* labels; // Labels associés aux images du dataset MNIST - jpegDataset* dataset; // Structure de données décrivant un dataset d'images jpeg + int*** images = NULL; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST) + unsigned int* labels = NULL; // Labels associés aux images du dataset MNIST + jpegDataset* dataset = NULL; // Structure de données décrivant un dataset d'images jpeg if (dataset_type == 0) { // Type MNIST // Chargement des images du set de données MNIST int* parameters = read_mnist_images_parameters(images_file); @@ -186,10 +231,14 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di Network* network; if (!recover) { if (dataset_type == 0) { - network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_width, input_depth); + network = create_network_lenet5(LEARNING_RATE, 0, LEAKY_RELU, HE, input_width, input_depth); //network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth); } else { network = create_network_VGG16(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, dataset->numCategories); + + #ifdef USE_MULTITHREADING + printf_warning("Utilisation de VGG16 avec multithreading. La quantité de RAM utilisée peut devenir excessive"); + #endif } } else { network = read_network(recover);