mirror of
https://github.com/augustin64/projet-tipe
synced 2025-03-13 06:15:21 +01:00
Update backpropagation (It works now)
This commit is contained in:
parent
f943228e80
commit
220d0a71be
@ -86,7 +86,7 @@ void backward_linearisation(Kernel_nn* ker, float*** input, float*** input_z, fl
|
|||||||
for (int k=0; k < dim_input; k++) {
|
for (int k=0; k < dim_input; k++) {
|
||||||
for (int l=0; l < dim_input; l++) {
|
for (int l=0; l < dim_input; l++) {
|
||||||
for (int j=0; j < size_output; j++) {
|
for (int j=0; j < size_output; j++) {
|
||||||
ker->d_weights[cpt][j] += input[i][k][l]*output[j];
|
ker->d_weights[cpt][j] += input[i][k][l]*output[j]/(depth_input*dim_input*dim_input*size_output);
|
||||||
}
|
}
|
||||||
cpt++;
|
cpt++;
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
#ifndef DEF_UPDATE_H
|
#ifndef DEF_UPDATE_H
|
||||||
#define DEF_UPDATE_H
|
#define DEF_UPDATE_H
|
||||||
|
|
||||||
|
#define MAX_RESEAU 100000000
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Met à jours les poids à partir de données obtenus après plusieurs backpropagations
|
* Met à jours les poids à partir de données obtenus après plusieurs backpropagations
|
||||||
* Puis met à 0 tous les d_weights
|
* Puis met à 0 tous les d_weights
|
||||||
|
@ -26,6 +26,11 @@ void update_weights(Network* network, Network* d_network) {
|
|||||||
for (int d=0; d<k_size; d++) {
|
for (int d=0; d<k_size; d++) {
|
||||||
cnn->w[a][b][c][d] -= network->learning_rate * d_cnn->d_w[a][b][c][d];
|
cnn->w[a][b][c][d] -= network->learning_rate * d_cnn->d_w[a][b][c][d];
|
||||||
d_cnn->d_w[a][b][c][d] = 0;
|
d_cnn->d_w[a][b][c][d] = 0;
|
||||||
|
|
||||||
|
if (cnn->w[a][b][c][d] > MAX_RESEAU)
|
||||||
|
cnn->w[a][b][c][d] = MAX_RESEAU;
|
||||||
|
else if (cnn->w[a][b][c][d] < -MAX_RESEAU)
|
||||||
|
cnn->w[a][b][c][d] = -MAX_RESEAU;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -48,6 +53,11 @@ void update_weights(Network* network, Network* d_network) {
|
|||||||
for (int b=0; b<output_width; b++) {
|
for (int b=0; b<output_width; b++) {
|
||||||
nn->weights[a][b] -= network->learning_rate * d_nn->d_weights[a][b];
|
nn->weights[a][b] -= network->learning_rate * d_nn->d_weights[a][b];
|
||||||
d_nn->d_weights[a][b] = 0;
|
d_nn->d_weights[a][b] = 0;
|
||||||
|
|
||||||
|
if (nn->weights[a][b] > MAX_RESEAU)
|
||||||
|
nn->weights[a][b] = MAX_RESEAU;
|
||||||
|
else if (nn->weights[a][b] < -MAX_RESEAU)
|
||||||
|
nn->weights[a][b] = -MAX_RESEAU;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -77,6 +87,11 @@ void update_bias(Network* network, Network* d_network) {
|
|||||||
for (int c=0; c<output_width; c++) {
|
for (int c=0; c<output_width; c++) {
|
||||||
cnn->bias[a][b][c] -= network->learning_rate * d_cnn->d_bias[a][b][c];
|
cnn->bias[a][b][c] -= network->learning_rate * d_cnn->d_bias[a][b][c];
|
||||||
d_cnn->d_bias[a][b][c] = 0;
|
d_cnn->d_bias[a][b][c] = 0;
|
||||||
|
|
||||||
|
if (cnn->bias[a][b][c] > MAX_RESEAU)
|
||||||
|
cnn->bias[a][b][c] = MAX_RESEAU;
|
||||||
|
else if (cnn->bias[a][b][c] < -MAX_RESEAU)
|
||||||
|
cnn->bias[a][b][c] = -MAX_RESEAU;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -86,6 +101,11 @@ void update_bias(Network* network, Network* d_network) {
|
|||||||
for (int a=0; a<output_width; a++) {
|
for (int a=0; a<output_width; a++) {
|
||||||
nn->bias[a] -= network->learning_rate * d_nn->d_bias[a];
|
nn->bias[a] -= network->learning_rate * d_nn->d_bias[a];
|
||||||
d_nn->d_bias[a] = 0;
|
d_nn->d_bias[a] = 0;
|
||||||
|
|
||||||
|
if (nn->bias[a] > MAX_RESEAU)
|
||||||
|
nn->bias[a] = MAX_RESEAU;
|
||||||
|
else if (nn->bias[a] < -MAX_RESEAU)
|
||||||
|
nn->bias[a] = -MAX_RESEAU;
|
||||||
}
|
}
|
||||||
} else { // Pooling
|
} else { // Pooling
|
||||||
(void)0; // Ne rien faire pour la couche pooling
|
(void)0; // Ne rien faire pour la couche pooling
|
||||||
|
Loading…
x
Reference in New Issue
Block a user