tipe/src/cnn/function.c

150 lines
3.8 KiB
C
Raw Normal View History

2022-07-05 08:13:25 +02:00
#include <stdio.h>
#include <math.h>
#include <float.h>
2022-10-24 12:54:51 +02:00
2022-09-16 14:53:35 +02:00
#include "include/function.h"
2022-07-05 08:13:25 +02:00
2022-11-01 11:20:17 +01:00
2022-11-03 18:13:01 +01:00
float max_float(float a, float b) {
return a < b ? b:a;
2022-07-05 08:13:25 +02:00
}
2022-11-12 14:20:13 +01:00
float identity(float x) {
return x;
}
float identity_derivative(float x) {
(void)x;
return 1;
}
2022-07-05 08:13:25 +02:00
float sigmoid(float x) {
return 1/(1 + exp(-x));
}
float sigmoid_derivative(float x) {
float tmp = exp(-x);
return tmp/((1+tmp)*(1+tmp));
}
float relu(float x) {
2022-11-03 18:13:01 +01:00
return max_float(0, x);
2022-07-05 08:13:25 +02:00
}
float relu_derivative(float x) {
if (x > 0)
return 1;
return 0;
}
float tanh_(float x) {
return tanh(x);
}
float tanh_derivative(float x) {
float a = tanh(x);
return 1 - a*a;
}
void apply_softmax_input(float ***input, int depth, int rows, int columns) {
float m = FLT_MIN;
float sum=0;
for (int i=0; i < depth; i++) {
for (int j=0; j < rows; j++) {
for (int k=0; k < columns; k++) {
2022-11-03 18:13:01 +01:00
m = max_float(m, input[i][j][k]);
2022-07-05 08:13:25 +02:00
}
}
}
for (int i=0; i < depth; i++) {
for (int j=0; j < rows; j++) {
for (int k=0; k < columns; k++) {
2022-07-05 08:13:25 +02:00
input[i][j][k] = exp(m-input[i][j][k]);
sum += input[i][j][k];
}
}
}
for (int i=0; i < depth; i++) {
for (int j=0; j < rows; j++) {
for (int k=0; k < columns; k++) {
2022-07-05 08:13:25 +02:00
input[i][j][k] = input[i][j][k]/sum;
}
}
}
}
void apply_function_input(float (*f)(float), float*** input, int depth, int rows, int columns) {
for (int i=0; i < depth; i++) {
for (int j=0; j < rows; j++) {
for (int k=0; k < columns; k++) {
2022-07-05 08:13:25 +02:00
input[i][j][k] = (*f)(input[i][j][k]);
}
}
}
}
2022-09-30 15:50:29 +02:00
void choose_apply_function_matrix(int activation, float*** input, int depth, int dim) {
2022-07-05 08:13:25 +02:00
if (activation == RELU) {
2022-09-30 15:50:29 +02:00
apply_function_input(relu, input, depth, dim, dim);
2022-10-26 17:32:54 +02:00
} else if (activation == SIGMOID) {
2022-09-30 15:50:29 +02:00
apply_function_input(sigmoid, input, depth, dim, dim);
2022-10-26 17:32:54 +02:00
} else if (activation == SOFTMAX) {
2022-09-30 15:50:29 +02:00
apply_softmax_input(input, depth, dim, dim);
2022-10-26 17:32:54 +02:00
} else if (activation == TANH) {
2022-09-30 15:50:29 +02:00
apply_function_input(tanh_, input, depth, dim, dim);
2022-10-26 17:32:54 +02:00
} else {
2022-09-30 15:50:29 +02:00
printf("Erreur, fonction d'activation inconnue (choose_apply_function_matrix): %d\n", activation);
}
}
void choose_apply_function_vector(int activation, float*** input, int dim) {
if (activation == RELU) {
apply_function_input(relu, input, 1, 1, dim);
2022-10-26 17:32:54 +02:00
} else if (activation == SIGMOID) {
2022-09-30 15:50:29 +02:00
apply_function_input(sigmoid, input, 1, 1, dim);
2022-10-26 17:32:54 +02:00
} else if (activation == SOFTMAX) {
2022-09-30 15:50:29 +02:00
apply_softmax_input(input, 1, 1, dim);
2022-10-26 17:32:54 +02:00
} else if (activation == TANH) {
2022-09-30 15:50:29 +02:00
apply_function_input(tanh_, input, 1, 1, dim);
2022-10-26 17:32:54 +02:00
} else {
printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation);
2022-09-30 15:50:29 +02:00
}
2022-10-26 17:32:54 +02:00
}
2022-11-01 11:20:17 +01:00
ptr get_function_activation(int activation) {
if (activation == RELU) {
return &relu;
2022-11-12 14:20:13 +01:00
}
if (activation == -RELU) {
2022-11-01 11:20:17 +01:00
return &relu_derivative;
2022-11-12 14:20:13 +01:00
}
if (activation == -IDENTITY) {
return &identity_derivative;
}
if (activation == IDENTITY) {
return &identity;
}
if (activation == SIGMOID) {
2022-11-01 11:20:17 +01:00
return &sigmoid;
2022-11-12 14:20:13 +01:00
}
if (activation == -SIGMOID) {
2022-11-01 11:20:17 +01:00
return &sigmoid_derivative;
2022-11-12 14:20:13 +01:00
}
if (activation == SOFTMAX) {
2022-11-09 12:55:55 +01:00
printf("Erreur, impossible de renvoyer la fonction softmax\n");
2022-11-01 11:20:17 +01:00
return NULL;
2022-11-12 14:20:13 +01:00
}
if (activation == -SOFTMAX) {
2022-11-09 12:55:55 +01:00
printf("Erreur, impossible de renvoyer la dérivée de la fonction softmax\n");
2022-11-01 11:20:17 +01:00
return NULL;
2022-11-12 14:20:13 +01:00
}
if (activation == TANH) {
2022-11-01 11:20:17 +01:00
return &tanh_;
2022-11-12 14:20:13 +01:00
}
if (activation == -TANH) {
2022-11-01 11:20:17 +01:00
return &tanh_derivative;
}
2022-11-12 14:20:13 +01:00
printf("Erreur, fonction d'activation inconnue (choose_apply_function_vector): %d\n", activation);
return NULL;
2022-11-01 11:20:17 +01:00
}