Add 'get_function_activation' function

This commit is contained in:
julienChemillier 2022-10-26 17:32:54 +02:00
parent 3d812701f7
commit 816f7ea334
2 changed files with 33 additions and 15 deletions

View File

@ -76,17 +76,13 @@ void apply_function_input(float (*f)(float), float*** input, int depth, int rows
void choose_apply_function_matrix(int activation, float*** input, int depth, int dim) { void choose_apply_function_matrix(int activation, float*** input, int depth, int dim) {
if (activation == RELU) { if (activation == RELU) {
apply_function_input(relu, input, depth, dim, dim); apply_function_input(relu, input, depth, dim, dim);
} } else if (activation == SIGMOID) {
else if (activation == SIGMOID) {
apply_function_input(sigmoid, input, depth, dim, dim); apply_function_input(sigmoid, input, depth, dim, dim);
} } else if (activation == SOFTMAX) {
else if (activation == SOFTMAX) {
apply_softmax_input(input, depth, dim, dim); apply_softmax_input(input, depth, dim, dim);
} } else if (activation == TANH) {
else if (activation == TANH) {
apply_function_input(tanh_, input, depth, dim, dim); apply_function_input(tanh_, input, depth, dim, dim);
} } else {
else {
printf("Erreur, fonction d'activation inconnue (choose_apply_function_matrix): %d\n", activation); printf("Erreur, fonction d'activation inconnue (choose_apply_function_matrix): %d\n", activation);
} }
} }
@ -94,17 +90,35 @@ void choose_apply_function_matrix(int activation, float*** input, int depth, int
void choose_apply_function_vector(int activation, float*** input, int dim) { void choose_apply_function_vector(int activation, float*** input, int dim) {
if (activation == RELU) { if (activation == RELU) {
apply_function_input(relu, input, 1, 1, dim); apply_function_input(relu, input, 1, 1, dim);
} } else if (activation == SIGMOID) {
else if (activation == SIGMOID) {
apply_function_input(sigmoid, input, 1, 1, dim); apply_function_input(sigmoid, input, 1, 1, dim);
} } else if (activation == SOFTMAX) {
else if (activation == SOFTMAX) {
apply_softmax_input(input, 1, 1, dim); apply_softmax_input(input, 1, 1, dim);
} } else if (activation == TANH) {
else if (activation == TANH) {
apply_function_input(tanh_, input, 1, 1, dim); apply_function_input(tanh_, input, 1, 1, dim);
} } else {
else { printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation);
}
}
void* get_function_activation(int activation) {
if (activation == RELU) {
return relu;
} else if (activation == -RELU) {
return relu_derivative;
} else if (activation == SIGMOID) {
return sigmoid;
} else if (activation == -SIGMOID) {
return sigmoid_derivative
} else if (activation == SOFTMAX) {
printf("Erreur, impossible de renvoyer la fonction softmax");
} else if (activation == -SOFTMAX) {
printf("Erreur, impossible de renvoyer la dérivée de la fonction softmax");
} else if (activation == TANH) {
return tanh_;
} else if (activation == -TANH) {
return tanh_derivative;
} else {
printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation); printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation);
} }
} }

View File

@ -44,5 +44,9 @@ void choose_apply_function_matrix(int activation, float*** input, int depth, int
*/ */
void choose_apply_function_vector(int activation, float*** input, int dim); void choose_apply_function_vector(int activation, float*** input, int dim);
/*
* Renvoie un pointeur vers la fonction d'activation correspondante
*/
void* get_function_activation(int activation)
#endif #endif