From 61156239e09cc839ce4324b593de89673da1352a Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Sun, 3 Feb 2019 00:18:30 +0300 Subject: [PATCH] Minor performance improvement --- src/activation_kernels.cu | 52 +++++++++++++++++++++++++++++++----- src/batchnorm_layer.c | 15 +++++++---- src/convolutional_kernels.cu | 36 ++++++++++++++----------- src/network_kernels.cu | 2 +- 4 files changed, 78 insertions(+), 27 deletions(-) diff --git a/src/activation_kernels.cu b/src/activation_kernels.cu index 29154824..aec279d5 100644 --- a/src/activation_kernels.cu +++ b/src/activation_kernels.cu @@ -164,7 +164,7 @@ __global__ void binary_gradient_array_kernel(float *x, float *dy, int n, int s, extern "C" void binary_gradient_array_gpu(float *x, float *dx, int n, int size, BINARY_ACTIVATION a, float *y) { binary_gradient_array_kernel << > >(x, dx, n / 2, size, a, y); - check_error(cudaPeekAtLastError()); + CHECK_CUDA(cudaPeekAtLastError()); } __global__ void binary_activate_array_kernel(float *x, int n, int s, BINARY_ACTIVATION a, float *y) { @@ -179,7 +179,7 @@ __global__ void binary_activate_array_kernel(float *x, int n, int s, BINARY_ACTI extern "C" void binary_activate_array_gpu(float *x, int n, int size, BINARY_ACTIVATION a, float *y) { binary_activate_array_kernel << > >(x, n / 2, size, a, y); - check_error(cudaPeekAtLastError()); + CHECK_CUDA(cudaPeekAtLastError()); } __global__ void activate_array_kernel(float *x, int n, ACTIVATION a) @@ -218,6 +218,38 @@ __global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delt if(i < n) delta[i] *= gradient_kernel(x[i], a); } +__global__ void gradient_array_leaky_kernel(float *x, int n, float *delta) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + delta[index] *= leaky_gradient_kernel(x[index]); + } +} + +__global__ void gradient_array_selu_kernel(float *x, int n, float *delta) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + delta[index] *= selu_gradient_kernel(x[index]); + } +} + +__global__ void gradient_array_logistic_kernel(float *x, int n, float *delta) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + delta[index] *= logistic_gradient_kernel(x[index]); + } +} + +__global__ void gradient_array_hardtan_kernel(float *x, int n, float *delta) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + delta[index] *= hardtan_gradient_kernel(x[index]); + } +} + extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a) { const int num_blocks = get_number_of_blocks(n, BLOCK); @@ -225,12 +257,20 @@ extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a) else if(a == LEAKY) activate_array_leaky_kernel << > >(x, n); else if (a == LOGISTIC) activate_array_logistic_kernel << > >(x, n); else if (a == SELU) activate_array_selu_kernel << > >(x, n); - else activate_array_kernel<<>>(x, n, a); - check_error(cudaPeekAtLastError()); + else + activate_array_kernel<<>>(x, n, a); + CHECK_CUDA(cudaPeekAtLastError()); } extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta) { - gradient_array_kernel<<>>(x, n, a, delta); - check_error(cudaPeekAtLastError()); + const int num_blocks = get_number_of_blocks(n, BLOCK); + if (a == LINEAR) return; + else if (a == LEAKY) gradient_array_leaky_kernel << > >(x, n, delta); + else if (a == LOGISTIC) gradient_array_logistic_kernel << > >(x, n, delta); + else if (a == SELU) gradient_array_selu_kernel << > >(x, n, delta); + else if (a == HARDTAN) gradient_array_hardtan_kernel << > >(x, n, delta); + else + gradient_array_kernel << > > (x, n, a, delta); + CHECK_CUDA(cudaPeekAtLastError()); } diff --git a/src/batchnorm_layer.c b/src/batchnorm_layer.c index 3fa129db..a2870405 100644 --- a/src/batchnorm_layer.c +++ b/src/batchnorm_layer.c @@ -142,7 +142,7 @@ void forward_batchnorm_layer(layer l, network_state state) axpy_cpu(l.out_c, .1, l.variance, 1, l.rolling_variance, 1); copy_cpu(l.outputs*l.batch, l.output, 1, l.x, 1); - normalize_cpu(l.output, l.mean, l.variance, l.batch, l.out_c, l.out_h*l.out_w); + normalize_cpu(l.output, l.mean, l.variance, l.batch, l.out_c, l.out_h*l.out_w); copy_cpu(l.outputs*l.batch, l.output, 1, l.x_norm, 1); } else { normalize_cpu(l.output, l.rolling_mean, l.rolling_variance, l.batch, l.out_c, l.out_h*l.out_w); @@ -179,8 +179,11 @@ void push_batchnorm_layer(layer l) void forward_batchnorm_layer_gpu(layer l, network_state state) { - if (l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1); - copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1); + if (l.type == BATCHNORM) simple_copy_ongpu(l.outputs*l.batch, state.input, l.output_gpu); + //copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1); + + simple_copy_ongpu(l.outputs*l.batch, l.output_gpu, l.x_gpu); + //copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1); if (state.train) { #ifdef CUDNN float one = 1; @@ -255,7 +258,8 @@ void backward_batchnorm_layer_gpu(layer l, network_state state) .00001, l.mean_gpu, // input (should be FP32) l.variance_gpu); // input (should be FP32) - copy_ongpu(l.outputs*l.batch, l.x_norm_gpu, 1, l.delta_gpu, 1); + simple_copy_ongpu(l.outputs*l.batch, l.x_norm_gpu, l.delta_gpu); + //copy_ongpu(l.outputs*l.batch, l.x_norm_gpu, 1, l.delta_gpu, 1); #else backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.out_c, l.out_w*l.out_h); backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.scale_updates_gpu); @@ -266,6 +270,7 @@ void backward_batchnorm_layer_gpu(layer l, network_state state) fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.variance_delta_gpu); normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.delta_gpu); #endif - if (l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, state.delta, 1); + if (l.type == BATCHNORM) simple_copy_ongpu(l.outputs*l.batch, l.delta_gpu, state.delta); + //copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, state.delta, 1); } #endif \ No newline at end of file diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index 4681416a..2ba2acdd 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -27,7 +27,7 @@ __global__ void binarize_kernel(float *x, int n, float *binary) void binarize_gpu(float *x, int n, float *binary) { binarize_kernel<<>>(x, n, binary); - check_error(cudaPeekAtLastError()); + CHECK_CUDA(cudaPeekAtLastError()); } __global__ void binarize_input_kernel(float *input, int n, int size, float *binary) @@ -48,7 +48,7 @@ __global__ void binarize_input_kernel(float *input, int n, int size, float *bina void binarize_input_gpu(float *input, int n, int size, float *binary) { binarize_input_kernel<<>>(input, n, size, binary); - check_error(cudaPeekAtLastError()); + CHECK_CUDA(cudaPeekAtLastError()); } __global__ void binarize_weights_kernel(float *weights, int n, int size, float *binary) @@ -70,7 +70,7 @@ __global__ void binarize_weights_kernel(float *weights, int n, int size, float * void binarize_weights_gpu(float *weights, int n, int size, float *binary) { binarize_weights_kernel << > >(weights, n, size, binary); - check_error(cudaPeekAtLastError()); + CHECK_CUDA(cudaPeekAtLastError()); } #define WARP_SIZE 32 @@ -121,7 +121,7 @@ void fast_binarize_weights_gpu(float *weights, int n, int size, float *binary, f set_zero_kernel << <(n/BLOCK + 1), BLOCK, 0, get_cuda_stream() >> > (mean_arr_gpu, n); reduce_kernel << > > (weights, n, size, mean_arr_gpu); binarize_weights_mean_kernel << > > (weights, n, size, binary, mean_arr_gpu); - check_error(cudaPeekAtLastError()); + CHECK_CUDA(cudaPeekAtLastError()); } else { binarize_weights_gpu(weights, n, size, binary); @@ -140,6 +140,7 @@ __global__ void cuda_f32_to_f16(float* input_f32, size_t size, half *output_f16) void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16) { cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, (half *)output_f16); + CHECK_CUDA(cudaPeekAtLastError()); } __global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32) @@ -151,6 +152,7 @@ __global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32) void cuda_convert_f16_to_f32(float* input_f16, size_t size, float *output_f32) { cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> ((half *)input_f16, size, output_f32); + CHECK_CUDA(cudaPeekAtLastError()); } half *cuda_make_f16_from_f32_array(float *src, size_t n) @@ -465,7 +467,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) { if (state.train) // Training { - copy_ongpu(l.outputs*l.batch / 2, output16, 1, l.x_gpu, 1); + simple_copy_ongpu(l.outputs*l.batch / 2, output16, l.x_gpu); + //copy_ongpu(l.outputs*l.batch / 2, output16, 1, l.x_gpu, 1); //cudaMemcpyAsync(l.x_gpu, output16, l.outputs*l.batch*sizeof(half), cudaMemcpyDefault, get_cuda_stream()); float one = 1; float zero = 0; @@ -645,7 +648,9 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state .00001, l.mean_gpu, // input (should be FP32) l.variance_gpu)); // input (should be FP32) - copy_ongpu(l.outputs*l.batch / 2, l.x_norm_gpu, 1, delta16, 1); + + simple_copy_ongpu(l.outputs*l.batch / 2, l.x_norm_gpu, delta16); + //copy_ongpu(l.outputs*l.batch / 2, l.x_norm_gpu, 1, delta16, 1); //cudaMemcpyAsync(delta16, l.x_norm_gpu, l.outputs*l.batch * sizeof(half), cudaMemcpyDefault, get_cuda_stream()); } else @@ -789,19 +794,20 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state void pull_convolutional_layer(convolutional_layer layer) { - cuda_pull_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size); - cuda_pull_array(layer.biases_gpu, layer.biases, layer.n); - cuda_pull_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size); - cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n); + cuda_pull_array_async(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size); + cuda_pull_array_async(layer.biases_gpu, layer.biases, layer.n); + cuda_pull_array_async(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size); + cuda_pull_array_async(layer.bias_updates_gpu, layer.bias_updates, layer.n); if (layer.batch_normalize){ - cuda_pull_array(layer.scales_gpu, layer.scales, layer.n); - cuda_pull_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n); - cuda_pull_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n); + cuda_pull_array_async(layer.scales_gpu, layer.scales, layer.n); + cuda_pull_array_async(layer.rolling_mean_gpu, layer.rolling_mean, layer.n); + cuda_pull_array_async(layer.rolling_variance_gpu, layer.rolling_variance, layer.n); } if (layer.adam){ - cuda_pull_array(layer.m_gpu, layer.m, layer.c*layer.n*layer.size*layer.size); - cuda_pull_array(layer.v_gpu, layer.v, layer.c*layer.n*layer.size*layer.size); + cuda_pull_array_async(layer.m_gpu, layer.m, layer.c*layer.n*layer.size*layer.size); + cuda_pull_array_async(layer.v_gpu, layer.v, layer.c*layer.n*layer.size*layer.size); } + cudaStreamSynchronize(get_cuda_stream()); } void push_convolutional_layer(convolutional_layer layer) diff --git a/src/network_kernels.cu b/src/network_kernels.cu index 0e14ce56..9ab124e4 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -87,7 +87,7 @@ void forward_network_gpu(network net, network_state state) } */ } - cudaStreamSynchronize(get_cuda_stream()); // sync CUDA-functions + //cudaStreamSynchronize(get_cuda_stream()); // sync CUDA-functions //cudaDeviceSynchronize(); //show_total_time(); }