diff --git a/src/cnn/train.c b/src/cnn/train.c index bb807de..3e3dd80 100644 --- a/src/cnn/train.c +++ b/src/cnn/train.c @@ -133,7 +133,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di // Création des paramètres donnés à chaque thread dans le cas du multi-threading TrainParameters** train_parameters = (TrainParameters**)malloc(sizeof(TrainParameters*)*nb_threads); TrainParameters* param; - + for (int k=0; k < nb_threads; k++) { train_parameters[k] = (TrainParameters*)malloc(sizeof(TrainParameters)); param = train_parameters[k]; @@ -219,7 +219,10 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di // On attend la terminaison de chaque thread un à un pthread_join( tid[k], NULL ); accuracy += train_parameters[k]->accuracy / (float) nb_images_total; - + } + + // On attend que tous les fils aient fini avant d'appliquer des modifications au réseau principal + for (int k=0; k < nb_threads; k++) { update_weights(network, train_parameters[k]->network, train_parameters[k]->nb_images); update_bias(network, train_parameters[k]->network, train_parameters[k]->nb_images); free_network(train_parameters[k]->network);