Fix various multithreading related issues

This commit is contained in:
augustin64 2023-01-14 15:02:57 +01:00
parent dd6fb046c7
commit df48b92cf2

View File

@ -205,74 +205,76 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
batches_epoques = div_up(nb_images_total, BATCHES); batches_epoques = div_up(nb_images_total, BATCHES);
nb_images_total_remaining = nb_images_total; nb_images_total_remaining = nb_images_total;
#ifndef USE_MULTITHREADING #ifndef USE_MULTITHREADING
train_params->nb_images = BATCHES; train_params->nb_images = BATCHES;
#endif #endif
for (int j=0; j < batches_epoques; j++) { for (int j=0; j < batches_epoques; j++) {
#ifdef USE_MULTITHREADING #ifdef USE_MULTITHREADING
if (j == batches_epoques-1) { if (j == batches_epoques-1) {
nb_remaining_images = nb_images_total_remaining; nb_remaining_images = nb_images_total_remaining;
nb_images_total_remaining = 0; nb_images_total_remaining = 0;
} else {
nb_images_total_remaining -= BATCHES;
nb_remaining_images = BATCHES;
}
for (int k=0; k < nb_threads; k++) {
if (k == nb_threads-1) {
train_parameters[k]->nb_images = nb_remaining_images;
nb_remaining_images = 0;
} else { } else {
nb_remaining_images -= BATCHES / nb_threads; nb_images_total_remaining -= BATCHES;
nb_remaining_images = BATCHES;
} }
train_parameters[k]->start = BATCHES*j + (BATCHES/nb_threads)*k;
train_parameters[k]->network = copy_network(network);
if (train_parameters[k]->start+train_parameters[k]->nb_images >= nb_images_total) { for (int k=0; k < nb_threads; k++) {
train_parameters[k]->nb_images = nb_images_total - train_parameters[k]->start -1; if (k == nb_threads-1) {
} train_parameters[k]->nb_images = nb_remaining_images;
if (train_parameters[k]->nb_images > 0) { nb_remaining_images = 0;
pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]); } else {
} else { nb_remaining_images -= BATCHES / nb_threads;
tid[k] = 0; }
} train_parameters[k]->start = BATCHES*j + (BATCHES/nb_threads)*k;
}
for (int k=0; k < nb_threads; k++) {
// On attend la terminaison de chaque thread un à un
if (tid[k] != 0) {
pthread_join( tid[k], NULL );
accuracy += train_parameters[k]->accuracy / (float) nb_images_total;
}
}
// On attend que tous les fils aient fini avant d'appliquer des modifications au réseau principal if (train_parameters[k]->start+train_parameters[k]->nb_images >= nb_images_total) {
for (int k=0; k < nb_threads; k++) { train_parameters[k]->nb_images = nb_images_total - train_parameters[k]->start -1;
update_weights(network, train_parameters[k]->network, train_parameters[k]->nb_images); }
update_bias(network, train_parameters[k]->network, train_parameters[k]->nb_images); if (train_parameters[k]->nb_images > 0) {
free_network(train_parameters[k]->network); train_parameters[k]->network = copy_network(network);
} pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]);
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES); } else {
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); train_parameters[k]->network = NULL;
fflush(stdout); }
}
for (int k=0; k < nb_threads; k++) {
// On attend la terminaison de chaque thread un à un
if (train_parameters[k]->network) {
pthread_join( tid[k], NULL );
accuracy += train_parameters[k]->accuracy / (float) nb_images_total;
}
}
// On attend que tous les fils aient fini avant d'appliquer des modifications au réseau principal
for (int k=0; k < nb_threads; k++) {
if (train_parameters[k]->network) { // Si le fil a été utilisé
update_weights(network, train_parameters[k]->network, train_parameters[k]->nb_images);
update_bias(network, train_parameters[k]->network, train_parameters[k]->nb_images);
free_network(train_parameters[k]->network);
}
}
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
fflush(stdout);
#else #else
(void)nb_images_total_remaining; // Juste pour enlever un warning (void)nb_images_total_remaining; // Juste pour enlever un warning
train_params->start = j*BATCHES; train_params->start = j*BATCHES;
// Ne pas dépasser le nombre d'images à cause de la partie entière // Ne pas dépasser le nombre d'images à cause de la partie entière
if (j == batches_epoques-1) { if (j == batches_epoques-1) {
train_params->nb_images = nb_images_total - j*BATCHES; train_params->nb_images = nb_images_total - j*BATCHES;
} }
train_thread((void*)train_params); train_thread((void*)train_params);
accuracy += train_params->accuracy / (float) nb_images_total; accuracy += train_params->accuracy / (float) nb_images_total;
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES); current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
update_weights(network, network, train_params->nb_images); update_weights(network, network, train_params->nb_images);
update_bias(network, network, train_params->nb_images); update_bias(network, network, train_params->nb_images);
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
fflush(stdout); fflush(stdout);
#endif #endif
} }
end_time = omp_get_wtime(); end_time = omp_get_wtime();