Update train.c

This commit is contained in:
augustin64 2022-11-23 11:37:26 +01:00
parent 65a91dd441
commit 4cafabdbee

View File

@ -91,7 +91,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
float accuracy;
float current_accuracy;
int nb_images_total;
int nb_images_total; // Images au total
int nb_images_total_remaining; // Images restantes dans un batch
int batches_epoques; // Batches par époque
int*** images;
unsigned int* labels;
@ -180,9 +182,17 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
// du multi-threading car chaque copie du réseau initiale sera légèrement différente
// et donnera donc des résultats différents sur les mêmes images.
accuracy = 0.;
for (int j=0; j < nb_images_total / BATCHES; j++) {
batches_epoques = nb_images_total / BATCHES;
nb_images_total_remaining = nb_images_total;
for (int j=0; j < batches_epoques; j++) {
#ifdef USE_MULTITHREADING
nb_remaining_images = BATCHES;
if (j == batches_epoques-1) {
nb_remaining_images = nb_images_total_remaining;
nb_images_total_remaining = 0;
} else {
nb_images_total_remaining -= BATCHES;
nb_remaining_images = BATCHES;
}
for (int k=0; k < nb_threads; k++) {
if (k == nb_threads-1) {
@ -191,8 +201,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
} else {
nb_remaining_images -= BATCHES / nb_threads;
}
train_parameters[k]->start = BATCHES*j + (BATCHES/nb_threads)*k;
train_parameters[k]->network = copy_network(network);
train_parameters[k]->start = BATCHES*j + (nb_images_total/BATCHES)*k;
pthread_create( &tid[k], NULL, train_thread, (void*) train_parameters[k]);
}
for (int k=0; k < nb_threads; k++) {
@ -206,6 +217,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
}
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
fflush(stdout);
#else
train_params->start = j*BATCHES;