2022-10-01 17:53:14 +02:00
|
|
|
#include <stdlib.h>
|
|
|
|
#include <stdio.h>
|
|
|
|
#include <float.h>
|
|
|
|
#include <pthread.h>
|
|
|
|
#include <sys/sysinfo.h>
|
2022-11-16 10:38:01 +01:00
|
|
|
#include <time.h>
|
2023-01-13 15:58:11 +01:00
|
|
|
#include <omp.h>
|
2022-10-01 17:53:14 +02:00
|
|
|
|
2022-10-24 12:54:51 +02:00
|
|
|
#include "../mnist/include/mnist.h"
|
|
|
|
#include "include/initialisation.h"
|
|
|
|
#include "include/neuron_io.h"
|
|
|
|
#include "../include/colors.h"
|
|
|
|
#include "include/function.h"
|
|
|
|
#include "include/creation.h"
|
2022-11-15 12:50:38 +01:00
|
|
|
#include "include/update.h"
|
2022-10-24 12:54:51 +02:00
|
|
|
#include "include/utils.h"
|
|
|
|
#include "include/free.h"
|
2022-11-19 16:09:07 +01:00
|
|
|
#include "include/jpeg.h"
|
2022-10-24 12:54:51 +02:00
|
|
|
#include "include/cnn.h"
|
2022-10-01 17:53:14 +02:00
|
|
|
|
|
|
|
#include "include/train.h"
|
|
|
|
|
2022-11-25 15:17:47 +01:00
|
|
|
int div_up(int a, int b) { // Partie entière supérieure de a/b
|
|
|
|
return ((a % b) != 0) ? (a / b + 1) : (a / b);
|
2022-11-03 18:13:01 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2022-10-01 17:53:14 +02:00
|
|
|
void* train_thread(void* parameters) {
|
|
|
|
TrainParameters* param = (TrainParameters*)parameters;
|
|
|
|
Network* network = param->network;
|
2022-11-19 16:09:07 +01:00
|
|
|
imgRawImage* image;
|
2022-11-03 18:13:01 +01:00
|
|
|
int maxi;
|
2022-10-01 17:53:14 +02:00
|
|
|
|
|
|
|
int*** images = param->images;
|
2022-10-24 12:54:51 +02:00
|
|
|
int* labels = (int*)param->labels;
|
2022-12-07 10:44:28 +01:00
|
|
|
int* index = param->index;
|
2022-10-01 17:53:14 +02:00
|
|
|
|
|
|
|
int width = param->width;
|
|
|
|
int height = param->height;
|
|
|
|
int dataset_type = param->dataset_type;
|
|
|
|
int start = param->start;
|
|
|
|
int nb_images = param->nb_images;
|
|
|
|
float accuracy = 0.;
|
|
|
|
for (int i=start; i < start+nb_images; i++) {
|
|
|
|
if (dataset_type == 0) {
|
2022-12-07 10:44:28 +01:00
|
|
|
write_image_in_network_32(images[index[i]], height, width, network->input[0][0]);
|
2022-10-07 14:26:36 +02:00
|
|
|
forward_propagation(network);
|
2022-11-15 12:58:00 +01:00
|
|
|
maxi = indice_max(network->input[network->size-1][0][0], 10);
|
2022-10-24 12:54:51 +02:00
|
|
|
backward_propagation(network, labels[i]);
|
2022-11-16 10:38:01 +01:00
|
|
|
|
2022-12-07 10:44:28 +01:00
|
|
|
if (maxi == labels[index[i]]) {
|
2022-11-03 18:13:01 +01:00
|
|
|
accuracy += 1.;
|
|
|
|
}
|
2022-10-01 17:53:14 +02:00
|
|
|
} else {
|
2022-12-07 10:44:28 +01:00
|
|
|
if (!param->dataset->images[index[i]]) {
|
|
|
|
image = loadJpegImageFile(param->dataset->fileNames[index[i]]);
|
|
|
|
param->dataset->images[index[i]] = image->lpData;
|
2022-11-19 16:09:07 +01:00
|
|
|
free(image);
|
|
|
|
}
|
2022-12-07 10:44:28 +01:00
|
|
|
write_image_in_network_260(param->dataset->images[index[i]], height, width, network->input[0]);
|
2022-11-19 16:09:07 +01:00
|
|
|
forward_propagation(network);
|
2022-11-19 22:22:24 +01:00
|
|
|
maxi = indice_max(network->input[network->size-1][0][0], param->dataset->numCategories);
|
2022-12-07 10:44:28 +01:00
|
|
|
backward_propagation(network, param->dataset->labels[index[i]]);
|
2022-11-19 16:09:07 +01:00
|
|
|
|
2022-12-07 10:44:28 +01:00
|
|
|
if (maxi == (int)param->dataset->labels[index[i]]) {
|
2022-11-19 16:09:07 +01:00
|
|
|
accuracy += 1.;
|
|
|
|
}
|
|
|
|
|
2022-12-07 10:44:28 +01:00
|
|
|
free(param->dataset->images[index[i]]);
|
|
|
|
param->dataset->images[index[i]] = NULL;
|
2022-10-01 17:53:14 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
param->accuracy = accuracy;
|
|
|
|
return NULL;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2022-12-07 13:09:39 +01:00
|
|
|
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out, char* recover) {
|
2022-11-16 10:38:01 +01:00
|
|
|
srand(time(NULL));
|
2022-12-07 13:09:39 +01:00
|
|
|
Network* network;
|
2022-10-01 17:53:14 +02:00
|
|
|
int input_dim = -1;
|
|
|
|
int input_depth = -1;
|
|
|
|
float accuracy;
|
2022-11-15 17:50:33 +01:00
|
|
|
float current_accuracy;
|
2022-10-01 17:53:14 +02:00
|
|
|
|
2022-11-23 11:37:26 +01:00
|
|
|
int nb_images_total; // Images au total
|
|
|
|
int nb_images_total_remaining; // Images restantes dans un batch
|
|
|
|
int batches_epoques; // Batches par époque
|
2022-10-01 17:53:14 +02:00
|
|
|
|
2022-12-07 10:44:28 +01:00
|
|
|
int*** images; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
|
|
|
|
unsigned int* labels; // Labels associés aux images du dataset MNIST
|
|
|
|
jpegDataset* dataset; // Structure de données décrivant un dataset d'images jpeg
|
|
|
|
int* shuffle_index; // shuffle_index[i] contient le nouvel index de l'élément à l'emplacement i avant mélange
|
2022-10-01 17:53:14 +02:00
|
|
|
|
2023-01-13 15:58:11 +01:00
|
|
|
double start_time, end_time;
|
|
|
|
double elapsed_time;
|
|
|
|
|
|
|
|
double algo_start = omp_get_wtime();
|
|
|
|
|
|
|
|
start_time = omp_get_wtime();
|
|
|
|
|
2022-10-01 17:53:14 +02:00
|
|
|
if (dataset_type == 0) { // Type MNIST
|
|
|
|
// Chargement des images du set de données MNIST
|
|
|
|
int* parameters = read_mnist_images_parameters(images_file);
|
|
|
|
nb_images_total = parameters[0];
|
|
|
|
free(parameters);
|
|
|
|
|
|
|
|
images = read_mnist_images(images_file);
|
|
|
|
labels = read_mnist_labels(labels_file);
|
|
|
|
|
|
|
|
input_dim = 32;
|
|
|
|
input_depth = 1;
|
2022-11-19 16:09:07 +01:00
|
|
|
} else { // Type JPG
|
|
|
|
dataset = loadJpegDataset(data_dir);
|
|
|
|
input_dim = dataset->height + 4; // image_size + padding
|
|
|
|
input_depth = dataset->numComponents;
|
2022-10-01 17:53:14 +02:00
|
|
|
|
2022-11-19 16:09:07 +01:00
|
|
|
nb_images_total = dataset->numImages;
|
2022-10-01 17:53:14 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
// Initialisation du réseau
|
2022-12-07 13:09:39 +01:00
|
|
|
if (!recover) {
|
2023-01-13 15:58:11 +01:00
|
|
|
network = create_network_lenet5(0.1, 0, TANH, GLOROT, input_dim, input_depth);
|
2022-12-07 13:09:39 +01:00
|
|
|
} else {
|
|
|
|
network = read_network(recover);
|
|
|
|
}
|
|
|
|
|
2022-12-07 10:44:28 +01:00
|
|
|
|
|
|
|
shuffle_index = (int*)malloc(sizeof(int)*nb_images_total);
|
|
|
|
for (int i=0; i < nb_images_total; i++) {
|
|
|
|
shuffle_index[i] = i;
|
|
|
|
}
|
2022-10-01 17:53:14 +02:00
|
|
|
|
|
|
|
#ifdef USE_MULTITHREADING
|
2022-11-15 17:50:33 +01:00
|
|
|
int nb_remaining_images; // Nombre d'images restantes à lancer pour une série de threads
|
2022-10-01 17:53:14 +02:00
|
|
|
// Récupération du nombre de threads disponibles
|
|
|
|
int nb_threads = get_nprocs();
|
|
|
|
pthread_t *tid = (pthread_t*)malloc(nb_threads * sizeof(pthread_t));
|
|
|
|
|
|
|
|
// Création des paramètres donnés à chaque thread dans le cas du multi-threading
|
|
|
|
TrainParameters** train_parameters = (TrainParameters**)malloc(sizeof(TrainParameters*)*nb_threads);
|
|
|
|
TrainParameters* param;
|
2022-12-19 15:49:03 +01:00
|
|
|
|
2022-10-01 17:53:14 +02:00
|
|
|
for (int k=0; k < nb_threads; k++) {
|
|
|
|
train_parameters[k] = (TrainParameters*)malloc(sizeof(TrainParameters));
|
|
|
|
param = train_parameters[k];
|
|
|
|
param->dataset_type = dataset_type;
|
|
|
|
if (dataset_type == 0) {
|
|
|
|
param->images = images;
|
|
|
|
param->labels = labels;
|
2022-11-19 16:09:07 +01:00
|
|
|
param->dataset = NULL;
|
2022-10-01 17:53:14 +02:00
|
|
|
param->width = 28;
|
|
|
|
param->height = 28;
|
|
|
|
} else {
|
2022-11-19 16:09:07 +01:00
|
|
|
param->dataset = dataset;
|
|
|
|
param->width = dataset->width;
|
|
|
|
param->height = dataset->height;
|
2022-10-01 17:53:14 +02:00
|
|
|
param->images = NULL;
|
|
|
|
param->labels = NULL;
|
|
|
|
}
|
|
|
|
param->nb_images = BATCHES / nb_threads;
|
2022-12-07 10:44:28 +01:00
|
|
|
param->index = shuffle_index;
|
2022-10-01 17:53:14 +02:00
|
|
|
}
|
|
|
|
#else
|
|
|
|
// Création des paramètres donnés à l'unique
|
|
|
|
// thread dans l'hypothèse ou le multi-threading n'est pas utilisé.
|
|
|
|
// Cela est utile à des fins de débogage notamment,
|
|
|
|
// où l'utilisation de threads rend vite les choses plus compliquées qu'elles ne le sont.
|
|
|
|
TrainParameters* train_params = (TrainParameters*)malloc(sizeof(TrainParameters));
|
|
|
|
|
|
|
|
train_params->network = network;
|
|
|
|
train_params->dataset_type = dataset_type;
|
|
|
|
if (dataset_type == 0) {
|
|
|
|
train_params->images = images;
|
|
|
|
train_params->labels = labels;
|
2022-10-07 14:26:36 +02:00
|
|
|
train_params->width = 28;
|
|
|
|
train_params->height = 28;
|
2022-11-19 16:09:07 +01:00
|
|
|
train_params->dataset = NULL;
|
2022-10-01 17:53:14 +02:00
|
|
|
} else {
|
2022-11-19 16:09:07 +01:00
|
|
|
train_params->dataset = dataset;
|
|
|
|
train_params->width = dataset->width;
|
|
|
|
train_params->height = dataset->height;
|
2022-10-01 17:53:14 +02:00
|
|
|
train_params->images = NULL;
|
|
|
|
train_params->labels = NULL;
|
|
|
|
}
|
|
|
|
train_params->nb_images = BATCHES;
|
2022-12-07 10:44:28 +01:00
|
|
|
train_params->index = shuffle_index;
|
2022-10-01 17:53:14 +02:00
|
|
|
#endif
|
2023-01-13 15:58:11 +01:00
|
|
|
end_time = omp_get_wtime();
|
|
|
|
|
|
|
|
elapsed_time = end_time - start_time;
|
|
|
|
printf("Initialisation: %0.2lf s\n\n", elapsed_time);
|
2022-10-01 17:53:14 +02:00
|
|
|
|
|
|
|
for (int i=0; i < epochs; i++) {
|
2023-01-13 15:58:11 +01:00
|
|
|
|
|
|
|
start_time = omp_get_wtime();
|
2022-10-01 17:53:14 +02:00
|
|
|
// La variable accuracy permet d'avoir une ESTIMATION
|
|
|
|
// du taux de réussite et de l'entraînement du réseau,
|
|
|
|
// mais n'est en aucun cas une valeur réelle dans le cas
|
|
|
|
// du multi-threading car chaque copie du réseau initiale sera légèrement différente
|
|
|
|
// et donnera donc des résultats différents sur les mêmes images.
|
|
|
|
accuracy = 0.;
|
2022-12-07 10:44:28 +01:00
|
|
|
knuth_shuffle(shuffle_index, nb_images_total);
|
2022-11-25 15:17:47 +01:00
|
|
|
batches_epoques = div_up(nb_images_total, BATCHES);
|
2022-11-23 11:37:26 +01:00
|
|
|
nb_images_total_remaining = nb_images_total;
|
2022-12-07 13:09:39 +01:00
|
|
|
#ifndef USE_MULTITHREADING
|
2023-01-14 15:02:57 +01:00
|
|
|
train_params->nb_images = BATCHES;
|
2022-12-07 13:09:39 +01:00
|
|
|
#endif
|
2022-11-23 11:37:26 +01:00
|
|
|
for (int j=0; j < batches_epoques; j++) {
|
2022-11-15 17:50:33 +01:00
|
|
|
#ifdef USE_MULTITHREADING
|
2023-01-14 15:02:57 +01:00
|
|
|
if (j == batches_epoques-1) {
|
|
|
|
nb_remaining_images = nb_images_total_remaining;
|
|
|
|
nb_images_total_remaining = 0;
|
2022-10-01 17:53:14 +02:00
|
|
|
} else {
|
2023-01-14 15:02:57 +01:00
|
|
|
nb_images_total_remaining -= BATCHES;
|
|
|
|
nb_remaining_images = BATCHES;
|
2022-10-01 17:53:14 +02:00
|
|
|
}
|
2022-11-23 11:37:26 +01:00
|
|
|
|
2023-01-14 15:02:57 +01:00
|
|
|
for (int k=0; k < nb_threads; k++) {
|
|
|
|
if (k == nb_threads-1) {
|
|
|
|
train_parameters[k]->nb_images = nb_remaining_images;
|
|
|
|
nb_remaining_images = 0;
|
|
|
|
} else {
|
|
|
|
nb_remaining_images -= BATCHES / nb_threads;
|
|
|
|
}
|
|
|
|
train_parameters[k]->start = BATCHES*j + (BATCHES/nb_threads)*k;
|
|
|
|
|
|
|
|
if (train_parameters[k]->start+train_parameters[k]->nb_images >= nb_images_total) {
|
|
|
|
train_parameters[k]->nb_images = nb_images_total - train_parameters[k]->start -1;
|
|
|
|
}
|
|
|
|
if (train_parameters[k]->nb_images > 0) {
|
|
|
|
train_parameters[k]->network = copy_network(network);
|
|
|
|
pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]);
|
|
|
|
} else {
|
|
|
|
train_parameters[k]->network = NULL;
|
|
|
|
}
|
2023-01-13 15:58:11 +01:00
|
|
|
}
|
2023-01-14 15:02:57 +01:00
|
|
|
for (int k=0; k < nb_threads; k++) {
|
|
|
|
// On attend la terminaison de chaque thread un à un
|
|
|
|
if (train_parameters[k]->network) {
|
|
|
|
pthread_join( tid[k], NULL );
|
|
|
|
accuracy += train_parameters[k]->accuracy / (float) nb_images_total;
|
|
|
|
}
|
2023-01-13 15:58:11 +01:00
|
|
|
}
|
2023-01-14 15:02:57 +01:00
|
|
|
|
|
|
|
// On attend que tous les fils aient fini avant d'appliquer des modifications au réseau principal
|
|
|
|
for (int k=0; k < nb_threads; k++) {
|
|
|
|
if (train_parameters[k]->network) { // Si le fil a été utilisé
|
|
|
|
update_weights(network, train_parameters[k]->network, train_parameters[k]->nb_images);
|
|
|
|
update_bias(network, train_parameters[k]->network, train_parameters[k]->nb_images);
|
|
|
|
free_network(train_parameters[k]->network);
|
|
|
|
}
|
2023-01-13 15:58:11 +01:00
|
|
|
}
|
2023-01-14 15:02:57 +01:00
|
|
|
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
|
|
|
|
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
|
|
|
fflush(stdout);
|
2022-10-01 17:53:14 +02:00
|
|
|
#else
|
2023-01-14 15:02:57 +01:00
|
|
|
(void)nb_images_total_remaining; // Juste pour enlever un warning
|
2022-12-07 10:44:28 +01:00
|
|
|
|
2023-01-14 15:02:57 +01:00
|
|
|
train_params->start = j*BATCHES;
|
2022-12-07 13:09:39 +01:00
|
|
|
|
2023-01-14 15:02:57 +01:00
|
|
|
// Ne pas dépasser le nombre d'images à cause de la partie entière
|
|
|
|
if (j == batches_epoques-1) {
|
|
|
|
train_params->nb_images = nb_images_total - j*BATCHES;
|
|
|
|
}
|
|
|
|
|
|
|
|
train_thread((void*)train_params);
|
|
|
|
|
|
|
|
accuracy += train_params->accuracy / (float) nb_images_total;
|
|
|
|
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
|
|
|
|
|
|
|
|
update_weights(network, network, train_params->nb_images);
|
|
|
|
update_bias(network, network, train_params->nb_images);
|
|
|
|
|
|
|
|
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
|
|
|
fflush(stdout);
|
2022-10-01 17:53:14 +02:00
|
|
|
#endif
|
|
|
|
}
|
2023-01-13 15:58:11 +01:00
|
|
|
end_time = omp_get_wtime();
|
|
|
|
elapsed_time = end_time - start_time;
|
2022-10-01 17:53:14 +02:00
|
|
|
#ifdef USE_MULTITHREADING
|
2023-01-13 15:58:11 +01:00
|
|
|
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET"\tTemps: %0.2f s\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100, elapsed_time);
|
2022-10-01 17:53:14 +02:00
|
|
|
#else
|
2023-01-13 15:58:11 +01:00
|
|
|
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET"\tTemps: %0.2f s\n", i, epochs, nb_images_total, nb_images_total, accuracy*100, elapsed_time);
|
2022-10-01 17:53:14 +02:00
|
|
|
#endif
|
|
|
|
write_network(out, network);
|
|
|
|
}
|
2022-12-07 10:44:28 +01:00
|
|
|
free(shuffle_index);
|
2022-10-07 14:26:36 +02:00
|
|
|
free_network(network);
|
2022-10-01 17:53:14 +02:00
|
|
|
#ifdef USE_MULTITHREADING
|
|
|
|
free(tid);
|
|
|
|
#else
|
|
|
|
free(train_params);
|
|
|
|
#endif
|
2023-01-13 15:58:11 +01:00
|
|
|
end_time = omp_get_wtime();
|
|
|
|
elapsed_time = end_time - algo_start;
|
|
|
|
printf("\nTemps total: %0.1f s\n", elapsed_time);
|
2022-10-01 17:53:14 +02:00
|
|
|
}
|