Add copy_network_parameters

This commit is contained in:
augustin64 2023-01-28 13:09:52 +01:00
parent 01459a729e
commit fe880f9aae
3 changed files with 90 additions and 14 deletions

View File

@ -24,8 +24,12 @@ void knuth_shuffle(int* tab, int n);
bool equals_networks(Network* network1, Network* network2); bool equals_networks(Network* network1, Network* network2);
/* /*
* Duplique un réseau * Duplique un réseau
*/ */
Network* copy_network(Network* network); Network* copy_network(Network* network);
/*
* Copie les paramètres d'un réseau dans un réseau déjà alloué en mémoire
*/
void copy_network_parameters(Network* network_src, Network* network_dest);
#endif #endif

View File

@ -1,8 +1,9 @@
#include <sys/sysinfo.h>
#include <pthread.h>
#include <stdlib.h> #include <stdlib.h>
#include <stdio.h> #include <stdio.h>
#include <float.h> #include <float.h>
#include <pthread.h> #include <math.h>
#include <sys/sysinfo.h>
#include <time.h> #include <time.h>
#include <omp.h> #include <omp.h>
@ -139,8 +140,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
// Initialisation du réseau // Initialisation du réseau
if (!recover) { if (!recover) {
network = create_network_lenet5(LEARNING_RATE, 0, RELU, GLOROT, input_dim, input_depth); // Le nouveau TA calculé à partir du loss est majoré par 0.75*TA
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_dim, input_depth); network = create_network_lenet5(LEARNING_RATE*0.75, 0, TANH, GLOROT, input_dim, input_depth);
//network = create_simple_one(LEARNING_RATE*0.75, 0, RELU, GLOROT, input_dim, input_depth);
} else { } else {
network = read_network(recover); network = read_network(recover);
network->learning_rate = LEARNING_RATE; network->learning_rate = LEARNING_RATE;
@ -161,6 +163,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
// Création des paramètres donnés à chaque thread dans le cas du multi-threading // Création des paramètres donnés à chaque thread dans le cas du multi-threading
TrainParameters** train_parameters = (TrainParameters**)malloc(sizeof(TrainParameters*)*nb_threads); TrainParameters** train_parameters = (TrainParameters**)malloc(sizeof(TrainParameters*)*nb_threads);
TrainParameters* param; TrainParameters* param;
bool* thread_used = (bool*)malloc(sizeof(bool)*nb_threads);
for (int k=0; k < nb_threads; k++) { for (int k=0; k < nb_threads; k++) {
train_parameters[k] = (TrainParameters*)malloc(sizeof(TrainParameters)); train_parameters[k] = (TrainParameters*)malloc(sizeof(TrainParameters));
@ -181,6 +184,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
} }
param->nb_images = BATCHES / nb_threads; param->nb_images = BATCHES / nb_threads;
param->index = shuffle_index; param->index = shuffle_index;
param->network = copy_network(network);
} }
#else #else
// Création des paramètres donnés à l'unique // Création des paramètres donnés à l'unique
@ -210,6 +214,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
end_time = omp_get_wtime(); end_time = omp_get_wtime();
elapsed_time = end_time - start_time; elapsed_time = end_time - start_time;
printf("Taux d'apprentissage initial: %lf\n", network->learning_rate);
printf("Initialisation: %0.2lf s\n\n", elapsed_time); printf("Initialisation: %0.2lf s\n\n", elapsed_time);
for (int i=0; i < epochs; i++) { for (int i=0; i < epochs; i++) {
@ -253,15 +258,16 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
train_parameters[k]->nb_images = nb_images_total - train_parameters[k]->start -1; train_parameters[k]->nb_images = nb_images_total - train_parameters[k]->start -1;
} }
if (train_parameters[k]->nb_images > 0) { if (train_parameters[k]->nb_images > 0) {
train_parameters[k]->network = copy_network(network); thread_used[k] = true;
copy_network_parameters(network, train_parameters[k]->network);
pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]); pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]);
} else { } else {
train_parameters[k]->network = NULL; thread_used[k] = false;
} }
} }
for (int k=0; k < nb_threads; k++) { for (int k=0; k < nb_threads; k++) {
// On attend la terminaison de chaque thread un à un // On attend la terminaison de chaque thread un à un
if (train_parameters[k]->network) { if (thread_used[k]) {
pthread_join( tid[k], NULL ); pthread_join( tid[k], NULL );
accuracy += train_parameters[k]->accuracy / (float) nb_images_total; accuracy += train_parameters[k]->accuracy / (float) nb_images_total;
loss += train_parameters[k]->loss/nb_images_total; loss += train_parameters[k]->loss/nb_images_total;
@ -274,11 +280,10 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
if (train_parameters[k]->network) { // Si le fil a été utilisé if (train_parameters[k]->network) { // Si le fil a été utilisé
update_weights(network, train_parameters[k]->network); update_weights(network, train_parameters[k]->network);
update_bias(network, train_parameters[k]->network); update_bias(network, train_parameters[k]->network);
free_network(train_parameters[k]->network);
} }
} }
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES); current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET" \tRéussies: %d", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100, (int)(accuracy*nb_images_total)); printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET, nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
fflush(stdout); fflush(stdout);
#else #else
(void)nb_images_total_remaining; // Juste pour enlever un warning (void)nb_images_total_remaining; // Juste pour enlever un warning
@ -300,27 +305,36 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
update_weights(network, network); update_weights(network, network);
update_bias(network, network); update_bias(network, network);
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET" \tRéussies: %d", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100, (int)(accuracy*nb_images_total)); printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
fflush(stdout); fflush(stdout);
#endif #endif
// Il serait intéressant d'utiliser la perte calculée pour // Il serait intéressant d'utiliser la perte calculée pour
// savoir l'avancement dans l'apprentissage et donc comment adapter le taux d'apprentissage // savoir l'avancement dans l'apprentissage et donc comment adapter le taux d'apprentissage
network->learning_rate = 10*LEARNING_RATE*batch_loss; network->learning_rate = LEARNING_RATE*log(batch_loss+1);
} }
end_time = omp_get_wtime(); end_time = omp_get_wtime();
elapsed_time = end_time - start_time; elapsed_time = end_time - start_time;
#ifdef USE_MULTITHREADING #ifdef USE_MULTITHREADING
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET" \tRéussies: %d\tTemps: %0.2f s\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100, (int)(accuracy*nb_images_total), elapsed_time); printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET" \tTemps: %0.2f s\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100, elapsed_time);
#else #else
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET" \tRéussies: %d\tTemps: %0.2f s\n", i, epochs, nb_images_total, nb_images_total, accuracy*100, (int)(accuracy*nb_images_total), elapsed_time); printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET" \tTemps: %0.2f s\n", i, epochs, nb_images_total, nb_images_total, accuracy*100, elapsed_time);
#endif #endif
write_network(out, network); write_network(out, network);
} }
// To generate a new neural and compare performances with scripts/benchmark_binary.py
if (epochs == 0) {
write_network(out, network);
}
free(shuffle_index); free(shuffle_index);
free_network(network); free_network(network);
#ifdef USE_MULTITHREADING #ifdef USE_MULTITHREADING
free(tid); free(tid);
for (int i=0; i < nb_threads; i++) {
free(train_parameters[i]->network);
}
free(train_parameters);
#else #else
free(train_params); free(train_params);
#endif #endif

