darknet/src/activation_kernels.cu

108 lines
3.6 KiB
Plaintext
Raw Normal View History

2015-11-16 06:51:26 +03:00
#include "cuda_runtime.h"
#include "curand.h"
#include "cublas_v2.h"
2015-01-23 03:38:24 +03:00
extern "C" {
#include "activations.h"
#include "cuda.h"
}
__device__ float linear_activate_kernel(float x){return x;}
2015-03-08 21:31:12 +03:00
__device__ float logistic_activate_kernel(float x){return 1./(1. + exp(-x));}
2015-01-23 03:38:24 +03:00
__device__ float relu_activate_kernel(float x){return x*(x>0);}
2015-11-26 22:48:01 +03:00
__device__ float elu_activate_kernel(float x){return (x >= 0)*x + (x < 0)*(exp(x)-1);}
2015-04-11 11:24:07 +03:00
__device__ float relie_activate_kernel(float x){return x*(x>0);}
2015-01-23 03:38:24 +03:00
__device__ float ramp_activate_kernel(float x){return x*(x>0)+.1*x;}
2015-07-08 10:36:43 +03:00
__device__ float leaky_activate_kernel(float x){return (x>0) ? x : .1*x;}
2015-01-23 03:38:24 +03:00
__device__ float tanh_activate_kernel(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
2015-03-21 22:25:14 +03:00
__device__ float plse_activate_kernel(float x)
{
if(x < -4) return .01 * (x + 4);
if(x > 4) return .01 * (x - 4) + 1;
return .125*x + .5;
}
2015-01-23 03:38:24 +03:00
__device__ float linear_gradient_kernel(float x){return 1;}
2015-03-08 21:31:12 +03:00
__device__ float logistic_gradient_kernel(float x){return (1-x)*x;}
2015-01-23 03:38:24 +03:00
__device__ float relu_gradient_kernel(float x){return (x>0);}
2015-11-26 22:48:01 +03:00
__device__ float elu_gradient_kernel(float x){return (x >= 0) + (x < 0)*(x + 1);}
2015-04-11 11:24:07 +03:00
__device__ float relie_gradient_kernel(float x){return (x>0) ? 1 : .01;}
2015-01-23 03:38:24 +03:00
__device__ float ramp_gradient_kernel(float x){return (x>0)+.1;}
2015-07-08 10:36:43 +03:00
__device__ float leaky_gradient_kernel(float x){return (x>0) ? 1 : .1;}
2015-01-23 03:38:24 +03:00
__device__ float tanh_gradient_kernel(float x){return 1-x*x;}
2015-03-21 22:25:14 +03:00
__device__ float plse_gradient_kernel(float x){return (x < 0 || x > 1) ? .01 : .125;}
2015-01-23 03:38:24 +03:00
__device__ float activate_kernel(float x, ACTIVATION a)
{
switch(a){
case LINEAR:
return linear_activate_kernel(x);
2015-03-08 21:31:12 +03:00
case LOGISTIC:
return logistic_activate_kernel(x);
2015-01-23 03:38:24 +03:00
case RELU:
return relu_activate_kernel(x);
2015-11-26 22:48:01 +03:00
case ELU:
return elu_activate_kernel(x);
2015-04-11 11:24:07 +03:00
case RELIE:
return relie_activate_kernel(x);
2015-01-23 03:38:24 +03:00
case RAMP:
return ramp_activate_kernel(x);
2015-07-08 10:36:43 +03:00
case LEAKY:
return leaky_activate_kernel(x);
2015-01-23 03:38:24 +03:00
case TANH:
return tanh_activate_kernel(x);
2015-03-21 22:25:14 +03:00
case PLSE:
return plse_activate_kernel(x);
2015-01-23 03:38:24 +03:00
}
return 0;
}
__device__ float gradient_kernel(float x, ACTIVATION a)
{
switch(a){
case LINEAR:
return linear_gradient_kernel(x);
2015-03-08 21:31:12 +03:00
case LOGISTIC:
return logistic_gradient_kernel(x);
2015-01-23 03:38:24 +03:00
case RELU:
return relu_gradient_kernel(x);
2015-11-26 22:48:01 +03:00
case ELU:
return elu_gradient_kernel(x);
2015-04-11 11:24:07 +03:00
case RELIE:
return relie_gradient_kernel(x);
2015-01-23 03:38:24 +03:00
case RAMP:
return ramp_gradient_kernel(x);
2015-07-08 10:36:43 +03:00
case LEAKY:
return leaky_gradient_kernel(x);
2015-01-23 03:38:24 +03:00
case TANH:
return tanh_gradient_kernel(x);
2015-03-21 22:25:14 +03:00
case PLSE:
return plse_gradient_kernel(x);
2015-01-23 03:38:24 +03:00
}
return 0;
}
__global__ void activate_array_kernel(float *x, int n, ACTIVATION a)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(i < n) x[i] = activate_kernel(x[i], a);
}
__global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delta)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(i < n) delta[i] *= gradient_kernel(x[i], a);
}
extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a)
{
activate_array_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, a);
check_error(cudaPeekAtLastError());
}
extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta)
{
gradient_array_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, a, delta);
check_error(cudaPeekAtLastError());
}