mirror of
https://github.com/augustin64/projet-tipe
synced 2025-01-24 07:36:24 +01:00
Move update
This commit is contained in:
parent
50af676027
commit
d03f7493b2
@ -4,7 +4,7 @@
|
|||||||
#define DEF_TRAIN_H
|
#define DEF_TRAIN_H
|
||||||
|
|
||||||
#define EPOCHS 10
|
#define EPOCHS 10
|
||||||
#define BATCHES 100
|
#define BATCHES 120
|
||||||
#define USE_MULTITHREADING
|
#define USE_MULTITHREADING
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include <float.h>
|
#include <float.h>
|
||||||
#include <pthread.h>
|
#include <pthread.h>
|
||||||
#include <sys/sysinfo.h>
|
#include <sys/sysinfo.h>
|
||||||
|
#include <time.h>
|
||||||
|
|
||||||
#include "../mnist/include/mnist.h"
|
#include "../mnist/include/mnist.h"
|
||||||
#include "include/initialisation.h"
|
#include "include/initialisation.h"
|
||||||
@ -46,19 +47,13 @@ void* train_thread(void* parameters) {
|
|||||||
int start = param->start;
|
int start = param->start;
|
||||||
int nb_images = param->nb_images;
|
int nb_images = param->nb_images;
|
||||||
float accuracy = 0.;
|
float accuracy = 0.;
|
||||||
int cpt=1;
|
|
||||||
for (int i=start; i < start+nb_images; i++) {
|
for (int i=start; i < start+nb_images; i++) {
|
||||||
if (dataset_type == 0) {
|
if (dataset_type == 0) {
|
||||||
write_image_in_network_32(images[i], height, width, network->input[0][0]);
|
write_image_in_network_32(images[i], height, width, network->input[0][0]);
|
||||||
forward_propagation(network);
|
forward_propagation(network);
|
||||||
maxi = indice_max(network->input[network->size-1][0][0], 10);
|
maxi = indice_max(network->input[network->size-1][0][0], 10);
|
||||||
backward_propagation(network, labels[i]);
|
backward_propagation(network, labels[i]);
|
||||||
if (cpt==16) { // Update the network
|
|
||||||
update_weights(network);
|
|
||||||
update_bias(network);
|
|
||||||
cpt = 0;
|
|
||||||
}
|
|
||||||
cpt++;
|
|
||||||
if (maxi == labels[i]) {
|
if (maxi == labels[i]) {
|
||||||
accuracy += 1.;
|
accuracy += 1.;
|
||||||
}
|
}
|
||||||
@ -74,6 +69,7 @@ void* train_thread(void* parameters) {
|
|||||||
|
|
||||||
|
|
||||||
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out) {
|
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out) {
|
||||||
|
srand(time(NULL));
|
||||||
int input_dim = -1;
|
int input_dim = -1;
|
||||||
int input_depth = -1;
|
int input_depth = -1;
|
||||||
float accuracy;
|
float accuracy;
|
||||||
@ -190,19 +186,25 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
|||||||
free_network(train_parameters[k]->network);
|
free_network(train_parameters[k]->network);
|
||||||
}
|
}
|
||||||
current_accuracy = accuracy * nb_images_total/(j*BATCHES);
|
current_accuracy = accuracy * nb_images_total/(j*BATCHES);
|
||||||
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
||||||
#else
|
#else
|
||||||
train_params->start = j*BATCHES;
|
train_params->start = j*BATCHES;
|
||||||
|
|
||||||
train_thread((void*)train_params);
|
train_thread((void*)train_params);
|
||||||
|
|
||||||
accuracy += train_params->accuracy / (float) nb_images_total;
|
accuracy += train_params->accuracy / (float) nb_images_total;
|
||||||
current_accuracy = accuracy * nb_images_total/(j*BATCHES);
|
current_accuracy = accuracy * nb_images_total/(j*BATCHES);
|
||||||
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET"\t", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
|
||||||
|
update_weights(network);
|
||||||
|
update_bias(network);
|
||||||
|
|
||||||
|
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.1f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
#ifdef USE_MULTITHREADING
|
#ifdef USE_MULTITHREADING
|
||||||
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100);
|
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET" \n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100);
|
||||||
#else
|
#else
|
||||||
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET"\t\n", i, epochs, nb_images_total, nb_images_total, accuracy*100);
|
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.1f%%"RESET" \n", i, epochs, nb_images_total, nb_images_total, accuracy*100);
|
||||||
#endif
|
#endif
|
||||||
write_network(out, network);
|
write_network(out, network);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user