Add 'get_function_activation' function

This commit is contained in:
julienChemillier 2022-11-01 11:20:17 +01:00
parent 3d63e9e63b
commit a1dba81e17
2 changed files with 35 additions and 1 deletions

View File

@ -4,6 +4,7 @@
#include "include/function.h" #include "include/function.h"
float max(float a, float b) { float max(float a, float b) {
return a < b ? b:a; return a < b ? b:a;
} }
@ -100,3 +101,32 @@ void choose_apply_function_vector(int activation, float*** input, int dim) {
printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation); printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation);
} }
} }
ptr get_function_activation(int activation) {
if (activation == RELU) {
return &relu;
} else if (activation == -RELU) {
return &relu_derivative;
} else if (activation == SIGMOID) {
return &sigmoid;
} else if (activation == -SIGMOID) {
return &sigmoid_derivative;
} else if (activation == SOFTMAX) {
printf("Erreur, impossible de renvoyer la fonction softmax");
return NULL;
} else if (activation == -SOFTMAX) {
printf("Erreur, impossible de renvoyer la dérivée de la fonction softmax");
return NULL;
} else if (activation == TANH) {
return &tanh_;
} else if (activation == -TANH) {
return &tanh_derivative;
} else {
printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation);
return NULL;
}
}
// to use:
// float a = 5; int activation;
// pm u = get_function_activation;
// printf("%f", (*u(activation))(a));

View File

@ -2,13 +2,15 @@
#define DEF_FUNCTION_H #define DEF_FUNCTION_H
typedef float (*returnFunctionType)(float, float);
// Les dérivées sont l'opposé // Les dérivées sont l'opposé
#define TANH 1 #define TANH 1
#define SIGMOID 2 #define SIGMOID 2
#define RELU 3 #define RELU 3
#define SOFTMAX 4 #define SOFTMAX 4
typedef float (*ptr)(float);
typedef ptr (*pm)();
/* /*
* Fonction max pour les floats * Fonction max pour les floats
*/ */
@ -46,4 +48,6 @@ void choose_apply_function_matrix(int activation, float*** input, int depth, int
*/ */
void choose_apply_function_vector(int activation, float*** input, int dim); void choose_apply_function_vector(int activation, float*** input, int dim);
ptr get_function_activation(int activation);
#endif #endif