mirror of
https://github.com/augustin64/projet-tipe
synced 2025-01-24 07:36:24 +01:00
Add copy_network_parameters
This commit is contained in:
parent
01459a729e
commit
fe880f9aae
@ -24,8 +24,12 @@ void knuth_shuffle(int* tab, int n);
|
|||||||
bool equals_networks(Network* network1, Network* network2);
|
bool equals_networks(Network* network1, Network* network2);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Duplique un réseau
|
* Duplique un réseau
|
||||||
*/
|
*/
|
||||||
Network* copy_network(Network* network);
|
Network* copy_network(Network* network);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Copie les paramètres d'un réseau dans un réseau déjà alloué en mémoire
|
||||||
|
*/
|
||||||
|
void copy_network_parameters(Network* network_src, Network* network_dest);
|
||||||
#endif
|
#endif
|
@ -1,8 +1,9 @@
|
|||||||
|
#include <sys/sysinfo.h>
|
||||||
|
#include <pthread.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <float.h>
|
#include <float.h>
|
||||||
#include <pthread.h>
|
#include <math.h>
|
||||||
#include <sys/sysinfo.h>
|
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
|
|
||||||
@ -139,8 +140,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
|
|
||||||
// Initialisation du réseau
|
// Initialisation du réseau
|
||||||
if (!recover) {
|
if (!recover) {
|
||||||
network = create_network_lenet5(LEARNING_RATE, 0, RELU, GLOROT, input_dim, input_depth);
|
// Le nouveau TA calculé à partir du loss est majoré par 0.75*TA
|
||||||
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_dim, input_depth);
|
network = create_network_lenet5(LEARNING_RATE*0.75, 0, TANH, GLOROT, input_dim, input_depth);
|
||||||
|
//network = create_simple_one(LEARNING_RATE*0.75, 0, RELU, GLOROT, input_dim, input_depth);
|
||||||
} else {
|
} else {
|
||||||
network = read_network(recover);
|
network = read_network(recover);
|
||||||
network->learning_rate = LEARNING_RATE;
|
network->learning_rate = LEARNING_RATE;
|
||||||
@ -161,6 +163,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
// Création des paramètres donnés à chaque thread dans le cas du multi-threading
|
// Création des paramètres donnés à chaque thread dans le cas du multi-threading
|
||||||
TrainParameters** train_parameters = (TrainParameters**)malloc(sizeof(TrainParameters*)*nb_threads);
|
TrainParameters** train_parameters = (TrainParameters**)malloc(sizeof(TrainParameters*)*nb_threads);
|
||||||
TrainParameters* param;
|
TrainParameters* param;
|
||||||
|
bool* thread_used = (bool*)malloc(sizeof(bool)*nb_threads);
|
||||||
|
|
||||||
for (int k=0; k < nb_threads; k++) {
|
for (int k=0; k < nb_threads; k++) {
|
||||||
train_parameters[k] = (TrainParameters*)malloc(sizeof(TrainParameters));
|
train_parameters[k] = (TrainParameters*)malloc(sizeof(TrainParameters));
|
||||||
@ -181,6 +184,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
}
|
}
|
||||||
param->nb_images = BATCHES / nb_threads;
|
param->nb_images = BATCHES / nb_threads;
|
||||||
param->index = shuffle_index;
|
param->index = shuffle_index;
|
||||||
|
param->network = copy_network(network);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
// Création des paramètres donnés à l'unique
|
// Création des paramètres donnés à l'unique
|
||||||
@ -210,6 +214,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
end_time = omp_get_wtime();
|
end_time = omp_get_wtime();
|
||||||
|
|
||||||
elapsed_time = end_time - start_time;
|
elapsed_time = end_time - start_time;
|
||||||
|
printf("Taux d'apprentissage initial: %lf\n", network->learning_rate);
|
||||||
printf("Initialisation: %0.2lf s\n\n", elapsed_time);
|
printf("Initialisation: %0.2lf s\n\n", elapsed_time);
|
||||||
|
|
||||||
for (int i=0; i < epochs; i++) {
|
for (int i=0; i < epochs; i++) {
|
||||||
@ -253,15 +258,16 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
train_parameters[k]->nb_images = nb_images_total - train_parameters[k]->start -1;
|
train_parameters[k]->nb_images = nb_images_total - train_parameters[k]->start -1;
|
||||||
}
|
}
|
||||||
if (train_parameters[k]->nb_images > 0) {
|
if (train_parameters[k]->nb_images > 0) {
|
||||||
train_parameters[k]->network = copy_network(network);
|
thread_used[k] = true;
|
||||||
|
copy_network_parameters(network, train_parameters[k]->network);
|
||||||
pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]);
|
pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]);
|
||||||
} else {
|
} else {
|
||||||
train_parameters[k]->network = NULL;
|
thread_used[k] = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int k=0; k < nb_threads; k++) {
|
for (int k=0; k < nb_threads; k++) {
|
||||||
// On attend la terminaison de chaque thread un à un
|
// On attend la terminaison de chaque thread un à un
|
||||||
if (train_parameters[k]->network) {
|
if (thread_used[k]) {
|
||||||
pthread_join( tid[k], NULL );
|
pthread_join( tid[k], NULL );
|
||||||
accuracy += train_parameters[k]->accuracy / (float) nb_images_total;
|
accuracy += train_parameters[k]->accuracy / (float) nb_images_total;
|
||||||
loss += train_parameters[k]->loss/nb_images_total;
|
loss += train_parameters[k]->loss/nb_images_total;
|
||||||
@ -274,11 +280,10 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
if (train_parameters[k]->network) { // Si le fil a été utilisé
|
if (train_parameters[k]->network) { // Si le fil a été utilisé
|
||||||
update_weights(network, train_parameters[k]->network);
|
update_weights(network, train_parameters[k]->network);
|
||||||
update_bias(network, train_parameters[k]->network);
|
update_bias(network, train_parameters[k]->network);
|
||||||
free_network(train_parameters[k]->network);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
|
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
|
||||||
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET" \tRéussies: %d", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100, (int)(accuracy*nb_images_total));
|
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET, nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
#else
|
#else
|
||||||
(void)nb_images_total_remaining; // Juste pour enlever un warning
|
(void)nb_images_total_remaining; // Juste pour enlever un warning
|
||||||
@ -300,27 +305,36 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
update_weights(network, network);
|
update_weights(network, network);
|
||||||
update_bias(network, network);
|
update_bias(network, network);
|
||||||
|
|
||||||
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET" \tRéussies: %d", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100, (int)(accuracy*nb_images_total));
|
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
#endif
|
#endif
|
||||||
// Il serait intéressant d'utiliser la perte calculée pour
|
// Il serait intéressant d'utiliser la perte calculée pour
|
||||||
// savoir l'avancement dans l'apprentissage et donc comment adapter le taux d'apprentissage
|
// savoir l'avancement dans l'apprentissage et donc comment adapter le taux d'apprentissage
|
||||||
network->learning_rate = 10*LEARNING_RATE*batch_loss;
|
network->learning_rate = LEARNING_RATE*log(batch_loss+1);
|
||||||
}
|
}
|
||||||
end_time = omp_get_wtime();
|
end_time = omp_get_wtime();
|
||||||
elapsed_time = end_time - start_time;
|
elapsed_time = end_time - start_time;
|
||||||
#ifdef USE_MULTITHREADING
|
#ifdef USE_MULTITHREADING
|
||||||
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET" \tRéussies: %d\tTemps: %0.2f s\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100, (int)(accuracy*nb_images_total), elapsed_time);
|
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET" \tTemps: %0.2f s\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100, elapsed_time);
|
||||||
#else
|
#else
|
||||||
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET" \tRéussies: %d\tTemps: %0.2f s\n", i, epochs, nb_images_total, nb_images_total, accuracy*100, (int)(accuracy*nb_images_total), elapsed_time);
|
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET" \tTemps: %0.2f s\n", i, epochs, nb_images_total, nb_images_total, accuracy*100, elapsed_time);
|
||||||
#endif
|
#endif
|
||||||
write_network(out, network);
|
write_network(out, network);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// To generate a new neural and compare performances with scripts/benchmark_binary.py
|
||||||
|
if (epochs == 0) {
|
||||||
|
write_network(out, network);
|
||||||
|
}
|
||||||
free(shuffle_index);
|
free(shuffle_index);
|
||||||
free_network(network);
|
free_network(network);
|
||||||
|
|
||||||
#ifdef USE_MULTITHREADING
|
#ifdef USE_MULTITHREADING
|
||||||
free(tid);
|
free(tid);
|
||||||
|
for (int i=0; i < nb_threads; i++) {
|
||||||
|
free(train_parameters[i]->network);
|
||||||
|
}
|
||||||
|
free(train_parameters);
|
||||||
#else
|
#else
|
||||||
free(train_params);
|
free(train_params);
|
||||||
#endif
|
#endif
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include "include/struct.h"
|
#include "include/struct.h"
|
||||||
|
|
||||||
#define copyVar(var) network_cp->var = network->var
|
#define copyVar(var) network_cp->var = network->var
|
||||||
|
#define copyVarParams(var) network_dest->var = network_src->var
|
||||||
|
|
||||||
#define checkEquals(var, name, indice) \
|
#define checkEquals(var, name, indice) \
|
||||||
if (network1->var != network2->var) { \
|
if (network1->var != network2->var) { \
|
||||||
@ -241,3 +242,60 @@ Network* copy_network(Network* network) {
|
|||||||
|
|
||||||
return network_cp;
|
return network_cp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void copy_network_parameters(Network* network_src, Network* network_dest) {
|
||||||
|
// Paramètre du réseau
|
||||||
|
int size = network_src->size;
|
||||||
|
// Paramètres des couches NN
|
||||||
|
int input_units;
|
||||||
|
int output_units;
|
||||||
|
// Paramètres des couches CNN
|
||||||
|
int rows;
|
||||||
|
int k_size;
|
||||||
|
int columns;
|
||||||
|
int output_dim;
|
||||||
|
|
||||||
|
copyVarParams(learning_rate);
|
||||||
|
|
||||||
|
for (int i=0; i < size-1; i++) {
|
||||||
|
if (!network_src->kernel[i]->cnn && network_src->kernel[i]->nn) { // Cas du NN
|
||||||
|
|
||||||
|
input_units = network_src->kernel[i]->nn->input_units;
|
||||||
|
output_units = network_src->kernel[i]->nn->output_units;
|
||||||
|
|
||||||
|
for (int j=0; j < output_units; j++) {
|
||||||
|
copyVarParams(kernel[i]->nn->bias[j]);
|
||||||
|
}
|
||||||
|
for (int j=0; j < input_units; j++) {
|
||||||
|
for (int k=0; k < output_units; k++) {
|
||||||
|
copyVarParams(kernel[i]->nn->weights[j][k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (network_src->kernel[i]->cnn && !network_src->kernel[i]->nn) { // Cas du CNN
|
||||||
|
|
||||||
|
rows = network_src->kernel[i]->cnn->rows;
|
||||||
|
k_size = network_src->kernel[i]->cnn->k_size;
|
||||||
|
columns = network_src->kernel[i]->cnn->columns;
|
||||||
|
output_dim = network_src->width[i+1];
|
||||||
|
|
||||||
|
for (int j=0; j < columns; j++) {
|
||||||
|
for (int k=0; k < output_dim; k++) {
|
||||||
|
for (int l=0; l < output_dim; l++) {
|
||||||
|
copyVarParams(kernel[i]->cnn->bias[j][k][l]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j=0; j < rows; j++) {
|
||||||
|
for (int k=0; k < columns; k++) {
|
||||||
|
for (int l=0; l < k_size; l++) {
|
||||||
|
for (int m=0; m < k_size; m++) {
|
||||||
|
copyVarParams(kernel[i]->cnn->w[j][k][l][m]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user