extern "C" get_activation_function_cuda

This commit is contained in:
augustin64 2023-03-30 18:11:00 +02:00
parent 953c92ac61
commit dd16e34cce
3 changed files with 11 additions and 14 deletions

View File

@ -107,20 +107,16 @@ float leaky_relu_derivative(float x) {
//* Tanh
#ifdef __CUDACC__
__device__
#endif
float device_tanh_(float x) {
__device__ float device_tanh_(float x) {
return tanh(x);
}
#ifdef __CUDACC__
__device__
#endif
float device_tanh_derivative(float x) {
__device__ float device_tanh_derivative(float x) {
float a = tanh(x);
return 1 - a*a;
}
#endif
float tanh_(float x) {
return tanh(x);
}
@ -303,6 +299,7 @@ funcPtr get_activation_function(int activation) {
#ifdef __CUDACC__
extern "C"
funcPtr get_activation_function_cuda(int activation) {
funcPtr host_function;

View File

@ -107,19 +107,15 @@ float leaky_relu_derivative(float x) {
//* Tanh
#ifdef __CUDACC__
__device__
#endif
float device_tanh_(float x) {
__device__ float device_tanh_(float x) {
return tanh(x);
}
#ifdef __CUDACC__
__device__
#endif
float device_tanh_derivative(float x) {
__device__ float device_tanh_derivative(float x) {
float a = tanh(x);
return 1 - a*a;
}
#endif
float tanh_(float x) {
return tanh(x);
@ -303,6 +299,7 @@ funcPtr get_activation_function(int activation) {
#ifdef __CUDACC__
extern "C"
funcPtr get_activation_function_cuda(int activation) {
funcPtr host_function;

View File

@ -142,6 +142,9 @@ funcPtr get_activation_function(int activation);
/*
* Récupère un pointeur sur le device vers la fonction d'activation demandée puis le transforme en pointeur sur l'host
*/
#ifdef __CUDACC__
extern "C"
funcPtr get_activation_function_cuda(int activation);
#endif
#endif