Compare commits

..

No commits in common. "08993ade85732487a106415df75c681b765960e6" and "964687d1b4277a9e4c6a98c54b48b172b379487b" have entirely different histories.

2 changed files with 5 additions and 56 deletions

View File

@ -8,8 +8,6 @@
#define LEARNING_RATE 3e-4 // Taux d'apprentissage
#define USE_MULTITHREADING // Commenter pour utiliser un seul coeur durant l'apprentissage (meilleur pour des tailles de batchs traités rapidement)
//#define DETAILED_TRAIN_TIMINGS // Afficher le temps de forward/ backward
//* Paramètres d'ADAM optimizer
#define ALPHA 3e-4
#define BETA_1 0.9

View File

@ -62,10 +62,6 @@ void* train_thread(void* parameters) {
float accuracy = 0.;
float loss = 0.;
#ifdef DETAILED_TRAIN_TIMINGS
double start_time;
#endif
pthread_t tid;
LoadImageParameters* load_image_param = (LoadImageParameters*)malloc(sizeof(LoadImageParameters));
if (dataset_type != 0) {
@ -78,20 +74,7 @@ void* train_thread(void* parameters) {
for (int i=start; i < start+nb_images; i++) {
if (dataset_type == 0) {
write_image_in_network_32(images[index[i]], height, width, network->input[0][0], true);
#ifdef DETAILED_TRAIN_TIMINGS
start_time = omp_get_wtime();
#endif
forward_propagation(network);
#ifdef DETAILED_TRAIN_TIMINGS
printf("Temps de forward: ");
printf_time(omp_get_wtime() - start_time);
printf("\n");
start_time = omp_get_wtime();
#endif
maxi = indice_max(network->input[network->size-1][0][0], 10);
if (maxi == -1) {
printf("\n");
@ -105,13 +88,6 @@ void* train_thread(void* parameters) {
backward_propagation(network, labels[index[i]]);
#ifdef DETAILED_TRAIN_TIMINGS
printf("Temps de backward: ");
printf_time(omp_get_wtime() - start_time);
printf("\n");
start_time = omp_get_wtime();
#endif
if (maxi == labels[index[i]]) {
accuracy += 1.;
}
@ -128,30 +104,9 @@ void* train_thread(void* parameters) {
pthread_create(&tid, NULL, load_image, (void*) load_image_param);
}
write_256_image_in_network(param->dataset->images[index[i]], width, param->dataset->numComponents, network->width[0], network->input[0]);
#ifdef DETAILED_TRAIN_TIMINGS
start_time = omp_get_wtime();
#endif
forward_propagation(network);
#ifdef DETAILED_TRAIN_TIMINGS
printf("Temps de forward: ");
printf_time(omp_get_wtime() - start_time);
printf("\n");
start_time = omp_get_wtime();
#endif
maxi = indice_max(network->input[network->size-1][0][0], param->dataset->numCategories);
backward_propagation(network, param->dataset->labels[index[i]]);
#ifdef DETAILED_TRAIN_TIMINGS
printf("Temps de backward: ");
printf_time(omp_get_wtime() - start_time);
printf("\n");
start_time = omp_get_wtime();
#endif
if (maxi == (int)param->dataset->labels[index[i]]) {
accuracy += 1.;
@ -205,9 +160,9 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
int nb_images_total_remaining; // Images restantes dans un batch
int batches_epoques; // Batches par époque
int*** images = NULL; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
unsigned int* labels = NULL; // Labels associés aux images du dataset MNIST
jpegDataset* dataset = NULL; // Structure de données décrivant un dataset d'images jpeg
int*** images; // Images sous forme de tableau de tableaux de tableaux de pixels (degré de gris, MNIST)
unsigned int* labels; // Labels associés aux images du dataset MNIST
jpegDataset* dataset; // Structure de données décrivant un dataset d'images jpeg
if (dataset_type == 0) { // Type MNIST
// Chargement des images du set de données MNIST
int* parameters = read_mnist_images_parameters(images_file);
@ -231,14 +186,10 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
Network* network;
if (!recover) {
if (dataset_type == 0) {
network = create_network_lenet5(LEARNING_RATE, 0, LEAKY_RELU, HE, input_width, input_depth);
network = create_network_lenet5(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, input_width, input_depth);
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_width, input_depth);
} else {
network = create_network_VGG16(LEARNING_RATE, 0, RELU, HE, dataset->numCategories);
#ifdef USE_MULTITHREADING
printf_warning("Utilisation de VGG16 avec multithreading. La quantité de RAM utilisée peut devenir excessive\n");
#endif
network = create_network_VGG16(LEARNING_RATE, 0, RELU, NORMALIZED_XAVIER, dataset->numCategories);
}
} else {
network = read_network(recover);