extern "C" get_activation_function_cuda

This commit is contained in:
augustin64 2023-03-30 18:11:00 +02:00
parent 953c92ac61
commit dd16e34cce
3 changed files with 11 additions and 14 deletions

View File

@ -107,20 +107,16 @@ float leaky_relu_derivative(float x) {
//* Tanh //* Tanh
#ifdef __CUDACC__ #ifdef __CUDACC__
__device__ __device__ float device_tanh_(float x) {
#endif
float device_tanh_(float x) {
return tanh(x); return tanh(x);
} }
#ifdef __CUDACC__ __device__ float device_tanh_derivative(float x) {
__device__
#endif
float device_tanh_derivative(float x) {
float a = tanh(x); float a = tanh(x);
return 1 - a*a; return 1 - a*a;
} }
#endif
float tanh_(float x) { float tanh_(float x) {
return tanh(x); return tanh(x);
} }
@ -303,6 +299,7 @@ funcPtr get_activation_function(int activation) {
#ifdef __CUDACC__ #ifdef __CUDACC__
extern "C"
funcPtr get_activation_function_cuda(int activation) { funcPtr get_activation_function_cuda(int activation) {
funcPtr host_function; funcPtr host_function;

View File

@ -107,19 +107,15 @@ float leaky_relu_derivative(float x) {
//* Tanh //* Tanh
#ifdef __CUDACC__ #ifdef __CUDACC__
__device__ __device__ float device_tanh_(float x) {
#endif
float device_tanh_(float x) {
return tanh(x); return tanh(x);
} }
#ifdef __CUDACC__ __device__ float device_tanh_derivative(float x) {
__device__
#endif
float device_tanh_derivative(float x) {
float a = tanh(x); float a = tanh(x);
return 1 - a*a; return 1 - a*a;
} }
#endif
float tanh_(float x) { float tanh_(float x) {
return tanh(x); return tanh(x);
@ -303,6 +299,7 @@ funcPtr get_activation_function(int activation) {
#ifdef __CUDACC__ #ifdef __CUDACC__
extern "C"
funcPtr get_activation_function_cuda(int activation) { funcPtr get_activation_function_cuda(int activation) {
funcPtr host_function; funcPtr host_function;

View File

@ -142,6 +142,9 @@ funcPtr get_activation_function(int activation);
/* /*
* Récupère un pointeur sur le device vers la fonction d'activation demandée puis le transforme en pointeur sur l'host * Récupère un pointeur sur le device vers la fonction d'activation demandée puis le transforme en pointeur sur l'host
*/ */
#ifdef __CUDACC__
extern "C"
funcPtr get_activation_function_cuda(int activation); funcPtr get_activation_function_cuda(int activation);
#endif
#endif #endif