Add detailed train timings

This commit is contained in:
augustin64 2023-05-19 21:48:08 +02:00
parent 964687d1b4
commit 5738892142
2 changed files with 55 additions and 4 deletions

View File

@ -8,6 +8,8 @@
#define LEARNING_RATE 3e-4 // Taux d'apprentissage
#define USE_MULTITHREADING // Commenter pour utiliser un seul coeur durant l'apprentissage (meilleur pour des tailles de batchs traités rapidement)
//#define DETAILED_TRAIN_TIMINGS // Afficher le temps de forward/ backward
//* Paramètres d'ADAM optimizer
#define ALPHA 3e-4
#define BETA_1 0.9

View File

@ -62,6 +62,10 @@ void* train_thread(void* parameters) {
float accuracy = 0.;
float loss = 0.;
#ifdef DETAILED_TRAIN_TIMINGS
double start_time;
#endif
pthread_t tid;
LoadImageParameters* load_image_param = (LoadImageParameters*)malloc(sizeof(LoadImageParameters));
if (dataset_type != 0) {
@ -74,7 +78,20 @@ void* train_thread(void* parameters) {
for (int i=start; i < start+nb_images; i++) {
if (dataset_type == 0) {
write_image_in_network_32(images[index[i]], height, width, network->input[0][0], true);
#ifdef DETAILED_TRAIN_TIMINGS
start_time = omp_get_wtime();
#endif
forward_propagation(network);
#ifdef DETAILED_TRAIN_TIMINGS
printf("Temps de forward: ");
printf_time(omp_get_wtime() - start_time);
printf("\n");
start_time = omp_get_wtime();
#endif
maxi = indice_max(network->input[network->size-1][0][0], 10);
if (maxi == -1) {
printf("\n");
@ -88,6 +105,13 @@ void* train_thread(void* parameters) {
backward_propagation(network, labels[index[i]]);
#ifdef DETAILED_TRAIN_TIMINGS
printf("Temps de backward: ");
printf_time(omp_get_wtime() - start_time);
printf("\n");
start_time = omp_get_wtime();
#endif
if (maxi == labels[index[i]]) {
accuracy += 1.;
}
@ -104,9 +128,30 @@ void* train_thread(void* parameters) {
pthread_create(&tid, NULL, load_image, (void*) load_image_param);
}
write_256_image_in_network(param->dataset->images[index[i]], width, param->dataset->numComponents, network->width[0], network->input[0]);
#ifdef DETAILED_TRAIN_TIMINGS
start_time = omp_get_wtime();
#endif
forward_propagation(network);
#ifdef DETAILED_TRAIN_TIMINGS
printf("Temps de forward: ");
printf_time(omp_get_wtime() - start_time);
printf("\n");
start_time = omp_get_wtime();
#endif
maxi = indice_max(network->input[network->size-1][0][0], param->dataset->numCategories);
backward_propagation(network, param->dataset->labels[index[i]]);
#ifdef DETAILED_TRAIN_TIMINGS
printf("Temps de backward: ");
printf_time(omp_get_wtime() - start_time);
printf("\n");
start_time = omp_get_wtime();
#endif
if (maxi == (int)param->dataset->labels[index[i]]) {
accuracy += 1.;
@ -160,9 +205,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
int nb_images_total_remaining; // Images restantes dans un batch
int batches_epoques; // Batches par époque
int*** images; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
unsigned int* labels; // Labels associés aux images du dataset MNIST
jpegDataset* dataset; // Structure de données décrivant un dataset d'images jpeg
int*** images = NULL; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
unsigned int* labels = NULL; // Labels associés aux images du dataset MNIST
jpegDataset* dataset = NULL; // Structure de données décrivant un dataset d'images jpeg
if (dataset_type == 0) { // Type MNIST
// Chargement des images du set de données MNIST
int* parameters = read_mnist_images_parameters(images_file);
@ -186,10 +231,14 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
Network* network;
if (!recover) {
if (dataset_type == 0) {
network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_width, input_depth);
network = create_network_lenet5(LEARNING_RATE, 0, LEAKY_RELU, HE, input_width, input_depth);
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth);
} else {
network = create_network_VGG16(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, dataset->numCategories);
#ifdef USE_MULTITHREADING
printf_warning("Utilisation de VGG16 avec multithreading. La quantité de RAM utilisée peut devenir excessive");
#endif
}
} else {
network = read_network(recover);