diff --git a/src/activation_kernels.cu b/src/activation_kernels.cu index ee112b18..e6718fe1 100644 --- a/src/activation_kernels.cu +++ b/src/activation_kernels.cu @@ -220,10 +220,11 @@ __global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delt extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a) { + const int num_blocks = get_number_of_blocks(n, BLOCK); if (a == LINEAR) return; - else if(a == LEAKY) activate_array_leaky_kernel << <(n / BLOCK + 1), BLOCK, 0, get_cuda_stream() >> >(x, n); - else if (a == LOGISTIC) activate_array_logistic_kernel << <(n / BLOCK + 1), BLOCK, 0, get_cuda_stream() >> >(x, n); - else if (a == SELU) activate_array_selu_kernel << <(n / BLOCK + 1), BLOCK, 0, get_cuda_stream() >> >(x, n); + else if(a == LEAKY) activate_array_leaky_kernel << > >(x, n); + else if (a == LOGISTIC) activate_array_logistic_kernel << > >(x, n); + else if (a == SELU) activate_array_selu_kernel << > >(x, n); else activate_array_kernel<<>>(x, n, a); check_error(cudaPeekAtLastError()); } diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 30529184..04dd88b8 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -777,57 +777,6 @@ void binary_align_weights(convolutional_layer *l) free(align_weights); } -/* -void binary_align_weights(convolutional_layer *l) -{ - int m = l->n; - int k = l->size*l->size*l->c; - size_t new_lda = k + (l->lda_align - k % l->lda_align); // (k / 8 + 1) * 8; - l->new_lda = new_lda; - - binarize_weights(l->weights, m, k, l->binary_weights); - - size_t align_weights_size = new_lda * m; - l->align_bit_weights_size = align_weights_size / 8 + 1; - float *align_weights = calloc(align_weights_size, sizeof(float)); - l->align_bit_weights = calloc(l->align_bit_weights_size, sizeof(char)); - - size_t i, j; - // align A without transpose - for (i = 0; i < m; ++i) { - for (j = 0; j < k; ++j) { - align_weights[i*new_lda + j] = l->binary_weights[i*k + j]; - } - } - float_to_bit(align_weights, l->align_bit_weights, align_weights_size); - - //l->mean_arr = calloc(l->n, sizeof(float)); - get_mean_array(align_weights, align_weights_size, l->n, l->mean_arr); - -#ifdef GPU - cudaError_t status; - l->align_workspace_size = l->bit_align * l->size * l->size * l->c; - status = cudaMalloc((void **)&l->align_workspace_gpu, l->align_workspace_size * sizeof(float)); - status = cudaMalloc((void **)&l->transposed_align_workspace_gpu, l->align_workspace_size * sizeof(float)); - check_error(status); - - //l->align_bit_weights_gpu = cuda_make_array(l->align_bit_weights, l->align_bit_weights_size * sizeof(char)/sizeof(float)); - status = cudaMalloc((void **)&l->align_bit_weights_gpu, l->align_bit_weights_size); - check_error(status); - status = cudaMemcpy(l->align_bit_weights_gpu, l->align_bit_weights, l->align_bit_weights_size, cudaMemcpyHostToDevice); - check_error(status); - status = cudaMemcpy(l->binary_weights_gpu, l->binary_weights, m*k*sizeof(float), cudaMemcpyHostToDevice); - check_error(status); - - //l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n); - cuda_push_array(l->mean_arr_gpu, l->mean_arr, l->n); - cudaDeviceSynchronize(); -#endif // GPU - - free(align_weights); -} -*/ - // binary transpose size_t binary_transpose_align_input(int k, int n, float *b, char **t_bit_input, size_t ldb_align, int bit_align) { diff --git a/src/cuda.c b/src/cuda.c index bd8fa5ab..16a50c4e 100644 --- a/src/cuda.c +++ b/src/cuda.c @@ -230,6 +230,11 @@ void cuda_pull_array_async(float *x_gpu, float *x, size_t n) //cudaStreamSynchronize(get_cuda_stream()); } +int get_number_of_blocks(int array_size, int block_size) +{ + return array_size / block_size + ((array_size % block_size > 0) ? 1 : 0); +} + #else // GPU #include "cuda.h" void cuda_set_device(int n) {} diff --git a/src/cuda.h b/src/cuda.h index 89af2465..b803c17f 100644 --- a/src/cuda.h +++ b/src/cuda.h @@ -38,6 +38,7 @@ extern "C" { dim3 cuda_gridsize(size_t n); cudaStream_t get_cuda_stream(); cudaStream_t get_cuda_memcpy_stream(); + int get_number_of_blocks(int array_size, int block_size); #ifdef __cplusplus } #endif // __cplusplus diff --git a/src/im2col_kernels.cu b/src/im2col_kernels.cu index d7d6503b..46208d3a 100644 --- a/src/im2col_kernels.cu +++ b/src/im2col_kernels.cu @@ -635,7 +635,8 @@ __global__ void float_to_bit_gpu_kernel(float *src, unsigned char *dst, size_t s void float_to_bit_gpu(float *src, unsigned char *dst, size_t size) { //const int num_blocks = size / 1024 + 1; - const int num_blocks = size / (32*1024) + 1; + //const int num_blocks = size / (32*1024) + 1; + const int num_blocks = get_number_of_blocks(size, 32 * 1024); float_to_bit_gpu_kernel<<>>(src, dst, size); } // -------------------------------- @@ -1025,16 +1026,17 @@ void repack_input_gpu_2(float *input, float *re_packed_input, int w, int h, int // -------------------------------- + // 32 channels -> 1 channel (with 32 floats) // 256 channels -> 8 channels (with 32 floats) __global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c) { - __shared__ uint32_t tmp[32]; + //__shared__ uint32_t tmp[32]; int index = blockIdx.x*blockDim.x + threadIdx.x; - const int num_of_warps = blockDim.x / WARP_SIZE; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; + //const int num_of_warps = blockDim.x / WARP_SIZE; + //const int warp_id = threadIdx.x / WARP_SIZE; + //const int lane_id = threadIdx.x % WARP_SIZE; const int items_per_channel = w * h; @@ -1058,18 +1060,8 @@ __global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_ float src = input[(chan + c_pack)*items_per_channel + i]; uint32_t bit_mask = __ballot(src > 0); - //if (threadIdx.x % 32 == 0) - // re_packed_input_bin[chan*items_per_channel/32 + i + c_pack/32] = bit_mask; - - if (lane_id == 0) tmp[warp_id] = bit_mask; - - __syncthreads(); - if (warp_id == 0) { - if (lane_id < num_of_warps) { - re_packed_input_bin[chan*items_per_channel / 32 + i + lane_id] = tmp[lane_id]; - } - } - __syncthreads(); + if (threadIdx.x % 32 == 0) + re_packed_input_bin[chan*items_per_channel / 32 + i] = bit_mask; } } } @@ -1078,9 +1070,74 @@ __global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_ void repack_input_gpu_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c) { int size = w * h * c; - const int num_blocks = size / BLOCK + 1; - repack_input_kernel_bin << > >(input, re_packed_input_bin, w, h, c); + const int block_size = 128; + const int num_blocks = get_number_of_blocks(size, block_size); + repack_input_kernel_bin << > >(input, re_packed_input_bin, w, h, c); } + + +/* +// 32 channels -> 1 channel (with 32 floats) +// 256 channels -> 8 channels (with 32 floats) +__global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c, int items_per_channel_align) +{ + __shared__ float tmp[33*32]; // misalgined array 32x32 + //const int index = blockIdx.x*blockDim.x + threadIdx.x; + + const int num_of_warps = blockDim.x / WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + const int items_per_channel = w * h; + //const int items_per_channel_align = items_per_channel + (32 - items_per_channel % 32); + const int blocks_per_wh = items_per_channel_align / 32; + //const int blocks_per_c = c / 32; + + // input[C x H x W] = input[C x ITEMS] + // BLOCK per C x ITEMS = 32x32 + + const int block_item_id = blockIdx.x % blocks_per_wh; + const int block_channel_id = blockIdx.x / blocks_per_wh; + + const int block_item = block_item_id * 32; + const int block_channel = block_channel_id * 32; + + const int lane_item = block_item + lane_id; + const int warp_channel = block_channel + warp_id; + + if (warp_channel < c) + { + float src = 0; + + if (lane_item < items_per_channel) + src = input[warp_channel*items_per_channel + lane_item]; + + tmp[warp_id * 33 + lane_id] = src; + __syncthreads(); + src = tmp[lane_id * 33 + warp_id]; + + uint32_t bit_mask = __ballot(src > 0); + + const int warp_item = block_item + warp_id; + + if (lane_id == 0 && warp_item < items_per_channel) + re_packed_input_bin[block_channel_id*items_per_channel + warp_item] = bit_mask; + } +} + +#define BLOCK_REPACK 1024 +void repack_input_gpu_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c) +{ + int items_per_channel = w*h; + int items_per_channel_align = items_per_channel + (32 - items_per_channel % 32); + int channel_align = c + (32 - c % 32); + + //int size = w * h * c; + int size = items_per_channel_align * channel_align; + const int num_blocks = get_number_of_blocks(size, BLOCK_REPACK); + repack_input_kernel_bin << > >(input, re_packed_input_bin, w, h, c, items_per_channel_align); +} +*/ // -------------------------------- @@ -1657,7 +1714,7 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, float *C, int ldc, float *mean_arr, float *bias) { size_t size = M*N; - const int num_blocks = size / BLOCK + 1; + const int num_blocks = get_number_of_blocks(size, BLOCK); /* printf("\n gemm_bin size = %d, num_blocks = %d, M*K = %d KB, N*K = %d KB \n (w) M*K/num_blocks = %d KB, (i) N*K/num_blocks = %d KB \n", diff --git a/src/network_kernels.cu b/src/network_kernels.cu index b021911d..0076a87b 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -88,7 +88,6 @@ void forward_network_gpu(network net, network_state state) */ } cudaStreamSynchronize(get_cuda_stream()); // sync CUDA-functions - //cudaStreamSynchronize(get_cuda_memcpy_stream()); // sync cudaMemcpyAsync() //cudaDeviceSynchronize(); //show_total_time(); }