mirror of
https://github.com/augustin64/projet-tipe
synced 2025-01-24 07:36:24 +01:00
Move update
This commit is contained in:
parent
50af676027
commit
d03f7493b2
@ -4,7 +4,7 @@
|
||||
#define DEF_TRAIN_H
|
||||
|
||||
#define EPOCHS 10
|
||||
#define BATCHES 100
|
||||
#define BATCHES 120
|
||||
#define USE_MULTITHREADING
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <float.h>
|
||||
#include <pthread.h>
|
||||
#include <sys/sysinfo.h>
|
||||
#include <time.h>
|
||||
|
||||
#include "../mnist/include/mnist.h"
|
||||
#include "include/initialisation.h"
|
||||
@ -46,19 +47,13 @@ void* train_thread(void* parameters) {
|
||||
int start = param->start;
|
||||
int nb_images = param->nb_images;
|
||||
float accuracy = 0.;
|
||||
int cpt=1;
|
||||
for (int i=start; i < start+nb_images; i++) {
|
||||
if (dataset_type == 0) {
|
||||
write_image_in_network_32(images[i], height, width, network->input[0][0]);
|
||||
forward_propagation(network);
|
||||
maxi = indice_max(network->input[network->size-1][0][0], 10);
|
||||
backward_propagation(network, labels[i]);
|
||||
if (cpt==16) { // Update the network
|
||||
update_weights(network);
|
||||
update_bias(network);
|
||||
cpt = 0;
|
||||
}
|
||||
cpt++;
|
||||
|
||||
if (maxi == labels[i]) {
|
||||
accuracy += 1.;
|
||||
}
|
||||
@ -74,6 +69,7 @@ void* train_thread(void* parameters) {
|
||||
|
||||
|
||||
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out) {
|
||||
srand(time(NULL));
|
||||
int input_dim = -1;
|
||||
int input_depth = -1;
|
||||
float accuracy;
|
||||
@ -190,19 +186,25 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
||||
free_network(train_parameters[k]->network);
|
||||
}
|
||||
current_accuracy = accuracy * nb_images_total/(j*BATCHES);
|
||||
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
||||
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
||||
#else
|
||||
train_params->start = j*BATCHES;
|
||||
|
||||
train_thread((void*)train_params);
|
||||
|
||||
accuracy += train_params->accuracy / (float) nb_images_total;
|
||||
current_accuracy = accuracy * nb_images_total/(j*BATCHES);
|
||||
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
||||
|
||||
update_weights(network);
|
||||
update_bias(network);
|
||||
|
||||
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
||||
#endif
|
||||
}
|
||||
#ifdef USE_MULTITHREADING
|
||||
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100);
|
||||
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET" \n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100);
|
||||
#else
|
||||
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", i, epochs, nb_images_total, nb_images_total, accuracy*100);
|
||||
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET" \n", i, epochs, nb_images_total, nb_images_total, accuracy*100);
|
||||
#endif
|
||||
write_network(out, network);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user