tipe/src/cnn/include/train.h

47 lines
1.2 KiB
C
Raw Normal View History

2022-10-04 12:43:37 +02:00
#include "struct.h"
2022-11-19 16:09:07 +01:00
#include "jpeg.h"
2022-10-04 12:43:37 +02:00
2022-10-01 17:53:14 +02:00
#ifndef DEF_TRAIN_H
#define DEF_TRAIN_H
#define EPOCHS 10
2023-01-20 13:41:38 +01:00
#define BATCHES 500
2023-02-04 13:12:52 +01:00
#define USE_MULTITHREADING
2023-01-30 09:39:45 +01:00
#define LEARNING_RATE 0.05
2022-10-01 17:53:14 +02:00
/*
* Structure donnée en argument à la fonction 'train_thread'
*/
typedef struct TrainParameters {
2022-11-19 16:09:07 +01:00
Network* network; // Réseau
jpegDataset* dataset; // Dataset si de type JPEG
int* index; // Sert à réordonner les images
int*** images; // Images si de type MNIST
unsigned int* labels; // Labels si de type MNIST
int width; // Largeur des images
int height; // Hauteur des images
int dataset_type; // Type de dataset
int start; // Début des images
2023-02-04 13:12:52 +01:00
int nb_images; // Nombre d'images à traiter
2022-11-19 16:09:07 +01:00
float accuracy; // Accuracy (à renvoyer)
2023-01-20 13:41:38 +01:00
float loss; // Loss (à renvoyer)
2022-10-01 17:53:14 +02:00
} TrainParameters;
2023-02-04 13:12:52 +01:00
/*
* Partie entière supérieure de a/b
*/
int div_up(int a, int b);
2022-11-03 18:13:01 +01:00
2022-10-01 17:53:14 +02:00
/*
* Fonction auxiliaire d'entraînement destinée à être exécutée sur plusieurs threads à la fois
*/
void* train_thread(void* parameters);
/*
* Fonction principale d'entraînement du réseau neuronal convolutif
*/
2022-12-07 13:09:39 +01:00
void train(int dataset_type, char* images_file, char* labels_file, char* data_dir, int epochs, char* out, char* recover);
2022-10-01 17:53:14 +02:00
#endif