mirror of
https://github.com/augustin64/projet-tipe
synced 2025-01-24 15:36:25 +01:00
Fix various multithreading related issues
This commit is contained in:
parent
dd6fb046c7
commit
df48b92cf2
110
src/cnn/train.c
110
src/cnn/train.c
@ -205,74 +205,76 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
batches_epoques = div_up(nb_images_total, BATCHES);
|
batches_epoques = div_up(nb_images_total, BATCHES);
|
||||||
nb_images_total_remaining = nb_images_total;
|
nb_images_total_remaining = nb_images_total;
|
||||||
#ifndef USE_MULTITHREADING
|
#ifndef USE_MULTITHREADING
|
||||||
train_params->nb_images = BATCHES;
|
train_params->nb_images = BATCHES;
|
||||||
#endif
|
#endif
|
||||||
for (int j=0; j < batches_epoques; j++) {
|
for (int j=0; j < batches_epoques; j++) {
|
||||||
#ifdef USE_MULTITHREADING
|
#ifdef USE_MULTITHREADING
|
||||||
if (j == batches_epoques-1) {
|
if (j == batches_epoques-1) {
|
||||||
nb_remaining_images = nb_images_total_remaining;
|
nb_remaining_images = nb_images_total_remaining;
|
||||||
nb_images_total_remaining = 0;
|
nb_images_total_remaining = 0;
|
||||||
} else {
|
|
||||||
nb_images_total_remaining -= BATCHES;
|
|
||||||
nb_remaining_images = BATCHES;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int k=0; k < nb_threads; k++) {
|
|
||||||
if (k == nb_threads-1) {
|
|
||||||
train_parameters[k]->nb_images = nb_remaining_images;
|
|
||||||
nb_remaining_images = 0;
|
|
||||||
} else {
|
} else {
|
||||||
nb_remaining_images -= BATCHES / nb_threads;
|
nb_images_total_remaining -= BATCHES;
|
||||||
|
nb_remaining_images = BATCHES;
|
||||||
}
|
}
|
||||||
train_parameters[k]->start = BATCHES*j + (BATCHES/nb_threads)*k;
|
|
||||||
train_parameters[k]->network = copy_network(network);
|
|
||||||
|
|
||||||
if (train_parameters[k]->start+train_parameters[k]->nb_images >= nb_images_total) {
|
for (int k=0; k < nb_threads; k++) {
|
||||||
train_parameters[k]->nb_images = nb_images_total - train_parameters[k]->start -1;
|
if (k == nb_threads-1) {
|
||||||
}
|
train_parameters[k]->nb_images = nb_remaining_images;
|
||||||
if (train_parameters[k]->nb_images > 0) {
|
nb_remaining_images = 0;
|
||||||
pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]);
|
} else {
|
||||||
} else {
|
nb_remaining_images -= BATCHES / nb_threads;
|
||||||
tid[k] = 0;
|
}
|
||||||
}
|
train_parameters[k]->start = BATCHES*j + (BATCHES/nb_threads)*k;
|
||||||
}
|
|
||||||
for (int k=0; k < nb_threads; k++) {
|
|
||||||
// On attend la terminaison de chaque thread un à un
|
|
||||||
if (tid[k] != 0) {
|
|
||||||
pthread_join( tid[k], NULL );
|
|
||||||
accuracy += train_parameters[k]->accuracy / (float) nb_images_total;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// On attend que tous les fils aient fini avant d'appliquer des modifications au réseau principal
|
if (train_parameters[k]->start+train_parameters[k]->nb_images >= nb_images_total) {
|
||||||
for (int k=0; k < nb_threads; k++) {
|
train_parameters[k]->nb_images = nb_images_total - train_parameters[k]->start -1;
|
||||||
update_weights(network, train_parameters[k]->network, train_parameters[k]->nb_images);
|
}
|
||||||
update_bias(network, train_parameters[k]->network, train_parameters[k]->nb_images);
|
if (train_parameters[k]->nb_images > 0) {
|
||||||
free_network(train_parameters[k]->network);
|
train_parameters[k]->network = copy_network(network);
|
||||||
}
|
pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]);
|
||||||
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
|
} else {
|
||||||
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
train_parameters[k]->network = NULL;
|
||||||
fflush(stdout);
|
}
|
||||||
|
}
|
||||||
|
for (int k=0; k < nb_threads; k++) {
|
||||||
|
// On attend la terminaison de chaque thread un à un
|
||||||
|
if (train_parameters[k]->network) {
|
||||||
|
pthread_join( tid[k], NULL );
|
||||||
|
accuracy += train_parameters[k]->accuracy / (float) nb_images_total;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// On attend que tous les fils aient fini avant d'appliquer des modifications au réseau principal
|
||||||
|
for (int k=0; k < nb_threads; k++) {
|
||||||
|
if (train_parameters[k]->network) { // Si le fil a été utilisé
|
||||||
|
update_weights(network, train_parameters[k]->network, train_parameters[k]->nb_images);
|
||||||
|
update_bias(network, train_parameters[k]->network, train_parameters[k]->nb_images);
|
||||||
|
free_network(train_parameters[k]->network);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
|
||||||
|
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
||||||
|
fflush(stdout);
|
||||||
#else
|
#else
|
||||||
(void)nb_images_total_remaining; // Juste pour enlever un warning
|
(void)nb_images_total_remaining; // Juste pour enlever un warning
|
||||||
|
|
||||||
train_params->start = j*BATCHES;
|
train_params->start = j*BATCHES;
|
||||||
|
|
||||||
// Ne pas dépasser le nombre d'images à cause de la partie entière
|
// Ne pas dépasser le nombre d'images à cause de la partie entière
|
||||||
if (j == batches_epoques-1) {
|
if (j == batches_epoques-1) {
|
||||||
train_params->nb_images = nb_images_total - j*BATCHES;
|
train_params->nb_images = nb_images_total - j*BATCHES;
|
||||||
}
|
}
|
||||||
|
|
||||||
train_thread((void*)train_params);
|
train_thread((void*)train_params);
|
||||||
|
|
||||||
accuracy += train_params->accuracy / (float) nb_images_total;
|
accuracy += train_params->accuracy / (float) nb_images_total;
|
||||||
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
|
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
|
||||||
|
|
||||||
update_weights(network, network, train_params->nb_images);
|
update_weights(network, network, train_params->nb_images);
|
||||||
update_bias(network, network, train_params->nb_images);
|
update_bias(network, network, train_params->nb_images);
|
||||||
|
|
||||||
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
end_time = omp_get_wtime();
|
end_time = omp_get_wtime();
|
||||||
|
Loading…
Reference in New Issue
Block a user