diff --git a/src/cnn/include/update.h b/src/cnn/include/update.h index ce6d635..e370130 100644 --- a/src/cnn/include/update.h +++ b/src/cnn/include/update.h @@ -3,7 +3,13 @@ #ifndef DEF_UPDATE_H #define DEF_UPDATE_H -#define MAX_RESEAU 100000000 +/* +* Des valeurs trop grandes dans le réseau risqueraient de provoquer des overflows notamment. +* On utilise donc la méthode gradient_clipping, +* qui consiste à majorer tous les biais et poids par un hyper-paramètre choisi précédemment. +* https://arxiv.org/pdf/1905.11881.pdf +*/ +#define CLIP_VALUE 300 /* * Met à jours les poids à partir de données obtenus après plusieurs backpropagations diff --git a/src/cnn/update.c b/src/cnn/update.c index e2501a4..504bebd 100644 --- a/src/cnn/update.c +++ b/src/cnn/update.c @@ -27,10 +27,10 @@ void update_weights(Network* network, Network* d_network) { cnn->weights[a][b][c][d] -= network->learning_rate * d_cnn->d_weights[a][b][c][d]; d_cnn->d_weights[a][b][c][d] = 0; - if (cnn->weights[a][b][c][d] > MAX_RESEAU) - cnn->weights[a][b][c][d] = MAX_RESEAU; - else if (cnn->weights[a][b][c][d] < -MAX_RESEAU) - cnn->weights[a][b][c][d] = -MAX_RESEAU; + if (cnn->weights[a][b][c][d] > CLIP_VALUE) + cnn->weights[a][b][c][d] = CLIP_VALUE; + else if (cnn->weights[a][b][c][d] < -CLIP_VALUE) + cnn->weights[a][b][c][d] = -CLIP_VALUE; } } } @@ -54,10 +54,10 @@ void update_weights(Network* network, Network* d_network) { nn->weights[a][b] -= network->learning_rate * d_nn->d_weights[a][b]; d_nn->d_weights[a][b] = 0; - if (nn->weights[a][b] > MAX_RESEAU) - nn->weights[a][b] = MAX_RESEAU; - else if (nn->weights[a][b] < -MAX_RESEAU) - nn->weights[a][b] = -MAX_RESEAU; + if (nn->weights[a][b] > CLIP_VALUE) + nn->weights[a][b] = CLIP_VALUE; + else if (nn->weights[a][b] < -CLIP_VALUE) + nn->weights[a][b] = -CLIP_VALUE; } } } @@ -88,10 +88,10 @@ void update_bias(Network* network, Network* d_network) { cnn->bias[a][b][c] -= network->learning_rate * d_cnn->d_bias[a][b][c]; d_cnn->d_bias[a][b][c] = 0; - if (cnn->bias[a][b][c] > MAX_RESEAU) - cnn->bias[a][b][c] = MAX_RESEAU; - else if (cnn->bias[a][b][c] < -MAX_RESEAU) - cnn->bias[a][b][c] = -MAX_RESEAU; + if (cnn->bias[a][b][c] > CLIP_VALUE) + cnn->bias[a][b][c] = CLIP_VALUE; + else if (cnn->bias[a][b][c] < -CLIP_VALUE) + cnn->bias[a][b][c] = -CLIP_VALUE; } } } @@ -102,10 +102,10 @@ void update_bias(Network* network, Network* d_network) { nn->bias[a] -= network->learning_rate * d_nn->d_bias[a]; d_nn->d_bias[a] = 0; - if (nn->bias[a] > MAX_RESEAU) - nn->bias[a] = MAX_RESEAU; - else if (nn->bias[a] < -MAX_RESEAU) - nn->bias[a] = -MAX_RESEAU; + if (nn->bias[a] > CLIP_VALUE) + nn->bias[a] = CLIP_VALUE; + else if (nn->bias[a] < -CLIP_VALUE) + nn->bias[a] = -CLIP_VALUE; } } else { // Pooling (void)0; // Ne rien faire pour la couche pooling