View File

@ -7,6 +7,7 @@
#include "include/struct.h" #include "include/struct.h"
#define copyVar(var) network_cp->var = network->var #define copyVar(var) network_cp->var = network->var
#define copyVarParams(var) network_dest->var = network_src->var
#define checkEquals(var, name, indice) \ #define checkEquals(var, name, indice) \
if (network1->var != network2->var) { \ if (network1->var != network2->var) { \
@ -241,3 +242,60 @@ Network* copy_network(Network* network) {
return network_cp; return network_cp;
} }
void copy_network_parameters(Network* network_src, Network* network_dest) {
// Paramètre du réseau
int size = network_src->size;
// Paramètres des couches NN
int input_units;
int output_units;
// Paramètres des couches CNN
int rows;
int k_size;
int columns;
int output_dim;
copyVarParams(learning_rate);
for (int i=0; i < size-1; i++) {
if (!network_src->kernel[i]->cnn && network_src->kernel[i]->nn) { // Cas du NN
input_units = network_src->kernel[i]->nn->input_units;
output_units = network_src->kernel[i]->nn->output_units;
for (int j=0; j < output_units; j++) {
copyVarParams(kernel[i]->nn->bias[j]);
}
for (int j=0; j < input_units; j++) {
for (int k=0; k < output_units; k++) {
copyVarParams(kernel[i]->nn->weights[j][k]);
}
}
}
else if (network_src->kernel[i]->cnn && !network_src->kernel[i]->nn) { // Cas du CNN
rows = network_src->kernel[i]->cnn->rows;
k_size = network_src->kernel[i]->cnn->k_size;
columns = network_src->kernel[i]->cnn->columns;
output_dim = network_src->width[i+1];
for (int j=0; j < columns; j++) {
for (int k=0; k < output_dim; k++) {
for (int l=0; l < output_dim; l++) {
copyVarParams(kernel[i]->cnn->bias[j][k][l]);
}
}
}
for (int j=0; j < rows; j++) {
for (int k=0; k < columns; k++) {
for (int l=0; l < k_size; l++) {
for (int m=0; m < k_size; m++) {
copyVarParams(kernel[i]->cnn->w[j][k][l][m]);
}
}
}
}
}
}
}