mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Added Swish-activation
This commit is contained in:
@ -102,7 +102,7 @@ typedef struct tree {
|
||||
|
||||
// activations.h
|
||||
typedef enum {
|
||||
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU
|
||||
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, SWISH
|
||||
}ACTIVATION;
|
||||
|
||||
// parser.h
|
||||
@ -339,6 +339,7 @@ struct layer {
|
||||
float *col_image;
|
||||
float * delta;
|
||||
float * output;
|
||||
float * output_sigmoid;
|
||||
int delta_pinned;
|
||||
int output_pinned;
|
||||
float * loss;
|
||||
@ -522,6 +523,7 @@ struct layer {
|
||||
float * scale_change_gpu;
|
||||
|
||||
float * output_gpu;
|
||||
float * output_sigmoid_gpu;
|
||||
float * loss_gpu;
|
||||
float * delta_gpu;
|
||||
float * rand_gpu;
|
||||
|
@ -186,6 +186,19 @@ __global__ void activate_array_kernel(float *x, int n, ACTIVATION a)
|
||||
if(i < n) x[i] = activate_kernel(x[i], a);
|
||||
}
|
||||
|
||||
|
||||
|
||||
__global__ void activate_array_swish_kernel(float *x, int n, float *output_sigmoid_gpu, float *output_gpu)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (i < n) {
|
||||
float x_val = x[i];
|
||||
float sigmoid = logistic_activate_kernel(x_val);
|
||||
output_sigmoid_gpu[i] = sigmoid;
|
||||
output_gpu[i] = x_val * sigmoid;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void activate_array_leaky_kernel(float *x, int n)
|
||||
{
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
@ -240,6 +253,16 @@ __global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delt
|
||||
if(i < n) delta[i] *= gradient_kernel(x[i], a);
|
||||
}
|
||||
|
||||
// https://github.com/BVLC/caffe/blob/04ab089db018a292ae48d51732dd6c66766b36b6/src/caffe/layers/swish_layer.cu#L28-L30
|
||||
__global__ void gradient_array_swish_kernel(float *x, int n, float *sigmoid_gpu, float *delta)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (i < n) {
|
||||
float swish = x[i];
|
||||
delta[i] *= swish + sigmoid_gpu[i] * (1 - swish); // gradient_kernel(x[i], a);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gradient_array_leaky_kernel(float *x, int n, float *delta)
|
||||
{
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
@ -303,6 +326,13 @@ extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a)
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void activate_array_swish_ongpu(float *x, int n, float *output_sigmoid_gpu, float *output_gpu)
|
||||
{
|
||||
const int num_blocks = get_number_of_blocks(n, BLOCK);
|
||||
activate_array_swish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> >(x, n, output_sigmoid_gpu, output_gpu);
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta)
|
||||
{
|
||||
const int num_blocks = get_number_of_blocks(n, BLOCK);
|
||||
@ -317,3 +347,11 @@ extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta
|
||||
gradient_array_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (x, n, a, delta);
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
extern "C" void gradient_array_swish_ongpu(float *x, int n, float *sigmoid_gpu, float *delta)
|
||||
{
|
||||
const int num_blocks = get_number_of_blocks(n, BLOCK);
|
||||
gradient_array_swish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (x, n, sigmoid_gpu, delta);
|
||||
CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
@ -45,6 +45,7 @@ char *get_activation_string(ACTIVATION a)
|
||||
ACTIVATION get_activation(char *s)
|
||||
{
|
||||
if (strcmp(s, "logistic")==0) return LOGISTIC;
|
||||
if (strcmp(s, "swish") == 0) return SWISH;
|
||||
if (strcmp(s, "loggy")==0) return LOGGY;
|
||||
if (strcmp(s, "relu")==0) return RELU;
|
||||
if (strcmp(s, "elu")==0) return ELU;
|
||||
@ -120,6 +121,18 @@ void activate_array(float *x, const int n, const ACTIVATION a)
|
||||
}
|
||||
}
|
||||
|
||||
void activate_array_swish(float *x, const int n, float * output_sigmoid, float * output)
|
||||
{
|
||||
int i;
|
||||
#pragma omp parallel for
|
||||
for (i = 0; i < n; ++i) {
|
||||
float x_val = x[i];
|
||||
float sigmoid = logistic_activate(x_val);
|
||||
output_sigmoid[i] = sigmoid;
|
||||
output[i] = x_val * sigmoid;
|
||||
}
|
||||
}
|
||||
|
||||
float gradient(float x, ACTIVATION a)
|
||||
{
|
||||
switch(a){
|
||||
@ -158,7 +171,19 @@ float gradient(float x, ACTIVATION a)
|
||||
void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta)
|
||||
{
|
||||
int i;
|
||||
#pragma omp parallel for
|
||||
for(i = 0; i < n; ++i){
|
||||
delta[i] *= gradient(x[i], a);
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/BVLC/caffe/blob/04ab089db018a292ae48d51732dd6c66766b36b6/src/caffe/layers/swish_layer.cpp#L54-L56
|
||||
void gradient_array_swish(const float *x, const int n, const float * sigmoid, float * delta)
|
||||
{
|
||||
int i;
|
||||
#pragma omp parallel for
|
||||
for (i = 0; i < n; ++i) {
|
||||
float swish = x[i];
|
||||
delta[i] *= swish + sigmoid[i]*(1 - swish);
|
||||
}
|
||||
}
|
||||
|
@ -17,10 +17,14 @@ char *get_activation_string(ACTIVATION a);
|
||||
float activate(float x, ACTIVATION a);
|
||||
float gradient(float x, ACTIVATION a);
|
||||
void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta);
|
||||
void gradient_array_swish(const float *x, const int n, const float * sigmoid, float * delta);
|
||||
void activate_array(float *x, const int n, const ACTIVATION a);
|
||||
void activate_array_swish(float *x, const int n, float * output_sigmoid, float * output);
|
||||
#ifdef GPU
|
||||
void activate_array_ongpu(float *x, int n, ACTIVATION a);
|
||||
void activate_array_swish_ongpu(float *x, int n, float *output_sigmoid_gpu, float *output_gpu);
|
||||
void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta);
|
||||
void gradient_array_swish_ongpu(float *x, int n, float *sigmoid_gpu, float *delta);
|
||||
#endif
|
||||
|
||||
static inline float stair_activate(float x)
|
||||
|
@ -391,7 +391,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
*/
|
||||
|
||||
//add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
if (l.activation != LINEAR && l.activation != LEAKY) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||
if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.output_sigmoid_gpu, l.output_gpu);
|
||||
else if (l.activation != LINEAR && l.activation != LEAKY) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||
//if(l.activation != LINEAR && l.activation != LEAKY) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||
//if (l.binary || l.xnor) swap_binary(&l);
|
||||
//cudaDeviceSynchronize();
|
||||
@ -594,7 +595,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
//#ifndef CUDNN_HALF
|
||||
//#endif // no CUDNN_HALF
|
||||
|
||||
if (l.activation != LINEAR) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||
if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.output_sigmoid_gpu, l.output_gpu);
|
||||
else if (l.activation != LINEAR) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||
//if(l.dot > 0) dot_error_gpu(l);
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
//cudaDeviceSynchronize(); // for correct profiling of performance
|
||||
@ -607,7 +609,9 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
{
|
||||
if(state.net.try_fix_nan) constrain_ongpu(l.outputs*l.batch, 1, l.delta_gpu, 1);
|
||||
gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
|
||||
|
||||
if (l.activation == SWISH) gradient_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.output_sigmoid_gpu, l.delta_gpu);
|
||||
else gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
|
||||
|
||||
if (!l.batch_normalize)
|
||||
backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
|
@ -462,7 +462,11 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
|
||||
l.scale_v = (float*)calloc(n, sizeof(float));
|
||||
}
|
||||
|
||||
if(l.activation == SWISH) l.output_sigmoid = (float*)calloc(total_batch*l.outputs, sizeof(float));
|
||||
|
||||
#ifdef GPU
|
||||
if (l.activation == SWISH) l.output_sigmoid_gpu = cuda_make_array(l.output_sigmoid, total_batch*out_h*out_w*n);
|
||||
|
||||
l.forward_gpu = forward_convolutional_layer_gpu;
|
||||
l.backward_gpu = backward_convolutional_layer_gpu;
|
||||
l.update_gpu = update_convolutional_layer_gpu;
|
||||
@ -1029,7 +1033,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
||||
add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
|
||||
|
||||
//activate_array(l.output, m*n*l.batch, l.activation);
|
||||
activate_array_cpu_custom(l.output, m*n*l.batch, l.activation);
|
||||
if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.output_sigmoid, l.output);
|
||||
else activate_array_cpu_custom(l.output, m*n*l.batch, l.activation);
|
||||
return;
|
||||
|
||||
}
|
||||
@ -1067,7 +1072,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
||||
add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
|
||||
|
||||
//activate_array(l.output, m*n*l.batch, l.activation);
|
||||
activate_array_cpu_custom(l.output, l.outputs*l.batch, l.activation);
|
||||
if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.output_sigmoid, l.output);
|
||||
else activate_array_cpu_custom(l.output, l.outputs*l.batch, l.activation);
|
||||
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
}
|
||||
@ -1080,7 +1086,8 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
|
||||
int n = l.size*l.size*l.c / l.groups;
|
||||
int k = l.out_w*l.out_h;
|
||||
|
||||
gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
|
||||
if (l.activation == SWISH) gradient_array_swish(l.output, l.outputs*l.batch, l.output_sigmoid, l.delta);
|
||||
else gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
|
||||
|
||||
if (l.batch_normalize) {
|
||||
backward_batchnorm_layer(l, state);
|
||||
|
10
src/layer.c
10
src/layer.c
@ -84,8 +84,9 @@ void free_layer(layer l)
|
||||
l.output = NULL;
|
||||
}
|
||||
#endif // GPU
|
||||
if (l.delta) free(l.delta);
|
||||
if (l.output) free(l.output);
|
||||
if (l.delta) free(l.delta), l.delta = NULL;
|
||||
if (l.output) free(l.output), l.output = NULL;
|
||||
if (l.output_sigmoid) free(l.output_sigmoid), l.output_sigmoid = NULL;
|
||||
if (l.squared) free(l.squared);
|
||||
if (l.norms) free(l.norms);
|
||||
if (l.spatial_mean) free(l.spatial_mean);
|
||||
@ -165,8 +166,9 @@ void free_layer(layer l)
|
||||
if (l.bias_updates_gpu) cuda_free(l.bias_updates_gpu), l.bias_updates_gpu = NULL;
|
||||
if (l.scales_gpu) cuda_free(l.scales_gpu), l.scales_gpu = NULL;
|
||||
if (l.scale_updates_gpu) cuda_free(l.scale_updates_gpu), l.scale_updates_gpu = NULL;
|
||||
if (l.output_gpu) cuda_free(l.output_gpu);
|
||||
if (l.delta_gpu) cuda_free(l.delta_gpu);
|
||||
if (l.output_gpu) cuda_free(l.output_gpu), l.output_gpu = NULL;
|
||||
if (l.output_sigmoid_gpu) cuda_free(l.output_sigmoid_gpu), l.output_sigmoid_gpu = NULL;
|
||||
if (l.delta_gpu) cuda_free(l.delta_gpu), l.delta_gpu = NULL;
|
||||
if (l.rand_gpu) cuda_free(l.rand_gpu);
|
||||
if (l.squared_gpu) cuda_free(l.squared_gpu);
|
||||
if (l.norms_gpu) cuda_free(l.norms_gpu);
|
||||
|
Reference in New Issue
Block a user