Add simple_one

This commit is contained in:
augustin64 2023-01-21 18:59:59 +01:00
parent 5f47b93672
commit 7deef7c5c5
4 changed files with 27 additions and 7 deletions

View File

@ -47,6 +47,15 @@ Network* create_network_lenet5(float learning_rate, int dropout, int activation,
return network;
}
Network* create_simple_one(float learning_rate, int dropout, int activation, int initialisation, int input_dim, int input_depth) {
Network* network = create_network(3, learning_rate, dropout, initialisation, input_dim, input_depth);
network->kernel[0]->activation = activation;
network->kernel[0]->linearisation = 0;
add_dense_linearisation(network, 80, activation);
add_dense(network, 10, SOFTMAX);
return network;
}
void create_a_cube_input_layer(Network* network, int pos, int depth, int dim) {
network->input[pos] = (float***)malloc(sizeof(float**)*depth);
for (int i=0; i < depth; i++) {

View File

@ -14,6 +14,11 @@ Network* create_network(int max_size, float learning_rate, int dropout, int init
*/
Network* create_network_lenet5(float learning_rate, int dropout, int activation, int initialisation, int input_dim, int input_depth);
/*
* Renvoie un réseau sans convolution, similaire à celui utilisé dans src/mnist
*/
Network* create_simple_one(float learning_rate, int dropout, int activation, int initialisation, int input_dim, int input_depth);
/*
* Créé et alloue de la mémoire à une couche de type input cube
*/

View File

@ -7,7 +7,7 @@
#define EPOCHS 10
#define BATCHES 500
#define USE_MULTITHREADING
#define LEARNING_RATE 0.01
#define LEARNING_RATE 0.5
/*

View File

@ -50,6 +50,11 @@ void* train_thread(void* parameters) {
write_image_in_network_32(images[index[i]], height, width, network->input[0][0]);
forward_propagation(network);
maxi = indice_max(network->input[network->size-1][0][0], 10);
if (maxi == -1) {
printf("\n");
printf_error("Le réseau sature.\n");
exit(1);
}
wanted_output = generate_wanted_output(labels[index[i]], 10);
loss += compute_mean_squared_error(network->input[network->size-1][0][0], wanted_output, 10);
@ -134,7 +139,8 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
// Initialisation du réseau
if (!recover) {
network = create_network_lenet5(LEARNING_RATE, 0, TANH, GLOROT, input_dim, input_depth);
network = create_network_lenet5(LEARNING_RATE, 0, RELU, GLOROT, input_dim, input_depth);
//network = create_simple_one(LEARNING_RATE, 0, RELU, GLOROT, input_dim, input_depth);
} else {
network = read_network(recover);
network->learning_rate = LEARNING_RATE;
@ -272,7 +278,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
}
}
current_accuracy = accuracy * nb_images_total/((j+1)*BATCHES);
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET" ", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.2f%%"RESET" \tRéussies: %d", nb_threads, i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100, (int)(accuracy*nb_images_total));
fflush(stdout);
#else
(void)nb_images_total_remaining; // Juste pour enlever un warning
@ -294,19 +300,19 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
update_weights(network, network);
update_bias(network, network);
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET" ", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100);
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "YELLOW"%0.4f%%"RESET" \tRéussies: %d", i, epochs, BATCHES*(j+1), nb_images_total, current_accuracy*100, (int)(accuracy*nb_images_total));
fflush(stdout);
#endif
// Il serait intéressant d'utiliser la perte calculée pour
// savoir l'avancement dans l'apprentissage et donc comment adapter le taux d'apprentissage
//network->learning_rate = 0.01*batch_loss;
network->learning_rate = 10*LEARNING_RATE*batch_loss;
}
end_time = omp_get_wtime();
elapsed_time = end_time - start_time;
#ifdef USE_MULTITHREADING
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET"\tTemps: %0.2f s\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100, elapsed_time);
printf("\rThreads [%d]\tÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET" \tRéussies: %d\tTemps: %0.2f s\n", nb_threads, i, epochs, nb_images_total, nb_images_total, accuracy*100, (int)(accuracy*nb_images_total), elapsed_time);
#else
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET"\tTemps: %0.2f s\n", i, epochs, nb_images_total, nb_images_total, accuracy*100, elapsed_time);
printf("\rÉpoque [%d/%d]\tImage [%d/%d]\tAccuracy: "GREEN"%0.4f%%"RESET" \tRéussies: %d\tTemps: %0.2f s\n", i, epochs, nb_images_total, nb_images_total, accuracy*100, (int)(accuracy*nb_images_total), elapsed_time);
#endif
write_network(out, network);
}