Add learning rate

This commit is contained in:
Julien Chemillier 2022-10-03 10:04:11 +02:00
parent 9f44e4a189
commit a604c96476
6 changed files with 10 additions and 8 deletions

View File

@ -4,11 +4,12 @@
#include "include/function.h" #include "include/function.h"
#include "initialisation.c" #include "initialisation.c"
Network* create_network(int max_size, int dropout, int initialisation, int input_dim, int input_depth) { Network* create_network(int max_size, int learning_rate, int dropout, int initialisation, int input_dim, int input_depth) {
if (dropout < 0 || dropout > 100) { if (dropout < 0 || dropout > 100) {
printf("Erreur, la probabilité de dropout n'est pas respecté, elle doit être comprise entre 0 et 100\n"); printf("Erreur, la probabilité de dropout n'est pas respecté, elle doit être comprise entre 0 et 100\n");
} }
Network* network = (Network*)malloc(sizeof(Network)); Network* network = (Network*)malloc(sizeof(Network));
network->learning_rate = learning_rate;
network->max_size = max_size; network->max_size = max_size;
network->dropout = dropout; network->dropout = dropout;
network->initialisation = initialisation; network->initialisation = initialisation;
@ -28,8 +29,8 @@ Network* create_network(int max_size, int dropout, int initialisation, int input
return network; return network;
} }
Network* create_network_lenet5(int dropout, int activation, int initialisation, int input_dim, int input_depth) { Network* create_network_lenet5(int learning_rate, int dropout, int activation, int initialisation, int input_dim, int input_depth) {
Network* network = create_network(8, dropout, initialisation, input_dim, input_depth); Network* network = create_network(8, learning_rate, dropout, initialisation, input_dim, input_depth);
network->kernel[0]->activation = activation; network->kernel[0]->activation = activation;
network->kernel[0]->linearisation = 0; network->kernel[0]->linearisation = 0;
add_convolution(network, 1, 32, 6, 28, activation); add_convolution(network, 1, 32, 6, 28, activation);

View File

@ -7,12 +7,12 @@
/* /*
* Créé un réseau qui peut contenir max_size couche (dont celle d'input et d'output) * Créé un réseau qui peut contenir max_size couche (dont celle d'input et d'output)
*/ */
Network* create_network(int max_size, int dropout, int initialisation, int input_dim, int input_depth); Network* create_network(int max_size, int learning_rate, int dropout, int initialisation, int input_dim, int input_depth);
/* /*
* Renvoie un réseau suivant l'architecture LeNet5 * Renvoie un réseau suivant l'architecture LeNet5
*/ */
Network* create_network_lenet5(int dropout, int activation, int initialisation, int input_dim, int input_depth); Network* create_network_lenet5(int learning_rate, int dropout, int activation, int initialisation, int input_dim, int input_depth);
/* /*
* Créé et alloue de la mémoire à une couche de type input cube * Créé et alloue de la mémoire à une couche de type input cube

View File

@ -30,6 +30,7 @@ typedef struct Kernel {
typedef struct Network{ typedef struct Network{
int dropout; // Contient la probabilité d'abandon d'un neurone dans [0, 100] (entiers) int dropout; // Contient la probabilité d'abandon d'un neurone dans [0, 100] (entiers)
int learning_rate;
int initialisation; // Contient le type d'initialisation int initialisation; // Contient le type d'initialisation
int max_size; // Taille du tableau contenant le réseau int max_size; // Taille du tableau contenant le réseau
int size; // Taille actuelle du réseau (size ≤ max_size) int size; // Taille actuelle du réseau (size ≤ max_size)

View File

@ -27,7 +27,7 @@ void help(char* call) {
void dev_conv() { void dev_conv() {
Network* network = create_network_lenet5(0, TANH, GLOROT_NORMAL, 32, 1); Network* network = create_network_lenet5(0, 0, TANH, GLOROT_NORMAL, 32, 1);
forward_propagation(network); forward_propagation(network);
} }

View File

@ -77,7 +77,7 @@ void train(int dataset_type, char* images_file, char* labels_file, char* data_di
} }
// Initialisation du réseau // Initialisation du réseau
Network* network = create_network_lenet5(0, TANH, GLOROT_NORMAL, input_dim, input_depth); Network* network = create_network_lenet5(0, 0, TANH, GLOROT_NORMAL, input_dim, input_depth);
#ifdef USE_MULTITHREADING #ifdef USE_MULTITHREADING
// Récupération du nombre de threads disponibles // Récupération du nombre de threads disponibles

View File

@ -12,7 +12,7 @@
int main() { int main() {
printf("Création du réseau\n"); printf("Création du réseau\n");
Network* network = create_network_lenet5(0, 3, 2, 32, 1); Network* network = create_network_lenet5(0, 0, 3, 2, 32, 1);
printf("OK\n"); printf("OK\n");
printf("Écriture du réseau\n"); printf("Écriture du réseau\n");