mirror of
https://github.com/augustin64/projet-tipe
synced 2025-01-23 15:16:26 +01:00
Compare commits
2 Commits
964687d1b4
...
08993ade85
Author | SHA1 | Date | |
---|---|---|---|
08993ade85 | |||
5738892142 |
@ -8,6 +8,8 @@
|
||||
#define LEARNING_RATE 3e-4 // Taux d'apprentissage
|
||||
#define USE_MULTITHREADING // Commenter pour utiliser un seul coeur durant l'apprentissage (meilleur pour des tailles de batchs traités rapidement)
|
||||
|
||||
//#define DETAILED_TRAIN_TIMINGS // Afficher le temps de forward/ backward
|
||||
|
||||
//* Paramètres d'ADAM optimizer
|
||||
#define ALPHA 3e-4
|
||||
#define BETA_1 0.9
|
||||
|
@ -62,6 +62,10 @@ void* train_thread(void* parameters) {
|
||||
float accuracy = 0.;
|
||||
float loss = 0.;
|
||||
|
||||
#ifdef DETAILED_TRAIN_TIMINGS
|
||||
double start_time;
|
||||
#endif
|
||||
|
||||
pthread_t tid;
|
||||
LoadImageParameters* load_image_param = (LoadImageParameters*)malloc(sizeof(LoadImageParameters));
|
||||
if (dataset_type != 0) {
|
||||
@ -74,7 +78,20 @@ void* train_thread(void* parameters) {
|
||||
for (int i=start; i < start+nb_images; i++) {
|
||||
if (dataset_type == 0) {
|
||||
write_image_in_network_32(images[index[i]], height, width, network->input[0][0], true);
|
||||
|
||||
#ifdef DETAILED_TRAIN_TIMINGS
|
||||
start_time = omp_get_wtime();
|
||||
#endif
|
||||
|
||||
forward_propagation(network);
|
||||
|
||||
#ifdef DETAILED_TRAIN_TIMINGS
|
||||
printf("Temps de forward: ");
|
||||
printf_time(omp_get_wtime() - start_time);
|
||||
printf("\n");
|
||||
start_time = omp_get_wtime();
|
||||
#endif
|
||||
|
||||
maxi = indice_max(network->input[network->size-1][0][0], 10);
|
||||
if (maxi == -1) {
|
||||
printf("\n");
|
||||
@ -88,6 +105,13 @@ void* train_thread(void* parameters) {
|
||||
|
||||
backward_propagation(network, labels[index[i]]);
|
||||
|
||||
#ifdef DETAILED_TRAIN_TIMINGS
|
||||
printf("Temps de backward: ");
|
||||
printf_time(omp_get_wtime() - start_time);
|
||||
printf("\n");
|
||||
start_time = omp_get_wtime();
|
||||
#endif
|
||||
|
||||
if (maxi == labels[index[i]]) {
|
||||
accuracy += 1.;
|
||||
}
|
||||
@ -104,9 +128,30 @@ void* train_thread(void* parameters) {
|
||||
pthread_create(&tid, NULL, load_image, (void*) load_image_param);
|
||||
}
|
||||
write_256_image_in_network(param->dataset->images[index[i]], width, param->dataset->numComponents, network->width[0], network->input[0]);
|
||||
|
||||
#ifdef DETAILED_TRAIN_TIMINGS
|
||||
start_time = omp_get_wtime();
|
||||
#endif
|
||||
|
||||
forward_propagation(network);
|
||||
|
||||
#ifdef DETAILED_TRAIN_TIMINGS
|
||||
printf("Temps de forward: ");
|
||||
printf_time(omp_get_wtime() - start_time);
|
||||
printf("\n");
|
||||
start_time = omp_get_wtime();
|
||||
#endif
|
||||
|
||||
maxi = indice_max(network->input[network->size-1][0][0], param->dataset->numCategories);
|
||||
backward_propagation(network, param->dataset->labels[index[i]]);
|
||||
|
||||
#ifdef DETAILED_TRAIN_TIMINGS
|
||||
printf("Temps de backward: ");
|
||||
printf_time(omp_get_wtime() - start_time);
|
||||
printf("\n");
|
||||
start_time = omp_get_wtime();
|
||||
#endif
|
||||
|
||||
|
||||
if (maxi == (int)param->dataset->labels[index[i]]) {
|
||||
accuracy += 1.;
|
||||
@ -160,9 +205,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
||||
int nb_images_total_remaining; // Images restantes dans un batch
|
||||
int batches_epoques; // Batches par époque
|
||||
|
||||
int*** images; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
|
||||
unsigned int* labels; // Labels associés aux images du dataset MNIST
|
||||
jpegDataset* dataset; // Structure de données décrivant un dataset d'images jpeg
|
||||
int*** images = NULL; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
|
||||
unsigned int* labels = NULL; // Labels associés aux images du dataset MNIST
|
||||
jpegDataset* dataset = NULL; // Structure de données décrivant un dataset d'images jpeg
|
||||
if (dataset_type == 0) { // Type MNIST
|
||||
// Chargement des images du set de données MNIST
|
||||
int* parameters = read_mnist_images_parameters(images_file);
|
||||
@ -186,10 +231,14 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
|
||||
Network* network;
|
||||
if (!recover) {
|
||||
if (dataset_type == 0) {
|
||||
network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_width, input_depth);
|
||||
network = create_network_lenet5(LEARNING_RATE, 0, LEAKY_RELU, HE, input_width, input_depth);
|
||||
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth);
|
||||
} else {
|
||||
network = create_network_VGG16(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, dataset->numCategories);
|
||||
network = create_network_VGG16(LEARNING_RATE, 0, RELU, HE, dataset->numCategories);
|
||||
|
||||
#ifdef USE_MULTITHREADING
|
||||
printf_warning("Utilisation de VGG16 avec multithreading. La quantité de RAM utilisée peut devenir excessive\n");
|
||||
#endif
|
||||
}
|
||||
} else {
|
||||
network = read_network(recover);
|
||||
|
Loading…
Reference in New Issue
Block a user