Rename MAX_RESEAU to CLIP_VALUE

This commit is contained in:
augustin64 2023-03-01 19:12:57 +01:00
parent 95ce123587
commit c45b21e322
2 changed files with 23 additions and 17 deletions

View File

@ -3,7 +3,13 @@
#ifndef DEF_UPDATE_H
#define DEF_UPDATE_H
#define MAX_RESEAU 100000000
/*
* Des valeurs trop grandes dans le réseau risqueraient de provoquer des overflows notamment.
* On utilise donc la méthode gradient_clipping,
* qui consiste à majorer tous les biais et poids par un hyper-paramètre choisi précédemment.
* https://arxiv.org/pdf/1905.11881.pdf
*/
#define CLIP_VALUE 300
/*
* Met à jours les poids à partir de données obtenus après plusieurs backpropagations

View File

@ -27,10 +27,10 @@ void update_weights(Network* network, Network* d_network) {
cnn->weights[a][b][c][d] -= network->learning_rate * d_cnn->d_weights[a][b][c][d];
d_cnn->d_weights[a][b][c][d] = 0;
if (cnn->weights[a][b][c][d] > MAX_RESEAU)
cnn->weights[a][b][c][d] = MAX_RESEAU;
else if (cnn->weights[a][b][c][d] < -MAX_RESEAU)
cnn->weights[a][b][c][d] = -MAX_RESEAU;
if (cnn->weights[a][b][c][d] > CLIP_VALUE)
cnn->weights[a][b][c][d] = CLIP_VALUE;
else if (cnn->weights[a][b][c][d] < -CLIP_VALUE)
cnn->weights[a][b][c][d] = -CLIP_VALUE;
}
}
}
@ -54,10 +54,10 @@ void update_weights(Network* network, Network* d_network) {
nn->weights[a][b] -= network->learning_rate * d_nn->d_weights[a][b];
d_nn->d_weights[a][b] = 0;
if (nn->weights[a][b] > MAX_RESEAU)
nn->weights[a][b] = MAX_RESEAU;
else if (nn->weights[a][b] < -MAX_RESEAU)
nn->weights[a][b] = -MAX_RESEAU;
if (nn->weights[a][b] > CLIP_VALUE)
nn->weights[a][b] = CLIP_VALUE;
else if (nn->weights[a][b] < -CLIP_VALUE)
nn->weights[a][b] = -CLIP_VALUE;
}
}
}
@ -88,10 +88,10 @@ void update_bias(Network* network, Network* d_network) {
cnn->bias[a][b][c] -= network->learning_rate * d_cnn->d_bias[a][b][c];
d_cnn->d_bias[a][b][c] = 0;
if (cnn->bias[a][b][c] > MAX_RESEAU)
cnn->bias[a][b][c] = MAX_RESEAU;
else if (cnn->bias[a][b][c] < -MAX_RESEAU)
cnn->bias[a][b][c] = -MAX_RESEAU;
if (cnn->bias[a][b][c] > CLIP_VALUE)
cnn->bias[a][b][c] = CLIP_VALUE;
else if (cnn->bias[a][b][c] < -CLIP_VALUE)
cnn->bias[a][b][c] = -CLIP_VALUE;
}
}
}
@ -102,10 +102,10 @@ void update_bias(Network* network, Network* d_network) {
nn->bias[a] -= network->learning_rate * d_nn->d_bias[a];
d_nn->d_bias[a] = 0;
if (nn->bias[a] > MAX_RESEAU)
nn->bias[a] = MAX_RESEAU;
else if (nn->bias[a] < -MAX_RESEAU)
nn->bias[a] = -MAX_RESEAU;
if (nn->bias[a] > CLIP_VALUE)
nn->bias[a] = CLIP_VALUE;
else if (nn->bias[a] < -CLIP_VALUE)
nn->bias[a] = -CLIP_VALUE;
}
} else { // Pooling
(void)0; // Ne rien faire pour la couche pooling