diff --git a/src/cnn/function.c b/src/cnn/function.c index 6e9a7fd..9e1b0b0 100644 --- a/src/cnn/function.c +++ b/src/cnn/function.c @@ -76,17 +76,13 @@ void apply_function_input(float (*f)(float), float*** input, int depth, int rows void choose_apply_function_matrix(int activation, float*** input, int depth, int dim) { if (activation == RELU) { apply_function_input(relu, input, depth, dim, dim); - } - else if (activation == SIGMOID) { + } else if (activation == SIGMOID) { apply_function_input(sigmoid, input, depth, dim, dim); - } - else if (activation == SOFTMAX) { + } else if (activation == SOFTMAX) { apply_softmax_input(input, depth, dim, dim); - } - else if (activation == TANH) { + } else if (activation == TANH) { apply_function_input(tanh_, input, depth, dim, dim); - } - else { + } else { printf("Erreur, fonction d'activation inconnue (choose_apply_function_matrix): %d\n", activation); } } @@ -94,17 +90,35 @@ void choose_apply_function_matrix(int activation, float*** input, int depth, int void choose_apply_function_vector(int activation, float*** input, int dim) { if (activation == RELU) { apply_function_input(relu, input, 1, 1, dim); - } - else if (activation == SIGMOID) { + } else if (activation == SIGMOID) { apply_function_input(sigmoid, input, 1, 1, dim); - } - else if (activation == SOFTMAX) { + } else if (activation == SOFTMAX) { apply_softmax_input(input, 1, 1, dim); - } - else if (activation == TANH) { + } else if (activation == TANH) { apply_function_input(tanh_, input, 1, 1, dim); + } else { + printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation); } - else { +} + +void* get_function_activation(int activation) { + if (activation == RELU) { + return relu; + } else if (activation == -RELU) { + return relu_derivative; + } else if (activation == SIGMOID) { + return sigmoid; + } else if (activation == -SIGMOID) { + return sigmoid_derivative + } else if (activation == SOFTMAX) { + printf("Erreur, impossible de renvoyer la fonction softmax"); + } else if (activation == -SOFTMAX) { + printf("Erreur, impossible de renvoyer la dérivée de la fonction softmax"); + } else if (activation == TANH) { + return tanh_; + } else if (activation == -TANH) { + return tanh_derivative; + } else { printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation); } } \ No newline at end of file diff --git a/src/cnn/include/function.h b/src/cnn/include/function.h index 699d381..50f89c6 100644 --- a/src/cnn/include/function.h +++ b/src/cnn/include/function.h @@ -44,5 +44,9 @@ void choose_apply_function_matrix(int activation, float*** input, int depth, int */ void choose_apply_function_vector(int activation, float*** input, int dim); +/* +* Renvoie un pointeur vers la fonction d'activation correspondante +*/ +void* get_function_activation(int activation) #endif \ No newline at end of file