Update train.c

This commit is contained in:
augustin64 2022-12-19 15:49:03 +01:00
parent 80e82c16f6
commit c913bf3195

View File

@ -133,7 +133,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
// Création des paramètres donnés à chaque thread dans le cas du multi-threading
TrainParameters** train_parameters = (TrainParameters**)malloc(sizeof(TrainParameters*)*nb_threads);
TrainParameters* param;
for (int k=0; k < nb_threads; k++) {
train_parameters[k] = (TrainParameters*)malloc(sizeof(TrainParameters));
param = train_parameters[k];
@ -219,7 +219,10 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
// On attend la terminaison de chaque thread un à un
pthread_join( tid[k], NULL );
accuracy += train_parameters[k]->accuracy / (float) nb_images_total;
}
// On attend que tous les fils aient fini avant d'appliquer des modifications au réseau principal
for (int k=0; k < nb_threads; k++) {
update_weights(network, train_parameters[k]->network, train_parameters[k]->nb_images);
update_bias(network, train_parameters[k]->network, train_parameters[k]->nb_images);
free_network(train_parameters[k]->network);