tipe/src/cnn/train.c

491 lines
19 KiB
C
Raw Normal View History

2023-01-28 13:09:52 +01:00
#include <sys/sysinfo.h>
#include <pthread.h>
2022-10-01 17:53:14 +02:00
#include <stdlib.h>
#include <stdio.h>
#include <float.h>
2023-01-28 13:09:52 +01:00
#include <math.h>
2022-11-16 10:38:01 +01:00
#include <time.h>
#include <omp.h>
2022-10-01 17:53:14 +02:00
2023-05-12 16:16:34 +02:00
#include "../common/include/memory_management.h"
#include "../common/include/colors.h"
#include "../common/include/utils.h"
2023-05-12 16:16:34 +02:00
#include "../common/include/mnist.h"
2022-10-24 12:54:51 +02:00
#include "include/initialisation.h"
2023-02-24 14:36:48 +01:00
#include "include/test_network.h"
2022-10-24 12:54:51 +02:00
#include "include/neuron_io.h"
#include "include/function.h"
2022-11-15 12:50:38 +01:00
#include "include/update.h"
#include "include/models.h"
2022-10-24 12:54:51 +02:00
#include "include/utils.h"
#include "include/free.h"
2022-11-19 16:09:07 +01:00
#include "include/jpeg.h"
2022-10-24 12:54:51 +02:00
#include "include/cnn.h"
2022-10-01 17:53:14 +02:00
#include "include/train.h"
2022-11-25 15:17:47 +01:00
int div_up(int a, int b) { // Partie entière supérieure de a/b
return ((a % b) != 0) ? (a / b + 1) : (a / b);
2022-11-03 18:13:01 +01:00
}
2023-04-02 17:34:31 +02:00
void* load_image(void* parameters) {
LoadImageParameters* param = (LoadImageParameters*)parameters;
if (!param->dataset->images[param->index]) {
imgRawImage* image = loadJpegImageFile(param->dataset->fileNames[param->index]);
param->dataset->images[param->index] = image->lpData;
free(image);
} else {
printf_warning((char*)"Image déjà chargée\n"); // Pas possible techniquement, donc on met un warning
}
return NULL;
}
2022-11-03 18:13:01 +01:00
2022-10-01 17:53:14 +02:00
void* train_thread(void* parameters) {
TrainParameters* param = (TrainParameters*)parameters;
Network* network = param->network;
2022-11-19 16:09:07 +01:00
imgRawImage* image;
2022-11-03 18:13:01 +01:00
int maxi;
2022-10-01 17:53:14 +02:00
int*** images = param->images;
2022-10-24 12:54:51 +02:00
int* labels = (int*)param->labels;
2022-12-07 10:44:28 +01:00
int* index = param->index;
2022-10-01 17:53:14 +02:00
int width = param->width;
int height = param->height;
int dataset_type = param->dataset_type;
int start = param->start;
int nb_images = param->nb_images;
2023-05-25 16:32:37 +02:00
int finetuning = param->finetuning;
2023-01-20 13:41:38 +01:00
float* wanted_output;
2022-10-01 17:53:14 +02:00
float accuracy = 0.;
2023-01-20 13:41:38 +01:00
float loss = 0.;
2023-05-19 21:48:08 +02:00
#ifdef DETAILED_TRAIN_TIMINGS
double start_time;
#endif
2023-04-02 17:34:31 +02:00
pthread_t tid;
2023-05-13 10:05:54 +02:00
LoadImageParameters* load_image_param = (LoadImageParameters*)malloc(sizeof(LoadImageParameters));
2023-04-02 17:34:31 +02:00
if (dataset_type != 0) {
load_image_param->dataset = param->dataset;
load_image_param->index = index[start];
pthread_create(&tid, NULL, load_image, (void*) load_image_param);
}
2022-10-01 17:53:14 +02:00
for (int i=start; i < start+nb_images; i++) {
if (dataset_type == 0) {
write_image_in_network_32(images[index[i]], height, width, network->input[0][0], param->offset);
2023-05-19 21:48:08 +02:00
#ifdef DETAILED_TRAIN_TIMINGS
start_time = omp_get_wtime();
#endif
2022-10-07 14:26:36 +02:00
forward_propagation(network);
2023-05-19 21:48:08 +02:00
#ifdef DETAILED_TRAIN_TIMINGS
printf("Temps de forward: ");
printf_time(omp_get_wtime() - start_time);
printf("\n");
start_time = omp_get_wtime();
#endif
2022-11-15 12:58:00 +01:00
maxi = indice_max(network->input[network->size-1][0][0], 10);
2023-01-21 18:59:59 +01:00
if (maxi == -1) {
printf("\n");
printf_error((char*)"Le réseau sature.\n");
2023-01-21 18:59:59 +01:00
exit(1);
}
2023-01-20 13:41:38 +01:00
wanted_output = generate_wanted_output(labels[index[i]], 10);
loss += compute_mean_squared_error(network->input[network->size-1][0][0], wanted_output, 10);
gree(wanted_output, false);
2023-01-20 13:41:38 +01:00
2023-05-25 16:32:37 +02:00
backward_propagation(network, labels[index[i]], finetuning);
2022-11-16 10:38:01 +01:00
2023-05-19 21:48:08 +02:00
#ifdef DETAILED_TRAIN_TIMINGS
printf("Temps de backward: ");
printf_time(omp_get_wtime() - start_time);
printf("\n");
start_time = omp_get_wtime();
#endif
2022-12-07 10:44:28 +01:00
if (maxi == labels[index[i]]) {
2022-11-03 18:13:01 +01:00
accuracy += 1.;
}
2022-10-01 17:53:14 +02:00
} else {
2023-04-02 17:34:31 +02:00
pthread_join(tid, NULL);
2022-12-07 10:44:28 +01:00
if (!param->dataset->images[index[i]]) {
image = loadJpegImageFile(param->dataset->fileNames[index[i]]);
param->dataset->images[index[i]] = image->lpData;
free(image);
2022-11-19 16:09:07 +01:00
}
2023-04-02 17:34:31 +02:00
if (i != start+nb_images-1) {
load_image_param->index = index[i+1];
pthread_create(&tid, NULL, load_image, (void*) load_image_param);
}
write_256_image_in_network(param->dataset->images[index[i]], width, param->dataset->numComponents, network->width[0], network->input[0]);
2023-05-19 21:48:08 +02:00
#ifdef DETAILED_TRAIN_TIMINGS
start_time = omp_get_wtime();
#endif
2022-11-19 16:09:07 +01:00
forward_propagation(network);
2023-05-19 21:48:08 +02:00
#ifdef DETAILED_TRAIN_TIMINGS
printf("Temps de forward: ");
printf_time(omp_get_wtime() - start_time);
printf("\n");
start_time = omp_get_wtime();
#endif
2022-11-19 22:22:24 +01:00
maxi = indice_max(network->input[network->size-1][0][0], param->dataset->numCategories);
2023-05-25 16:32:37 +02:00
backward_propagation(network, param->dataset->labels[index[i]], finetuning);
2023-05-19 21:48:08 +02:00
#ifdef DETAILED_TRAIN_TIMINGS
printf("Temps de backward: ");
printf_time(omp_get_wtime() - start_time);
printf("\n");
start_time = omp_get_wtime();
#endif
2022-11-19 16:09:07 +01:00
2022-12-07 10:44:28 +01:00
if (maxi == (int)param->dataset->labels[index[i]]) {
2022-11-19 16:09:07 +01:00
accuracy += 1.;
}
free(param->dataset->images[index[i]]);
2022-12-07 10:44:28 +01:00
param->dataset->images[index[i]] = NULL;
2022-10-01 17:53:14 +02:00
}
}
2023-05-13 10:05:54 +02:00
free(load_image_param);
2023-04-02 17:34:31 +02:00
2022-10-01 17:53:14 +02:00
param->accuracy = accuracy;
2023-01-20 13:41:38 +01:00
param->loss = loss;
2022-10-01 17:53:14 +02:00
return NULL;
}
2023-05-25 16:32:37 +02:00
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out, char* recover, bool offset, int finetuning) {
2023-01-29 09:40:55 +01:00
#ifdef USE_CUDA
2023-05-16 13:21:26 +02:00
bool compatibility = cuda_setup(true);
2023-01-29 09:40:55 +01:00
if (!compatibility) {
printf("Exiting.\n");
2023-01-29 09:47:29 +01:00
exit(1);
2023-01-29 09:40:55 +01:00
}
#endif
2022-11-16 10:38:01 +01:00
srand(time(NULL));
2023-01-20 13:41:38 +01:00
float loss;
float batch_loss; // May be redundant with loss, but gives more informations
float test_accuracy = 0.; // Used to decrease Learning rate
2023-05-13 10:05:54 +02:00
(void)test_accuracy; // To avoid warnings when not used
2022-10-01 17:53:14 +02:00
float accuracy;
2023-03-13 13:55:09 +01:00
float batch_accuracy;
2022-11-15 17:50:33 +01:00
float current_accuracy;
2022-10-01 17:53:14 +02:00
2023-03-11 13:45:00 +01:00
//* Différents timers pour mesurer les performance en terme de vitesse
double start_time, end_time;
double elapsed_time;
double algo_start = omp_get_wtime();
start_time = omp_get_wtime();
2023-03-11 13:45:00 +01:00
//* Chargement du dataset
2023-05-13 17:22:47 +02:00
int input_width = -1;
2023-03-11 13:45:00 +01:00
int input_depth = -1;
int nb_images_total; // Images au total
int nb_images_total_remaining; // Images restantes dans un batch
int batches_epoques; // Batches par époque
2023-05-19 21:48:08 +02:00
int*** images = NULL; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
unsigned int* labels = NULL; // Labels associés aux images du dataset MNIST
jpegDataset* dataset = NULL; // Structure de données décrivant un dataset d'images jpeg
2022-10-01 17:53:14 +02:00
if (dataset_type == 0) { // Type MNIST
// Chargement des images du set de données MNIST
int* parameters = read_mnist_images_parameters(images_file);
nb_images_total = parameters[0];
2023-02-18 13:10:00 +01:00
free(parameters);
2022-10-01 17:53:14 +02:00
images = read_mnist_images(images_file);
labels = read_mnist_labels(labels_file);
2023-05-13 17:22:47 +02:00
input_width = 32;
2022-10-01 17:53:14 +02:00
input_depth = 1;
2022-11-19 16:09:07 +01:00
} else { // Type JPG
dataset = loadJpegDataset(data_dir);
2023-05-13 17:22:47 +02:00
input_width = dataset->height + 4; // image_size + padding
2022-11-19 16:09:07 +01:00
input_depth = dataset->numComponents;
2022-10-01 17:53:14 +02:00
2022-11-19 16:09:07 +01:00
nb_images_total = dataset->numImages;
2022-10-01 17:53:14 +02:00
}
2023-03-11 13:45:00 +01:00
//* Création du réseau
Network* network;
2022-12-07 13:09:39 +01:00
if (!recover) {
if (dataset_type == 0) {
2023-05-19 21:48:08 +02:00
network = create_network_lenet5(LEARNING_RATE, 0, LEAKY_RELU, HE, input_width, input_depth);
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth);
} else {
network = create_network_VGG16(LEARNING_RATE, 0, RELU, HE, dataset->numCategories);
2023-05-19 21:48:08 +02:00
#ifdef USE_MULTITHREADING
printf_warning("Utilisation de VGG16 avec multithreading. La quantité de RAM utilisée peut devenir excessive\n");
2023-05-19 21:48:08 +02:00
#endif
}
2022-12-07 13:09:39 +01:00
} else {
network = read_network(recover);
2023-01-20 13:41:38 +01:00
network->learning_rate = LEARNING_RATE;
2022-12-07 13:09:39 +01:00
}
2023-01-17 15:34:29 +01:00
2023-03-11 13:45:00 +01:00
/*
shuffle_index[i] contient le nouvel index de l'élément à l'emplacement i avant mélange
Cela permet de réordonner le jeu d'apprentissage pour éviter certains biais
qui pourraient provenir de l'ordre établi.
*/
int* shuffle_index = (int*)malloc(sizeof(int)*nb_images_total);
2022-12-07 10:44:28 +01:00
for (int i=0; i < nb_images_total; i++) {
shuffle_index[i] = i;
}
2022-10-01 17:53:14 +02:00
2023-03-11 13:45:00 +01:00
//* Création des paramètres d'entrée de train_thread
2022-10-01 17:53:14 +02:00
#ifdef USE_MULTITHREADING
2023-03-11 13:45:00 +01:00
int nb_remaining_images; // Nombre d'images restantes à lancer pour une série de threads
// Récupération du nombre de threads disponibles
int nb_threads = get_nprocs();
pthread_t *tid = (pthread_t*)malloc(nb_threads * sizeof(pthread_t));
// Création des paramètres donnés à chaque thread dans le cas du multi-threading
TrainParameters** train_parameters = (TrainParameters**)malloc(sizeof(TrainParameters*)*nb_threads);
TrainParameters* param;
bool* thread_used = (bool*)malloc(sizeof(bool)*nb_threads);
for (int k=0; k < nb_threads; k++) {
train_parameters[k] = (TrainParameters*)malloc(sizeof(TrainParameters));
param = train_parameters[k];
param->dataset_type = dataset_type;
if (dataset_type == 0) {
param->images = images;
param->labels = labels;
param->dataset = NULL;
param->width = 28;
param->height = 28;
} else {
param->dataset = dataset;
param->width = dataset->width;
param->height = dataset->height;
param->images = NULL;
param->labels = NULL;
}
param->nb_images = BATCHES / nb_threads;
param->index = shuffle_index;
param->network = copy_network(network);
param->offset = offset;
2023-05-25 16:32:37 +02:00
param->finetuning = finetuning;
2023-03-11 13:45:00 +01:00
}
#else
// Création des paramètres donnés à l'unique
// thread dans l'hypothèse ou le multi-threading n'est pas utilisé.
// Cela est utile à des fins de débogage notamment,
// où l'utilisation de threads rend vite les choses plus compliquées qu'elles ne le sont.
TrainParameters* train_params = (TrainParameters*)malloc(sizeof(TrainParameters));
train_params->network = network;
train_params->dataset_type = dataset_type;
2022-10-01 17:53:14 +02:00
if (dataset_type == 0) {
2023-03-11 13:45:00 +01:00
train_params->images = images;
train_params->labels = labels;
train_params->width = 28;
train_params->height = 28;
train_params->dataset = NULL;
2022-10-01 17:53:14 +02:00
} else {
2023-03-11 13:45:00 +01:00
train_params->dataset = dataset;
train_params->width = dataset->width;
train_params->height = dataset->height;
train_params->images = NULL;
train_params->labels = NULL;
2022-10-01 17:53:14 +02:00
}
2023-03-11 13:45:00 +01:00
train_params->nb_images = BATCHES;
train_params->index = shuffle_index;
train_params->offset = offset;
2023-05-25 16:32:37 +02:00
train_params->finetuning = finetuning;
2022-10-01 17:53:14 +02:00
#endif
2023-03-11 13:45:00 +01:00
end_time = omp_get_wtime();
elapsed_time = end_time - start_time;
printf("Taux d'apprentissage initial: %0.2e\n", network->learning_rate);
printf("Initialisation: ");
printf_time(elapsed_time);
printf("\n\n");
2022-10-01 17:53:14 +02:00
2023-03-11 13:45:00 +01:00
//* Boucle d'apprentissage
2022-10-01 17:53:14 +02:00
for (int i=0; i < epochs; i++) {
start_time = omp_get_wtime();
2022-10-01 17:53:14 +02:00
// La variable accuracy permet d'avoir une ESTIMATION
// du taux de réussite et de l'entraînement du réseau,
// mais n'est en aucun cas une valeur réelle dans le cas
// du multi-threading car chaque copie du réseau initiale sera légèrement différente
// et donnera donc des résultats différents sur les mêmes images.
accuracy = 0.;
2023-01-20 13:41:38 +01:00
loss = 0.;
2022-12-07 10:44:28 +01:00
knuth_shuffle(shuffle_index, nb_images_total);
2022-11-25 15:17:47 +01:00
batches_epoques = div_up(nb_images_total, BATCHES);
2022-11-23 11:37:26 +01:00
nb_images_total_remaining = nb_images_total;
2022-12-07 13:09:39 +01:00
#ifndef USE_MULTITHREADING
train_params->nb_images = BATCHES;
2022-12-07 13:09:39 +01:00
#endif
2023-01-20 13:41:38 +01:00
2022-11-23 11:37:26 +01:00
for (int j=0; j < batches_epoques; j++) {
2023-01-20 13:41:38 +01:00
batch_loss = 0.;
2023-03-13 13:55:09 +01:00
batch_accuracy = 0.;
2022-11-15 17:50:33 +01:00
#ifdef USE_MULTITHREADING
if (j == batches_epoques-1) {
nb_remaining_images = nb_images_total_remaining;
nb_images_total_remaining = 0;
2022-10-01 17:53:14 +02:00
} else {
nb_images_total_remaining -= BATCHES;
nb_remaining_images = BATCHES;
2022-10-01 17:53:14 +02:00
}
2022-11-23 11:37:26 +01:00
for (int k=0; k < nb_threads; k++) {
if (k == nb_threads-1) {
train_parameters[k]->nb_images = nb_remaining_images;
nb_remaining_images = 0;
} else {
nb_remaining_images -= BATCHES / nb_threads;
}
train_parameters[k]->start = BATCHES*j + (BATCHES/nb_threads)*k;
if (train_parameters[k]->start+train_parameters[k]->nb_images >= nb_images_total) {
train_parameters[k]->nb_images = nb_images_total - train_parameters[k]->start -1;
}
if (train_parameters[k]->nb_images > 0) {
2023-01-28 13:09:52 +01:00
thread_used[k] = true;
copy_network_parameters(network, train_parameters[k]->network);
pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]);
} else {
2023-01-28 13:09:52 +01:00
thread_used[k] = false;
}
}
for (int k=0; k < nb_threads; k++) {
// On attend la terminaison de chaque thread un à un
2023-01-28 13:09:52 +01:00
if (thread_used[k]) {
pthread_join( tid[k], NULL );
accuracy += train_parameters[k]->accuracy / (float) nb_images_total;
2023-01-20 13:41:38 +01:00
loss += train_parameters[k]->loss/nb_images_total;
batch_loss += train_parameters[k]->loss/BATCHES;
2023-03-13 13:55:09 +01:00
batch_accuracy += train_parameters[k]->accuracy / (float) BATCHES; // C'est faux pour le dernier batch mais on ne l'affiche pas pour lui (enfin très rapidement)
}
}
2023-01-17 15:34:29 +01:00
// On attend que tous les fils aient fini avant d'appliquer des modifications au réseau principal
for (int k=0; k < nb_threads; k++) {
if (train_parameters[k]->network) { // Si le fil a été utilisé
2023-01-20 13:41:38 +01:00
update_weights(network, train_parameters[k]->network);
update_bias(network, train_parameters[k]->network);
}
}
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
2023-03-13 13:55:09 +01:00
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: " YELLOW "%0.2f%%" RESET " \tBatch Accuracy: " YELLOW "%0.2f%%" RESET, nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100, batch_accuracy*100);
2022-10-01 17:53:14 +02:00
#else
(void)nb_images_total_remaining; // Juste pour enlever un warning
2022-12-07 10:44:28 +01:00
train_params->start = j*BATCHES;
2022-12-07 13:09:39 +01:00
// Ne pas dépasser le nombre d'images à cause de la partie entière
if (j == batches_epoques-1) {
train_params->nb_images = nb_images_total - j*BATCHES;
}
2023-01-17 15:34:29 +01:00
train_thread((void*)train_params);
2023-01-17 15:34:29 +01:00
accuracy += train_params->accuracy / (float) nb_images_total;
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
2023-03-13 13:55:09 +01:00
batch_accuracy += train_params->accuracy / (float)BATCHES;
2023-01-20 13:41:38 +01:00
loss += train_params->loss/nb_images_total;
batch_loss += train_params->loss/BATCHES;
2023-01-17 15:34:29 +01:00
2023-01-20 13:41:38 +01:00
update_weights(network, network);
update_bias(network, network);
2023-01-17 15:34:29 +01:00
2023-03-13 13:55:09 +01:00
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: " YELLOW "%0.4f%%" RESET "\tBatch Accuracy: " YELLOW "%0.2f%%" RESET, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100, batch_accuracy*100);
2022-10-01 17:53:14 +02:00
#endif
}
2023-03-11 13:45:00 +01:00
//* Fin d'une époque: affichage des résultats et sauvegarde du réseau
end_time = omp_get_wtime();
elapsed_time = end_time - start_time;
2022-10-01 17:53:14 +02:00
#ifdef USE_MULTITHREADING
2023-02-24 14:36:48 +01:00
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: " GREEN "%0.4f%%" RESET " \tLoss: %lf\tTemps: ", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100, loss);
printf_time(elapsed_time);
printf("\n");
2022-10-01 17:53:14 +02:00
#else
2023-02-24 14:36:48 +01:00
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: " GREEN "%0.4f%%" RESET " \tLoss: %lf\tTemps: ", i, epochs, nb_images_total, nb_images_total, accuracy*100, loss);
printf_time(elapsed_time);
printf("\n");
2022-10-01 17:53:14 +02:00
#endif
write_network(out, network);
// If you want to test the network between each epoch, uncomment the following lines:
/*
float* test_results = test_network(0, out, "data/mnist/t10k-images-idx3-ubyte", "data/mnist/t10k-labels-idx1-ubyte", NULL, false, false, offset);
printf("Tests: Accuracy: %0.2lf%%\tLoss: %lf\n", test_results[0], test_results[1]);
if (test_results[0] < test_accuracy) {
network->learning_rate *= 0.1;
printf("Decreased learning rate to %0.2e\n", network->learning_rate);
}
if (test_results[0] == test_accuracy) {
network->learning_rate *= 2;
printf("Increased learning rate to %0.2e\n", network->learning_rate);
}
test_accuracy = test_results[0];
free(test_results);
*/
2022-10-01 17:53:14 +02:00
}
2023-01-28 13:09:52 +01:00
2023-03-11 13:45:00 +01:00
//* Fin de l'algo
2023-01-28 13:09:52 +01:00
// To generate a new neural and compare performances with scripts/benchmark_binary.py
if (epochs == 0) {
write_network(out, network);
}
2022-12-07 10:44:28 +01:00
free(shuffle_index);
2022-10-07 14:26:36 +02:00
free_network(network);
2023-01-14 15:28:02 +01:00
2022-10-01 17:53:14 +02:00
#ifdef USE_MULTITHREADING
free(tid);
2023-01-28 13:09:52 +01:00
for (int i=0; i < nb_threads; i++) {
2023-01-28 22:04:38 +01:00
free_network(train_parameters[i]->network);
2023-01-28 13:09:52 +01:00
}
free(train_parameters);
2022-10-01 17:53:14 +02:00
#else
free(train_params);
#endif
2023-01-17 12:50:35 +01:00
2023-01-14 15:28:02 +01:00
if (dataset_type == 0) {
for (int i=0; i < nb_images_total; i++) {
for (int j=0; j < 28; j++) {
free(images[i][j]);
2023-01-14 15:28:02 +01:00
}
free(images[i]);
2023-01-14 15:28:02 +01:00
}
free(images);
free(labels);
2023-01-14 15:28:02 +01:00
} else {
free_dataset(dataset);
}
end_time = omp_get_wtime();
elapsed_time = end_time - algo_start;
printf("\nTemps total: ");
printf_time(elapsed_time);
printf("\n");
2023-05-25 16:32:37 +02:00
}