diff --git a/src/cnn/update.c b/src/cnn/update.c index 0392799..f3ab286 100644 --- a/src/cnn/update.c +++ b/src/cnn/update.c @@ -22,7 +22,7 @@ void update_weights(Network* network) { for (int b=0; bw[a][b][c][d] += cnn->d_w[a][b][c][d]; + cnn->w[a][b][c][d] += network->learning_rate * cnn->d_w[a][b][c][d]; cnn->d_w[a][b][c][d] = 0; } } @@ -33,7 +33,7 @@ void update_weights(Network* network) { Kernel_nn* nn = k_i_1->nn; for (int a=0; aweights[a][b] += nn->d_weights[a][b]; + nn->weights[a][b] += network->learning_rate * nn->d_weights[a][b]; nn->d_weights[a][b] = 0; } } @@ -42,7 +42,7 @@ void update_weights(Network* network) { int input_size = input_width*input_width*input_depth; for (int a=0; aweights[a][b] += nn->d_weights[a][b]; + nn->weights[a][b] += network->learning_rate * nn->d_weights[a][b]; nn->d_weights[a][b] = 0; } } @@ -69,7 +69,7 @@ void update_bias(Network* network) { for (int a=0; abias[a][b][c] += cnn->d_bias[a][b][c]; + cnn->bias[a][b][c] += network->learning_rate * cnn->d_bias[a][b][c]; cnn->d_bias[a][b][c] = 0; } } @@ -77,7 +77,7 @@ void update_bias(Network* network) { } else if (k_i->nn) { // Full connection Kernel_nn* nn = k_i_1->nn; for (int a=0; abias[a] += nn->d_bias[a]; + nn->bias[a] += network->learning_rate * nn->d_bias[a]; nn->d_bias[a] = 0; } } else { // Pooling