Move update

This commit is contained in:
augustin64 2022-11-16 10:38:01 +01:00
parent 50af676027
commit d03f7493b2
2 changed files with 14 additions and 12 deletions

View File

@ -4,7 +4,7 @@
#define DEF_TRAIN_H #define DEF_TRAIN_H
#define EPOCHS 10 #define EPOCHS 10
#define BATCHES 100 #define BATCHES 120
#define USE_MULTITHREADING #define USE_MULTITHREADING

View File

@ -3,6 +3,7 @@
#include <float.h> #include <float.h>
#include <pthread.h> #include <pthread.h>
#include <sys/sysinfo.h> #include <sys/sysinfo.h>
#include <time.h>
#include "../mnist/include/mnist.h" #include "../mnist/include/mnist.h"
#include "include/initialisation.h" #include "include/initialisation.h"
@ -46,19 +47,13 @@ void* train_thread(void* parameters) {
int start = param->start; int start = param->start;
int nb_images = param->nb_images; int nb_images = param->nb_images;
float accuracy = 0.; float accuracy = 0.;
int cpt=1;
for (int i=start; i < start+nb_images; i++) { for (int i=start; i < start+nb_images; i++) {
if (dataset_type == 0) { if (dataset_type == 0) {
write_image_in_network_32(images[i], height, width, network->input[0][0]); write_image_in_network_32(images[i], height, width, network->input[0][0]);
forward_propagation(network); forward_propagation(network);
maxi = indice_max(network->input[network->size-1][0][0], 10); maxi = indice_max(network->input[network->size-1][0][0], 10);
backward_propagation(network, labels[i]); backward_propagation(network, labels[i]);
if (cpt==16) { // Update the network
update_weights(network);
update_bias(network);
cpt = 0;
}
cpt++;
if (maxi == labels[i]) { if (maxi == labels[i]) {
accuracy += 1.; accuracy += 1.;
} }
@ -74,6 +69,7 @@ void* train_thread(void* parameters) {
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out) { void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out) {
srand(time(NULL));
int input_dim = -1; int input_dim = -1;
int input_depth = -1; int input_depth = -1;
float accuracy; float accuracy;
@ -190,19 +186,25 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
free_network(train_parameters[k]->network); free_network(train_parameters[k]->network);
} }
current_accuracy = accuracy * nb_images_total/(j*BATCHES); current_accuracy = accuracy * nb_images_total/(j*BATCHES);
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100); printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
#else #else
train_params->start = j*BATCHES; train_params->start = j*BATCHES;
train_thread((void*)train_params); train_thread((void*)train_params);
accuracy += train_params->accuracy / (float) nb_images_total; accuracy += train_params->accuracy / (float) nb_images_total;
current_accuracy = accuracy * nb_images_total/(j*BATCHES); current_accuracy = accuracy * nb_images_total/(j*BATCHES);
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
update_weights(network);
update_bias(network);
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
#endif #endif
} }
#ifdef USE_MULTITHREADING #ifdef USE_MULTITHREADING
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100); printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET" \n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100);
#else #else
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", i, epochs, nb_images_total, nb_images_total, accuracy*100); printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET" \n", i, epochs, nb_images_total, nb_images_total, accuracy*100);
#endif #endif
write_network(out, network); write_network(out, network);
} }