2023-01-28 13:09:52 +01:00
# include <sys/sysinfo.h>
# include <pthread.h>
2022-10-01 17:53:14 +02:00
# include <stdlib.h>
# include <stdio.h>
# include <float.h>
2023-01-28 13:09:52 +01:00
# include <math.h>
2022-11-16 10:38:01 +01:00
# include <time.h>
2023-01-13 15:58:11 +01:00
# include <omp.h>
2022-10-01 17:53:14 +02:00
2023-05-12 16:16:34 +02:00
# include "../common/include/memory_management.h"
2023-05-15 18:25:29 +02:00
# include "../common/include/colors.h"
# include "../common/include/utils.h"
2023-05-12 16:16:34 +02:00
# include "../common/include/mnist.h"
2022-10-24 12:54:51 +02:00
# include "include/initialisation.h"
2023-02-24 14:36:48 +01:00
# include "include/test_network.h"
2022-10-24 12:54:51 +02:00
# include "include/neuron_io.h"
# include "include/function.h"
2022-11-15 12:50:38 +01:00
# include "include/update.h"
2023-05-15 18:25:29 +02:00
# include "include/models.h"
2022-10-24 12:54:51 +02:00
# include "include/utils.h"
# include "include/free.h"
2022-11-19 16:09:07 +01:00
# include "include/jpeg.h"
2022-10-24 12:54:51 +02:00
# include "include/cnn.h"
2022-10-01 17:53:14 +02:00
# include "include/train.h"
2022-11-25 15:17:47 +01:00
int div_up ( int a , int b ) { // Partie entière supérieure de a/b
return ( ( a % b ) ! = 0 ) ? ( a / b + 1 ) : ( a / b ) ;
2022-11-03 18:13:01 +01:00
}
2023-04-02 17:34:31 +02:00
void * load_image ( void * parameters ) {
LoadImageParameters * param = ( LoadImageParameters * ) parameters ;
if ( ! param - > dataset - > images [ param - > index ] ) {
imgRawImage * image = loadJpegImageFile ( param - > dataset - > fileNames [ param - > index ] ) ;
param - > dataset - > images [ param - > index ] = image - > lpData ;
free ( image ) ;
} else {
printf_warning ( ( char * ) " Image déjà chargée \n " ) ; // Pas possible techniquement, donc on met un warning
}
return NULL ;
}
2022-11-03 18:13:01 +01:00
2022-10-01 17:53:14 +02:00
void * train_thread ( void * parameters ) {
TrainParameters * param = ( TrainParameters * ) parameters ;
Network * network = param - > network ;
2022-11-19 16:09:07 +01:00
imgRawImage * image ;
2022-11-03 18:13:01 +01:00
int maxi ;
2022-10-01 17:53:14 +02:00
int * * * images = param - > images ;
2022-10-24 12:54:51 +02:00
int * labels = ( int * ) param - > labels ;
2022-12-07 10:44:28 +01:00
int * index = param - > index ;
2022-10-01 17:53:14 +02:00
int width = param - > width ;
int height = param - > height ;
int dataset_type = param - > dataset_type ;
int start = param - > start ;
int nb_images = param - > nb_images ;
2023-05-25 16:32:37 +02:00
int finetuning = param - > finetuning ;
2023-01-20 13:41:38 +01:00
float * wanted_output ;
2022-10-01 17:53:14 +02:00
float accuracy = 0. ;
2023-01-20 13:41:38 +01:00
float loss = 0. ;
2023-05-19 21:48:08 +02:00
# ifdef DETAILED_TRAIN_TIMINGS
double start_time ;
# endif
2023-04-02 17:34:31 +02:00
pthread_t tid ;
2023-05-13 10:05:54 +02:00
LoadImageParameters * load_image_param = ( LoadImageParameters * ) malloc ( sizeof ( LoadImageParameters ) ) ;
2023-04-02 17:34:31 +02:00
if ( dataset_type ! = 0 ) {
load_image_param - > dataset = param - > dataset ;
load_image_param - > index = index [ start ] ;
pthread_create ( & tid , NULL , load_image , ( void * ) load_image_param ) ;
}
2022-10-01 17:53:14 +02:00
for ( int i = start ; i < start + nb_images ; i + + ) {
if ( dataset_type = = 0 ) {
2023-05-25 13:31:55 +02:00
write_image_in_network_32 ( images [ index [ i ] ] , height , width , network - > input [ 0 ] [ 0 ] , param - > offset ) ;
2023-05-19 21:48:08 +02:00
# ifdef DETAILED_TRAIN_TIMINGS
start_time = omp_get_wtime ( ) ;
# endif
2022-10-07 14:26:36 +02:00
forward_propagation ( network ) ;
2023-05-19 21:48:08 +02:00
# ifdef DETAILED_TRAIN_TIMINGS
printf ( " Temps de forward: " ) ;
printf_time ( omp_get_wtime ( ) - start_time ) ;
printf ( " \n " ) ;
start_time = omp_get_wtime ( ) ;
# endif
2022-11-15 12:58:00 +01:00
maxi = indice_max ( network - > input [ network - > size - 1 ] [ 0 ] [ 0 ] , 10 ) ;
2023-01-21 18:59:59 +01:00
if ( maxi = = - 1 ) {
printf ( " \n " ) ;
2023-02-19 15:01:58 +01:00
printf_error ( ( char * ) " Le réseau sature. \n " ) ;
2023-01-21 18:59:59 +01:00
exit ( 1 ) ;
}
2023-01-20 13:41:38 +01:00
wanted_output = generate_wanted_output ( labels [ index [ i ] ] , 10 ) ;
loss + = compute_mean_squared_error ( network - > input [ network - > size - 1 ] [ 0 ] [ 0 ] , wanted_output , 10 ) ;
2023-05-15 10:44:09 +02:00
gree ( wanted_output , false ) ;
2023-01-20 13:41:38 +01:00
2023-05-25 16:32:37 +02:00
backward_propagation ( network , labels [ index [ i ] ] , finetuning ) ;
2022-11-16 10:38:01 +01:00
2023-05-19 21:48:08 +02:00
# ifdef DETAILED_TRAIN_TIMINGS
printf ( " Temps de backward: " ) ;
printf_time ( omp_get_wtime ( ) - start_time ) ;
printf ( " \n " ) ;
start_time = omp_get_wtime ( ) ;
# endif
2022-12-07 10:44:28 +01:00
if ( maxi = = labels [ index [ i ] ] ) {
2022-11-03 18:13:01 +01:00
accuracy + = 1. ;
}
2022-10-01 17:53:14 +02:00
} else {
2023-04-02 17:34:31 +02:00
pthread_join ( tid , NULL ) ;
2022-12-07 10:44:28 +01:00
if ( ! param - > dataset - > images [ index [ i ] ] ) {
image = loadJpegImageFile ( param - > dataset - > fileNames [ index [ i ] ] ) ;
param - > dataset - > images [ index [ i ] ] = image - > lpData ;
2023-02-19 10:22:22 +01:00
free ( image ) ;
2022-11-19 16:09:07 +01:00
}
2023-04-02 17:34:31 +02:00
if ( i ! = start + nb_images - 1 ) {
load_image_param - > index = index [ i + 1 ] ;
pthread_create ( & tid , NULL , load_image , ( void * ) load_image_param ) ;
}
2023-05-15 11:34:23 +02:00
write_256_image_in_network ( param - > dataset - > images [ index [ i ] ] , width , param - > dataset - > numComponents , network - > width [ 0 ] , network - > input [ 0 ] ) ;
2023-05-19 21:48:08 +02:00
# ifdef DETAILED_TRAIN_TIMINGS
start_time = omp_get_wtime ( ) ;
# endif
2022-11-19 16:09:07 +01:00
forward_propagation ( network ) ;
2023-05-19 21:48:08 +02:00
# ifdef DETAILED_TRAIN_TIMINGS
printf ( " Temps de forward: " ) ;
printf_time ( omp_get_wtime ( ) - start_time ) ;
printf ( " \n " ) ;
start_time = omp_get_wtime ( ) ;
# endif
2022-11-19 22:22:24 +01:00
maxi = indice_max ( network - > input [ network - > size - 1 ] [ 0 ] [ 0 ] , param - > dataset - > numCategories ) ;
2023-05-25 16:32:37 +02:00
backward_propagation ( network , param - > dataset - > labels [ index [ i ] ] , finetuning ) ;
2023-05-19 21:48:08 +02:00
# ifdef DETAILED_TRAIN_TIMINGS
printf ( " Temps de backward: " ) ;
printf_time ( omp_get_wtime ( ) - start_time ) ;
printf ( " \n " ) ;
start_time = omp_get_wtime ( ) ;
# endif
2022-11-19 16:09:07 +01:00
2022-12-07 10:44:28 +01:00
if ( maxi = = ( int ) param - > dataset - > labels [ index [ i ] ] ) {
2022-11-19 16:09:07 +01:00
accuracy + = 1. ;
}
2023-02-19 10:22:22 +01:00
free ( param - > dataset - > images [ index [ i ] ] ) ;
2022-12-07 10:44:28 +01:00
param - > dataset - > images [ index [ i ] ] = NULL ;
2022-10-01 17:53:14 +02:00
}
}
2023-05-13 10:05:54 +02:00
free ( load_image_param ) ;
2023-04-02 17:34:31 +02:00
2022-10-01 17:53:14 +02:00
param - > accuracy = accuracy ;
2023-01-20 13:41:38 +01:00
param - > loss = loss ;
2022-10-01 17:53:14 +02:00
return NULL ;
}
2023-05-25 16:32:37 +02:00
void train ( int dataset_type , char * images_file , char * labels_file , char * data_dir , int epochs , char * out , char * recover , bool offset , int finetuning ) {
2023-01-29 09:40:55 +01:00
# ifdef USE_CUDA
2023-05-16 13:21:26 +02:00
bool compatibility = cuda_setup ( true ) ;
2023-01-29 09:40:55 +01:00
if ( ! compatibility ) {
printf ( " Exiting. \n " ) ;
2023-01-29 09:47:29 +01:00
exit ( 1 ) ;
2023-01-29 09:40:55 +01:00
}
# endif
2022-11-16 10:38:01 +01:00
srand ( time ( NULL ) ) ;
2023-01-20 13:41:38 +01:00
float loss ;
float batch_loss ; // May be redundant with loss, but gives more informations
2023-03-22 13:03:19 +01:00
float test_accuracy = 0. ; // Used to decrease Learning rate
2023-05-13 10:05:54 +02:00
( void ) test_accuracy ; // To avoid warnings when not used
2022-10-01 17:53:14 +02:00
float accuracy ;
2023-03-13 13:55:09 +01:00
float batch_accuracy ;
2022-11-15 17:50:33 +01:00
float current_accuracy ;
2022-10-01 17:53:14 +02:00
2023-03-11 13:45:00 +01:00
//* Différents timers pour mesurer les performance en terme de vitesse
2023-01-13 15:58:11 +01:00
double start_time , end_time ;
double elapsed_time ;
double algo_start = omp_get_wtime ( ) ;
start_time = omp_get_wtime ( ) ;
2023-03-11 13:45:00 +01:00
//* Chargement du dataset
2023-05-13 17:22:47 +02:00
int input_width = - 1 ;
2023-03-11 13:45:00 +01:00
int input_depth = - 1 ;
int nb_images_total ; // Images au total
int nb_images_total_remaining ; // Images restantes dans un batch
int batches_epoques ; // Batches par époque
2023-05-19 21:48:08 +02:00
int * * * images = NULL ; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
unsigned int * labels = NULL ; // Labels associés aux images du dataset MNIST
jpegDataset * dataset = NULL ; // Structure de données décrivant un dataset d'images jpeg
2022-10-01 17:53:14 +02:00
if ( dataset_type = = 0 ) { // Type MNIST
// Chargement des images du set de données MNIST
int * parameters = read_mnist_images_parameters ( images_file ) ;
nb_images_total = parameters [ 0 ] ;
2023-02-18 13:10:00 +01:00
free ( parameters ) ;
2022-10-01 17:53:14 +02:00
images = read_mnist_images ( images_file ) ;
labels = read_mnist_labels ( labels_file ) ;
2023-05-13 17:22:47 +02:00
input_width = 32 ;
2022-10-01 17:53:14 +02:00
input_depth = 1 ;
2022-11-19 16:09:07 +01:00
} else { // Type JPG
dataset = loadJpegDataset ( data_dir ) ;
2023-05-13 17:22:47 +02:00
input_width = dataset - > height + 4 ; // image_size + padding
2022-11-19 16:09:07 +01:00
input_depth = dataset - > numComponents ;
2022-10-01 17:53:14 +02:00
2022-11-19 16:09:07 +01:00
nb_images_total = dataset - > numImages ;
2022-10-01 17:53:14 +02:00
}
2023-03-11 13:45:00 +01:00
//* Création du réseau
Network * network ;
2022-12-07 13:09:39 +01:00
if ( ! recover ) {
2023-05-15 10:45:14 +02:00
if ( dataset_type = = 0 ) {
2023-05-19 21:48:08 +02:00
network = create_network_lenet5 ( LEARNING_RATE , 0 , LEAKY_RELU , HE , input_width , input_depth ) ;
2023-05-15 10:45:14 +02:00
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth);
} else {
2023-05-19 23:29:46 +02:00
network = create_network_VGG16 ( LEARNING_RATE , 0 , RELU , HE , dataset - > numCategories ) ;
2023-05-19 21:48:08 +02:00
# ifdef USE_MULTITHREADING
2023-05-19 23:29:46 +02:00
printf_warning ( " Utilisation de VGG16 avec multithreading. La quantité de RAM utilisée peut devenir excessive \n " ) ;
2023-05-19 21:48:08 +02:00
# endif
2023-05-15 10:45:14 +02:00
}
2022-12-07 13:09:39 +01:00
} else {
network = read_network ( recover ) ;
2023-01-20 13:41:38 +01:00
network - > learning_rate = LEARNING_RATE ;
2022-12-07 13:09:39 +01:00
}
2023-01-17 15:34:29 +01:00
2023-03-11 13:45:00 +01:00
/*
shuffle_index [ i ] contient le nouvel index de l ' é lément à l ' emplacement i avant mélange
Cela permet de réordonner le jeu d ' apprentissage pour é viter certains biais
qui pourraient provenir de l ' ordre é tabli .
*/
int * shuffle_index = ( int * ) malloc ( sizeof ( int ) * nb_images_total ) ;
2022-12-07 10:44:28 +01:00
for ( int i = 0 ; i < nb_images_total ; i + + ) {
shuffle_index [ i ] = i ;
}
2022-10-01 17:53:14 +02:00
2023-03-11 13:45:00 +01:00
//* Création des paramètres d'entrée de train_thread
2022-10-01 17:53:14 +02:00
# ifdef USE_MULTITHREADING
2023-03-11 13:45:00 +01:00
int nb_remaining_images ; // Nombre d'images restantes à lancer pour une série de threads
// Récupération du nombre de threads disponibles
int nb_threads = get_nprocs ( ) ;
pthread_t * tid = ( pthread_t * ) malloc ( nb_threads * sizeof ( pthread_t ) ) ;
// Création des paramètres donnés à chaque thread dans le cas du multi-threading
TrainParameters * * train_parameters = ( TrainParameters * * ) malloc ( sizeof ( TrainParameters * ) * nb_threads ) ;
TrainParameters * param ;
bool * thread_used = ( bool * ) malloc ( sizeof ( bool ) * nb_threads ) ;
for ( int k = 0 ; k < nb_threads ; k + + ) {
train_parameters [ k ] = ( TrainParameters * ) malloc ( sizeof ( TrainParameters ) ) ;
param = train_parameters [ k ] ;
param - > dataset_type = dataset_type ;
if ( dataset_type = = 0 ) {
param - > images = images ;
param - > labels = labels ;
param - > dataset = NULL ;
param - > width = 28 ;
param - > height = 28 ;
} else {
param - > dataset = dataset ;
param - > width = dataset - > width ;
param - > height = dataset - > height ;
param - > images = NULL ;
param - > labels = NULL ;
}
param - > nb_images = BATCHES / nb_threads ;
param - > index = shuffle_index ;
param - > network = copy_network ( network ) ;
2023-05-25 13:31:55 +02:00
param - > offset = offset ;
2023-05-25 16:32:37 +02:00
param - > finetuning = finetuning ;
2023-03-11 13:45:00 +01:00
}
# else
// Création des paramètres donnés à l'unique
// thread dans l'hypothèse ou le multi-threading n'est pas utilisé.
// Cela est utile à des fins de débogage notamment,
// où l'utilisation de threads rend vite les choses plus compliquées qu'elles ne le sont.
TrainParameters * train_params = ( TrainParameters * ) malloc ( sizeof ( TrainParameters ) ) ;
train_params - > network = network ;
train_params - > dataset_type = dataset_type ;
2022-10-01 17:53:14 +02:00
if ( dataset_type = = 0 ) {
2023-03-11 13:45:00 +01:00
train_params - > images = images ;
train_params - > labels = labels ;
train_params - > width = 28 ;
train_params - > height = 28 ;
train_params - > dataset = NULL ;
2022-10-01 17:53:14 +02:00
} else {
2023-03-11 13:45:00 +01:00
train_params - > dataset = dataset ;
train_params - > width = dataset - > width ;
train_params - > height = dataset - > height ;
train_params - > images = NULL ;
train_params - > labels = NULL ;
2022-10-01 17:53:14 +02:00
}
2023-03-11 13:45:00 +01:00
train_params - > nb_images = BATCHES ;
train_params - > index = shuffle_index ;
2023-05-25 13:31:55 +02:00
train_params - > offset = offset ;
2023-05-25 16:32:37 +02:00
train_params - > finetuning = finetuning ;
2022-10-01 17:53:14 +02:00
# endif
2023-03-11 13:45:00 +01:00
2023-01-13 15:58:11 +01:00
end_time = omp_get_wtime ( ) ;
elapsed_time = end_time - start_time ;
2023-03-22 13:03:19 +01:00
printf ( " Taux d'apprentissage initial: %0.2e \n " , network - > learning_rate ) ;
2023-02-19 15:01:58 +01:00
printf ( " Initialisation: " ) ;
printf_time ( elapsed_time ) ;
printf ( " \n \n " ) ;
2022-10-01 17:53:14 +02:00
2023-03-11 13:45:00 +01:00
//* Boucle d'apprentissage
2022-10-01 17:53:14 +02:00
for ( int i = 0 ; i < epochs ; i + + ) {
2023-01-13 15:58:11 +01:00
start_time = omp_get_wtime ( ) ;
2022-10-01 17:53:14 +02:00
// La variable accuracy permet d'avoir une ESTIMATION
// du taux de réussite et de l'entraînement du réseau,
// mais n'est en aucun cas une valeur réelle dans le cas
// du multi-threading car chaque copie du réseau initiale sera légèrement différente
// et donnera donc des résultats différents sur les mêmes images.
accuracy = 0. ;
2023-01-20 13:41:38 +01:00
loss = 0. ;
2022-12-07 10:44:28 +01:00
knuth_shuffle ( shuffle_index , nb_images_total ) ;
2022-11-25 15:17:47 +01:00
batches_epoques = div_up ( nb_images_total , BATCHES ) ;
2022-11-23 11:37:26 +01:00
nb_images_total_remaining = nb_images_total ;
2022-12-07 13:09:39 +01:00
# ifndef USE_MULTITHREADING
2023-01-14 15:02:57 +01:00
train_params - > nb_images = BATCHES ;
2022-12-07 13:09:39 +01:00
# endif
2023-01-20 13:41:38 +01:00
2022-11-23 11:37:26 +01:00
for ( int j = 0 ; j < batches_epoques ; j + + ) {
2023-01-20 13:41:38 +01:00
batch_loss = 0. ;
2023-03-13 13:55:09 +01:00
batch_accuracy = 0. ;
2022-11-15 17:50:33 +01:00
# ifdef USE_MULTITHREADING
2023-01-14 15:02:57 +01:00
if ( j = = batches_epoques - 1 ) {
nb_remaining_images = nb_images_total_remaining ;
nb_images_total_remaining = 0 ;
2022-10-01 17:53:14 +02:00
} else {
2023-01-14 15:02:57 +01:00
nb_images_total_remaining - = BATCHES ;
nb_remaining_images = BATCHES ;
2022-10-01 17:53:14 +02:00
}
2022-11-23 11:37:26 +01:00
2023-01-14 15:02:57 +01:00
for ( int k = 0 ; k < nb_threads ; k + + ) {
if ( k = = nb_threads - 1 ) {
train_parameters [ k ] - > nb_images = nb_remaining_images ;
nb_remaining_images = 0 ;
} else {
nb_remaining_images - = BATCHES / nb_threads ;
}
train_parameters [ k ] - > start = BATCHES * j + ( BATCHES / nb_threads ) * k ;
if ( train_parameters [ k ] - > start + train_parameters [ k ] - > nb_images > = nb_images_total ) {
train_parameters [ k ] - > nb_images = nb_images_total - train_parameters [ k ] - > start - 1 ;
}
if ( train_parameters [ k ] - > nb_images > 0 ) {
2023-01-28 13:09:52 +01:00
thread_used [ k ] = true ;
copy_network_parameters ( network , train_parameters [ k ] - > network ) ;
2023-01-14 15:02:57 +01:00
pthread_create ( & tid [ k ] , NULL , train_thread , ( void * ) train_parameters [ k ] ) ;
} else {
2023-01-28 13:09:52 +01:00
thread_used [ k ] = false ;
2023-01-14 15:02:57 +01:00
}
2023-01-13 15:58:11 +01:00
}
2023-01-14 15:02:57 +01:00
for ( int k = 0 ; k < nb_threads ; k + + ) {
// On attend la terminaison de chaque thread un à un
2023-01-28 13:09:52 +01:00
if ( thread_used [ k ] ) {
2023-01-14 15:02:57 +01:00
pthread_join ( tid [ k ] , NULL ) ;
accuracy + = train_parameters [ k ] - > accuracy / ( float ) nb_images_total ;
2023-01-20 13:41:38 +01:00
loss + = train_parameters [ k ] - > loss / nb_images_total ;
batch_loss + = train_parameters [ k ] - > loss / BATCHES ;
2023-03-13 13:55:09 +01:00
batch_accuracy + = train_parameters [ k ] - > accuracy / ( float ) BATCHES ; // C'est faux pour le dernier batch mais on ne l'affiche pas pour lui (enfin très rapidement)
2023-01-14 15:02:57 +01:00
}
2023-01-13 15:58:11 +01:00
}
2023-01-17 15:34:29 +01:00
2023-01-14 15:02:57 +01:00
// On attend que tous les fils aient fini avant d'appliquer des modifications au réseau principal
for ( int k = 0 ; k < nb_threads ; k + + ) {
if ( train_parameters [ k ] - > network ) { // Si le fil a été utilisé
2023-01-20 13:41:38 +01:00
update_weights ( network , train_parameters [ k ] - > network ) ;
update_bias ( network , train_parameters [ k ] - > network ) ;
2023-01-14 15:02:57 +01:00
}
2023-01-13 15:58:11 +01:00
}
2023-01-14 15:02:57 +01:00
current_accuracy = accuracy * nb_images_total / ( ( j + 1 ) * BATCHES ) ;
2023-03-13 13:55:09 +01:00
printf ( " \r Threads [%d] \t Époque [%d/%d] \t Image [%d/%d] \t Accuracy: " YELLOW " %0.2f%% " RESET " \t Batch Accuracy: " YELLOW " %0.2f%% " RESET , nb_threads , i , epochs , BATCHES * ( j + 1 ) , nb_images_total , current_accuracy * 100 , batch_accuracy * 100 ) ;
2022-10-01 17:53:14 +02:00
# else
2023-01-14 15:02:57 +01:00
( void ) nb_images_total_remaining ; // Juste pour enlever un warning
2022-12-07 10:44:28 +01:00
2023-01-14 15:02:57 +01:00
train_params - > start = j * BATCHES ;
2022-12-07 13:09:39 +01:00
2023-01-14 15:02:57 +01:00
// Ne pas dépasser le nombre d'images à cause de la partie entière
if ( j = = batches_epoques - 1 ) {
train_params - > nb_images = nb_images_total - j * BATCHES ;
}
2023-01-17 15:34:29 +01:00
2023-01-14 15:02:57 +01:00
train_thread ( ( void * ) train_params ) ;
2023-01-17 15:34:29 +01:00
2023-01-14 15:02:57 +01:00
accuracy + = train_params - > accuracy / ( float ) nb_images_total ;
current_accuracy = accuracy * nb_images_total / ( ( j + 1 ) * BATCHES ) ;
2023-03-13 13:55:09 +01:00
batch_accuracy + = train_params - > accuracy / ( float ) BATCHES ;
2023-01-20 13:41:38 +01:00
loss + = train_params - > loss / nb_images_total ;
batch_loss + = train_params - > loss / BATCHES ;
2023-01-17 15:34:29 +01:00
2023-01-20 13:41:38 +01:00
update_weights ( network , network ) ;
update_bias ( network , network ) ;
2023-01-17 15:34:29 +01:00
2023-03-13 13:55:09 +01:00
printf ( " \r Époque [%d/%d] \t Image [%d/%d] \t Accuracy: " YELLOW " %0.4f%% " RESET " \t Batch Accuracy: " YELLOW " %0.2f%% " RESET , i , epochs , BATCHES * ( j + 1 ) , nb_images_total , current_accuracy * 100 , batch_accuracy * 100 ) ;
2022-10-01 17:53:14 +02:00
# endif
}
2023-03-11 13:45:00 +01:00
//* Fin d'une époque: affichage des résultats et sauvegarde du réseau
2023-01-13 15:58:11 +01:00
end_time = omp_get_wtime ( ) ;
elapsed_time = end_time - start_time ;
2022-10-01 17:53:14 +02:00
# ifdef USE_MULTITHREADING
2023-02-24 14:36:48 +01:00
printf ( " \r Threads [%d] \t Époque [%d/%d] \t Image [%d/%d] \t Accuracy: " GREEN " %0.4f%% " RESET " \t Loss: %lf \t Temps: " , nb_threads , i , epochs , nb_images_total , nb_images_total , accuracy * 100 , loss ) ;
2023-02-19 15:01:58 +01:00
printf_time ( elapsed_time ) ;
printf ( " \n " ) ;
2022-10-01 17:53:14 +02:00
# else
2023-02-24 14:36:48 +01:00
printf ( " \r Époque [%d/%d] \t Image [%d/%d] \t Accuracy: " GREEN " %0.4f%% " RESET " \t Loss: %lf \t Temps: " , i , epochs , nb_images_total , nb_images_total , accuracy * 100 , loss ) ;
2023-02-19 15:01:58 +01:00
printf_time ( elapsed_time ) ;
printf ( " \n " ) ;
2022-10-01 17:53:14 +02:00
# endif
write_network ( out , network ) ;
2023-03-22 13:03:19 +01:00
// If you want to test the network between each epoch, uncomment the following lines:
/*
2023-05-25 13:31:55 +02:00
float * test_results = test_network ( 0 , out , " data/mnist/t10k-images-idx3-ubyte " , " data/mnist/t10k-labels-idx1-ubyte " , NULL , false , false , offset ) ;
2023-03-22 13:03:19 +01:00
printf ( " Tests: Accuracy: %0.2lf%% \t Loss: %lf \n " , test_results [ 0 ] , test_results [ 1 ] ) ;
if ( test_results [ 0 ] < test_accuracy ) {
network - > learning_rate * = 0.1 ;
printf ( " Decreased learning rate to %0.2e \n " , network - > learning_rate ) ;
}
if ( test_results [ 0 ] = = test_accuracy ) {
network - > learning_rate * = 2 ;
printf ( " Increased learning rate to %0.2e \n " , network - > learning_rate ) ;
}
test_accuracy = test_results [ 0 ] ;
free ( test_results ) ;
*/
2022-10-01 17:53:14 +02:00
}
2023-01-28 13:09:52 +01:00
2023-03-11 13:45:00 +01:00
//* Fin de l'algo
2023-01-28 13:09:52 +01:00
// To generate a new neural and compare performances with scripts/benchmark_binary.py
if ( epochs = = 0 ) {
write_network ( out , network ) ;
}
2022-12-07 10:44:28 +01:00
free ( shuffle_index ) ;
2022-10-07 14:26:36 +02:00
free_network ( network ) ;
2023-01-14 15:28:02 +01:00
2022-10-01 17:53:14 +02:00
# ifdef USE_MULTITHREADING
free ( tid ) ;
2023-01-28 13:09:52 +01:00
for ( int i = 0 ; i < nb_threads ; i + + ) {
2023-01-28 22:04:38 +01:00
free_network ( train_parameters [ i ] - > network ) ;
2023-01-28 13:09:52 +01:00
}
free ( train_parameters ) ;
2022-10-01 17:53:14 +02:00
# else
free ( train_params ) ;
# endif
2023-01-17 12:50:35 +01:00
2023-01-14 15:28:02 +01:00
if ( dataset_type = = 0 ) {
for ( int i = 0 ; i < nb_images_total ; i + + ) {
for ( int j = 0 ; j < 28 ; j + + ) {
2023-02-19 10:22:22 +01:00
free ( images [ i ] [ j ] ) ;
2023-01-14 15:28:02 +01:00
}
2023-02-19 10:22:22 +01:00
free ( images [ i ] ) ;
2023-01-14 15:28:02 +01:00
}
2023-02-19 10:22:22 +01:00
free ( images ) ;
free ( labels ) ;
2023-01-14 15:28:02 +01:00
} else {
free_dataset ( dataset ) ;
}
2023-01-13 15:58:11 +01:00
end_time = omp_get_wtime ( ) ;
elapsed_time = end_time - algo_start ;
2023-02-19 15:01:58 +01:00
printf ( " \n Temps total: " ) ;
printf_time ( elapsed_time ) ;
printf ( " \n " ) ;
2023-05-25 16:32:37 +02:00
}