This commit is contained in:
julienChemillier 2022-11-17 13:02:34 +01:00
commit ea03e3a464
2 changed files with 14 additions and 12 deletions

View File

@ -4,7 +4,7 @@
#define DEF_TRAIN_H
#define EPOCHS 10
#define BATCHES 100
#define BATCHES 120
#define USE_MULTITHREADING

View File

@ -3,6 +3,7 @@
#include <float.h>
#include <pthread.h>
#include <sys/sysinfo.h>
#include <time.h>
#include "../mnist/include/mnist.h"
#include "include/initialisation.h"
@ -46,19 +47,13 @@ void* train_thread(void* parameters) {
int start = param->start;
int nb_images = param->nb_images;
float accuracy = 0.;
int cpt=1;
for (int i=start; i < start+nb_images; i++) {
if (dataset_type == 0) {
write_image_in_network_32(images[i], height, width, network->input[0][0]);
forward_propagation(network);
maxi = indice_max(network->input[network->size-1][0][0], 10);
backward_propagation(network, labels[i]);
if (cpt==16) { // Update the network
update_weights(network);
update_bias(network);
cpt = 0;
}
cpt++;
if (maxi == labels[i]) {
accuracy += 1.;
}
@ -74,6 +69,7 @@ void* train_thread(void* parameters) {
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out) {
srand(time(NULL));
int input_dim = -1;
int input_depth = -1;
float accuracy;
@ -190,19 +186,25 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
free_network(train_parameters[k]->network);
}
current_accuracy = accuracy * nb_images_total/(j*BATCHES);
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
#else
train_params->start = j*BATCHES;
train_thread((void*)train_params);
accuracy += train_params->accuracy / (float) nb_images_total;
current_accuracy = accuracy * nb_images_total/(j*BATCHES);
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
update_weights(network);
update_bias(network);
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
#endif
}
#ifdef USE_MULTITHREADING
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100);
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET" \n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100);
#else
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", i, epochs, nb_images_total, nb_images_total, accuracy*100);
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET" \n", i, epochs, nb_images_total, nb_images_total, accuracy*100);
#endif
write_network(out, network);
}