From 5343aa423563c107e0071b1427ad5defc27b56d2 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Wed, 16 Jan 2019 18:08:11 +0300 Subject: [PATCH] CUDA minor performance improvement --- src/activation_kernels.cu | 24 ++++++++++++++--- src/blas.h | 2 ++ src/blas_kernels.cu | 51 +++++++++++++++++++++++++++++++++++- src/convolutional_kernels.cu | 8 +++--- src/convolutional_layer.c | 2 +- src/shortcut_layer.c | 6 +++-- src/yolo_layer.c | 7 ++--- 7 files changed, 86 insertions(+), 14 deletions(-) diff --git a/src/activation_kernels.cu b/src/activation_kernels.cu index 4285b5f3..ee112b18 100644 --- a/src/activation_kernels.cu +++ b/src/activation_kernels.cu @@ -192,8 +192,23 @@ __global__ void activate_array_leaky_kernel(float *x, int n) { int index = blockIdx.x*blockDim.x + threadIdx.x; if (index < n) { - float val = x[index]; - x[index] = (val > 0) ? val : val / 10; + x[index] = leaky_activate_kernel(x[index]); + } +} + +__global__ void activate_array_selu_kernel(float *x, int n) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + x[index] = selu_activate_kernel(x[index]); + } +} + +__global__ void activate_array_logistic_kernel(float *x, int n) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + x[index] = logistic_activate_kernel(x[index]); } } @@ -205,7 +220,10 @@ __global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delt extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a) { - if(a == LEAKY) activate_array_leaky_kernel << <(n / BLOCK + 1), BLOCK, 0, get_cuda_stream() >> >(x, n); + if (a == LINEAR) return; + else if(a == LEAKY) activate_array_leaky_kernel << <(n / BLOCK + 1), BLOCK, 0, get_cuda_stream() >> >(x, n); + else if (a == LOGISTIC) activate_array_logistic_kernel << <(n / BLOCK + 1), BLOCK, 0, get_cuda_stream() >> >(x, n); + else if (a == SELU) activate_array_selu_kernel << <(n / BLOCK + 1), BLOCK, 0, get_cuda_stream() >> >(x, n); else activate_array_kernel<<>>(x, n, a); check_error(cudaPeekAtLastError()); } diff --git a/src/blas.h b/src/blas.h index d5f67250..9b7f3d5a 100644 --- a/src/blas.h +++ b/src/blas.h @@ -46,6 +46,7 @@ void softmax_x_ent_cpu(int n, float *pred, float *truth, float *delta, float *er void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY); void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY); +void simple_copy_ongpu(int size, float *src, float *dst); void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY); void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY); void scal_ongpu(int N, float ALPHA, float * X, int INCX); @@ -69,6 +70,7 @@ void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *varianc void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance); void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean); void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out); +void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out); void scale_bias_gpu(float *output, float *biases, int batch, int n, int size); void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates); void scale_bias_gpu(float *output, float *biases, int batch, int n, int size); diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index 3fdc742d..fffd46f0 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -439,6 +439,13 @@ __global__ void copy_kernel(int N, float *X, int OFFX, int INCX, float *Y, int if(i < N) Y[i*INCY + OFFY] = X[i*INCX + OFFX]; } +__global__ void simple_copy_kernel(int size, float *src, float *dst) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < size) + dst[index] = src[index]; +} + __global__ void mul_kernel(int N, float *X, int INCX, float *Y, int INCY) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; @@ -557,6 +564,13 @@ extern "C" void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY) copy_ongpu_offset(N, X, 0, INCX, Y, 0, INCY); } +extern "C" void simple_copy_ongpu(int size, float *src, float *dst) +{ + const int num_blocks = size / BLOCK + 1; + simple_copy_kernel << > >(size, src, dst); + check_error(cudaPeekAtLastError()); +} + extern "C" void mul_ongpu(int N, float * X, int INCX, float * Y, int INCY) { mul_kernel<<>>(N, X, INCX, Y, INCY); @@ -678,6 +692,41 @@ extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int check_error(cudaPeekAtLastError()); } +__global__ void input_shortcut_kernel(float *in, int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out) +{ + int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (id >= size) return; + int i = id % minw; + id /= minw; + int j = id % minh; + id /= minh; + int k = id % minc; + id /= minc; + int b = id % batch; + + int out_index = i*sample + w2*(j*sample + h2*(k + c2*b)); + int add_index = i*stride + w1*(j*stride + h1*(k + c1*b)); + out[out_index] = in[out_index] + add[add_index]; +} + +extern "C" void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out) +{ + int minw = (w1 < w2) ? w1 : w2; + int minh = (h1 < h2) ? h1 : h2; + int minc = (c1 < c2) ? c1 : c2; + + int stride = w1 / w2; + int sample = w2 / w1; + assert(stride == h1 / h2); + assert(sample == h2 / h1); + if (stride < 1) stride = 1; + if (sample < 1) sample = 1; + + int size = batch * minw * minh * minc; + input_shortcut_kernel << > >(in, size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out); + check_error(cudaPeekAtLastError()); +} + __global__ void smooth_l1_kernel(int n, float *pred, float *truth, float *delta, float *error) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; @@ -877,7 +926,7 @@ __global__ void upsample_kernel(size_t N, float *x, int w, int h, int c, int bat extern "C" void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out) { size_t size = w*h*c*batch*stride*stride; - upsample_kernel << > >(size, in, w, h, c, batch, stride, forward, scale, out); + upsample_kernel << > >(size, in, w, h, c, batch, stride, forward, scale, out); check_error(cudaPeekAtLastError()); } diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index 0014f32a..6c2dc9fa 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -36,7 +36,7 @@ __global__ void binarize_kernel(float *x, int n, float *binary) void binarize_gpu(float *x, int n, float *binary) { - binarize_kernel<<>>(x, n, binary); + binarize_kernel<<>>(x, n, binary); check_error(cudaPeekAtLastError()); } @@ -79,7 +79,7 @@ __global__ void binarize_weights_kernel(float *weights, int n, int size, float * void binarize_weights_gpu(float *weights, int n, int size, float *binary) { - binarize_weights_kernel << > >(weights, n, size, binary); + binarize_weights_kernel << > >(weights, n, size, binary); check_error(cudaPeekAtLastError()); } @@ -126,7 +126,7 @@ void fast_binarize_weights_gpu(float *weights, int n, int size, float *binary, f set_zero_kernel << <(n/BLOCK + 1), BLOCK >> > (mean_arr_gpu, n); reduce_kernel << > > (weights, n, size, mean_arr_gpu); - binarize_weights_mean_kernel << > > (weights, n, size, binary, mean_arr_gpu); + binarize_weights_mean_kernel << > > (weights, n, size, binary, mean_arr_gpu); check_error(cudaPeekAtLastError()); } else { @@ -296,7 +296,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) //printf("\n n = %d, n % 32 = %d, new_ldb = %d, new_ldb % 32 = %d \n", n, n % 32, new_ldb, new_ldb % 32); //start_timer(); - transpose_uint32_gpu_2((uint32_t *)state.workspace, (uint32_t *)l.transposed_align_workspace_gpu, new_k, n, n, new_ldb); + transpose_uint32_gpu((uint32_t *)state.workspace, (uint32_t *)l.transposed_align_workspace_gpu, new_k, n, n, new_ldb); //cudaDeviceSynchronize(); //stop_timer_and_show_name("transpose_uint32_gpu"); diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 124514a2..30529184 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -883,7 +883,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) //gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); //gemm_nn_custom(m, n, k, 1, a, k, b, n, c, n); - if (l.xnor && l.align_bit_weights && !state.train && (l.stride == 1 && l.pad == 1)) + if (l.xnor && l.align_bit_weights && !state.train) { memset(b, 0, l.bit_align*l.size*l.size*l.c * sizeof(float)); diff --git a/src/shortcut_layer.c b/src/shortcut_layer.c index 9fa18db8..7cdc5368 100644 --- a/src/shortcut_layer.c +++ b/src/shortcut_layer.c @@ -73,8 +73,10 @@ void backward_shortcut_layer(const layer l, network_state state) #ifdef GPU void forward_shortcut_layer_gpu(const layer l, network_state state) { - copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1); - shortcut_gpu(l.batch, l.w, l.h, l.c, state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu); + //copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1); + //simple_copy_ongpu(l.outputs*l.batch, state.input, l.output_gpu); + //shortcut_gpu(l.batch, l.w, l.h, l.c, state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu); + input_shortcut_gpu(state.input, l.batch, l.w, l.h, l.c, state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu); activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); } diff --git a/src/yolo_layer.c b/src/yolo_layer.c index ac0768a2..d3962a16 100644 --- a/src/yolo_layer.c +++ b/src/yolo_layer.c @@ -399,14 +399,15 @@ int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, void forward_yolo_layer_gpu(const layer l, network_state state) { - copy_ongpu(l.batch*l.inputs, state.input, 1, l.output_gpu, 1); + //copy_ongpu(l.batch*l.inputs, state.input, 1, l.output_gpu, 1); + simple_copy_ongpu(l.batch*l.inputs, state.input, l.output_gpu); int b, n; for (b = 0; b < l.batch; ++b){ for(n = 0; n < l.n; ++n){ int index = entry_index(l, b, n*l.w*l.h, 0); - activate_array_ongpu(l.output_gpu + index, 2*l.w*l.h, LOGISTIC); + activate_array_ongpu(l.output_gpu + index, 2*l.w*l.h, LOGISTIC); // x,y index = entry_index(l, b, n*l.w*l.h, 4); - activate_array_ongpu(l.output_gpu + index, (1+l.classes)*l.w*l.h, LOGISTIC); + activate_array_ongpu(l.output_gpu + index, (1+l.classes)*l.w*l.h, LOGISTIC); // classes and objectness } } if(!state.train || l.onlyforward){