From a1dba81e17b12cdac6d83f49cc8bd31ec89e88b0 Mon Sep 17 00:00:00 2001 From: julienChemillier Date: Tue, 1 Nov 2022 11:20:17 +0100 Subject: [PATCH] Add 'get_function_activation' function --- src/cnn/function.c | 30 ++++++++++++++++++++++++++++++ src/cnn/include/function.h | 6 +++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/cnn/function.c b/src/cnn/function.c index c89df78..edd3695 100644 --- a/src/cnn/function.c +++ b/src/cnn/function.c @@ -4,6 +4,7 @@ #include "include/function.h" + float max(float a, float b) { return a < b ? b:a; } @@ -100,3 +101,32 @@ void choose_apply_function_vector(int activation, float*** input, int dim) { printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation); } } + +ptr get_function_activation(int activation) { + if (activation == RELU) { + return &relu; + } else if (activation == -RELU) { + return &relu_derivative; + } else if (activation == SIGMOID) { + return &sigmoid; + } else if (activation == -SIGMOID) { + return &sigmoid_derivative; + } else if (activation == SOFTMAX) { + printf("Erreur, impossible de renvoyer la fonction softmax"); + return NULL; + } else if (activation == -SOFTMAX) { + printf("Erreur, impossible de renvoyer la dérivée de la fonction softmax"); + return NULL; + } else if (activation == TANH) { + return &tanh_; + } else if (activation == -TANH) { + return &tanh_derivative; + } else { + printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation); + return NULL; + } +} +// to use: +// float a = 5; int activation; +// pm u = get_function_activation; +// printf("%f", (*u(activation))(a)); diff --git a/src/cnn/include/function.h b/src/cnn/include/function.h index 78189af..a3fdcf8 100644 --- a/src/cnn/include/function.h +++ b/src/cnn/include/function.h @@ -2,13 +2,15 @@ #define DEF_FUNCTION_H -typedef float (*returnFunctionType)(float, float); // Les dérivées sont l'opposé #define TANH 1 #define SIGMOID 2 #define RELU 3 #define SOFTMAX 4 +typedef float (*ptr)(float); +typedef ptr (*pm)(); + /* * Fonction max pour les floats */ @@ -46,4 +48,6 @@ void choose_apply_function_matrix(int activation, float*** input, int depth, int */ void choose_apply_function_vector(int activation, float*** input, int dim); +ptr get_function_activation(int activation); + #endif \ No newline at end of file