tipe/src/cnn/initialisation.c

94 lines
2.9 KiB
C
Raw Normal View History

2022-07-05 08:13:25 +02:00
#include <stdlib.h>
#include <math.h>
2022-09-28 10:20:08 +02:00
2022-10-24 12:54:51 +02:00
#include "../include/colors.h"
2022-09-16 14:53:35 +02:00
#include "include/initialisation.h"
2022-07-05 08:13:25 +02:00
2022-11-04 10:54:32 +01:00
// glorot (wavier initialisation) linear, tanh, softmax, logistic (1/(fan_in+fan_out/2))
// he initialisation : RELU (2/fan_in)
// LeCun initialisation: SELU (1/fan_in)
2022-07-05 08:13:25 +02:00
2022-11-04 10:54:32 +01:00
// Only uniform for the moment
void initialisation_1d_matrix(int initialisation, float* matrix, int dim, int n_in) {
2022-11-04 10:54:32 +01:00
int n;
if (initialisation == GLOROT) {
n = (n_in + dim)/2;
2023-01-17 15:34:29 +01:00
2022-11-04 10:54:32 +01:00
} else if (initialisation == HE) {
2023-01-17 15:34:29 +01:00
n = n_in/2;
2022-11-04 10:54:32 +01:00
} else {
printf_warning("Initialisation non reconnue dans 'initialisation_1d_matrix' \n");
return ;
}
float lower_bound = -1/sqrt((double)n);
float distance_bounds = -2*lower_bound;
for (int i=0; i < dim; i++) {
matrix[i] = lower_bound + RAND_FLT()*distance_bounds;
2022-07-05 08:13:25 +02:00
}
}
2022-11-04 10:54:32 +01:00
void initialisation_2d_matrix(int initialisation, float** matrix, int dim1, int dim2, int n_in, int n_out) {
int n;
if (initialisation == GLOROT) {
n = (n_in + n_out)/2;
2023-01-17 15:34:29 +01:00
2022-11-04 10:54:32 +01:00
} else if (initialisation == HE) {
2023-01-17 15:34:29 +01:00
n = n_in/2;
2022-11-04 10:54:32 +01:00
} else {
printf_warning("Initialisation non reconnue dans 'initialisation_2d_matrix' \n");
return ;
}
float lower_bound = -1/sqrt((double)n);
float distance_bounds = -2*lower_bound;
for (int i=0; i < dim1; i++) {
for (int j=0; j < dim2; j++) {
matrix[i][j] = lower_bound + RAND_FLT()*distance_bounds;
2022-07-05 08:13:25 +02:00
}
}
}
2022-11-04 10:54:32 +01:00
void initialisation_3d_matrix(int initialisation, float*** matrix, int depth, int dim1, int dim2, int n_in, int n_out) {
int n;
if (initialisation == GLOROT) {
n = (n_in + n_out)/2;
2023-01-17 15:34:29 +01:00
2022-11-04 10:54:32 +01:00
} else if (initialisation == HE) {
2023-01-17 15:34:29 +01:00
n = n_in/2;
2022-11-04 10:54:32 +01:00
} else {
printf_warning("Initialisation non reconnue dans 'initialisation_3d_matrix' \n");
return ;
}
float lower_bound = -1/sqrt((double)n);
float distance_bounds = -2*lower_bound;
for (int i=0; i < depth; i++) {
2022-11-04 10:54:32 +01:00
for (int j=0; j < dim1; j++) {
for (int k=0; k < dim2; k++) {
matrix[i][j][k] = lower_bound + RAND_FLT()*distance_bounds;
2022-07-05 08:13:25 +02:00
}
}
}
}
2022-11-04 10:54:32 +01:00
void initialisation_4d_matrix(int initialisation, float**** matrix, int depth1, int depth2, int dim1, int dim2, int n_in, int n_out) {
int n;
if (initialisation == GLOROT) {
n = (n_in + n_out)/2;
2023-01-17 15:34:29 +01:00
2022-11-04 10:54:32 +01:00
} else if (initialisation == HE) {
2023-01-17 15:34:29 +01:00
n = n_in/2;
2022-11-04 10:54:32 +01:00
} else {
printf_warning("Initialisation non reconnue dans 'initialisation_3d_matrix' \n");
return ;
}
float lower_bound = -1/sqrt((double)n);
float distance_bounds = -2*lower_bound;
for (int i=0; i < depth1; i++) {
for (int j=0; j < depth2; j++) {
for (int k=0; k < dim1; k++) {
for (int l=0; l < dim2; l++) {
matrix[i][j][k][l] = lower_bound + RAND_FLT()*distance_bounds;
2022-07-05 08:13:25 +02:00
}
}
}
}
}