diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index e996854a..1b63c4c5 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -84,7 +84,11 @@ __global__ void set_zero_kernel(float *src, int size) __inline__ __device__ float warpAllReduceSum(float val) { for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) +#if CUDA_VERSION >= 9000 + val += __shfl_xor_sync(0xffffffff, val, mask); +#else val += __shfl_xor(val, mask); +#endif return val; } diff --git a/src/im2col_kernels.cu b/src/im2col_kernels.cu index a4ae4c9e..2bdca30b 100644 --- a/src/im2col_kernels.cu +++ b/src/im2col_kernels.cu @@ -12,8 +12,27 @@ extern "C" { #include #include +#define FULL_MASK 0xffffffff #define WARP_SIZE 32 +template +__device__ inline T1 __shfl_custom(T1 val, T2 lane) { +#if CUDA_VERSION >= 9000 + return __shfl_sync(FULL_MASK, val, lane); +#else + return __shfl(val, lane); +#endif +} + +template +__device__ inline uint32_t __ballot_custom(T val) { +#if CUDA_VERSION >= 9000 + return __ballot_sync(FULL_MASK, val); +#else + return __ballot(val); +#endif +} + // src: https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu // You may also want to read: https://github.com/BVLC/caffe/blob/master/LICENSE @@ -205,227 +224,6 @@ void im2col_align_ongpu(float *im, // -------------------------------- -/* -// binary im2col -__global__ void im2col_align_bin_gpu_kernel(const int n, const float* data_im, - const int height, const int width, const int ksize, const int channels, - const int pad, - const int stride, - const int height_col, const int width_col, - float *data_col, const int bit_align) -{ - __shared__ float tmp_s[1]; - - //#define SHRED_VALS ((BLOCK / 169) * ) - __shared__ float dst_s[1024]; - //__shared__ float dst_s[1024]; - //__shared__ uint32_t bit_s[32]; - __shared__ uint8_t bit_s[128]; - - int index = blockIdx.x*blockDim.x + threadIdx.x; - for (; index < n; index += blockDim.x*gridDim.x) - { - //int c_index = index; - //int channel_in = c_index % channels; - - int h_out = index % height_col; - int c_index = index / height_col; - int channel_in = c_index % channels; - - int channel_out = channel_in * ksize * ksize; - - int j_index = c_index / channels; - int j = j_index % ksize; - int i = j_index / ksize; - if (i < ksize) - { - for (int w_out = 0; w_out < width_col; ++w_out) - { - int h_in = h_out * stride - pad; - int w_in = w_out * stride - pad; - - int h = h_in + i; - int w = w_in + j; - - float val = (h >= 0 && w >= 0 && h < height && w < width) ? - data_im[(channel_in * height + h_in) * width + w_in + i * width + j] : 0; - - //int pre_out_index = index % (width_col*height_col); - int pre_out_index = h_out * width_col + w_out; - int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index; - data_col[out_index] = val; - - - }// w_out - } - } -} -*/ - -/* -// binary im2col -__global__ void im2col_align_bin_gpu_kernel(const int n, const float* data_im, - const int height, const int width, const int ksize, const int channels, - const int pad, - const int stride, - const int height_col, const int width_col, - float *data_col, const int bit_align) -{ - __shared__ float tmp_s[1]; - __shared__ ulonglong4 tmp256_s[1]; - - - //#define SHRED_VALS ((BLOCK / 169) * ) - //__shared__ float dst_s[1024]; - //__shared__ float dst_s[1024]; - //__shared__ uint32_t bit_s[32]; - //__shared__ uint8_t bit_s[128]; - - int index = blockIdx.x*blockDim.x + threadIdx.x; - //for (; index < n; index += blockDim.x*gridDim.x) - { - //int c_index = index; - //int channel_in = c_index % channels; - - int h_out = index % height_col; - int c_index = index / height_col; - int channel_in = c_index % channels; - - int channel_out = channel_in * ksize * ksize; - - int j_index = c_index / channels; - int j = j_index % ksize; - int i = j_index / ksize; - - int h_in = h_out * stride - pad; - int h = h_in + i; - - //if (i < ksize) - { - int w_out = 0; - - // the end of padding - //if(0) - for (; w_out < (width_col); w_out += 32) - { - int w = w_out * stride - pad + j; - int pre_in_index = (channel_in * height + h_in) * width + i * width; - int in_index = pre_in_index + w; - //float *src_p = (float *)&data_im[in_index]; - - int pre_out_index = h_out * width_col + w_out; - int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index; - // float *dst_p = (float *)&data_col[out_index]; - - if (i >= ksize) { - out_index = -1; - } - - #pragma unroll - for (int t = 0; t < WARP_SIZE; ++t) { - const int lane_id = threadIdx.x % WARP_SIZE; - - //const int64_t cur_pre_in_index = pre_in_index; - //const int64_t cur_j = j; - //const int64_t out_i = out_index;// __shfl(out_index, t) + lane_id; - - const int64_t cur_out_index = __shfl(out_index, t); - if (cur_out_index >= 0) - { - const int64_t cur_pre_in_index = __shfl(pre_in_index, t); - const int64_t cur_j = __shfl(j, t); - const int64_t cur_h = __shfl(h, t); - - int cur_w = ((w_out + lane_id) * stride - pad + cur_j); - int in_index = cur_pre_in_index + cur_w; - - float val = (cur_w >= 0 && cur_w < width && cur_h >= 0 && cur_h < height) ? - data_im[in_index] : float(); - - if ((w_out + lane_id) < width_col) { - data_col[cur_out_index + lane_id] = val; - //tmp_s[0] = val; - - //uint32_t bit_mask = __ballot(val > 0); - //uint8_t *bit8_ptr = &(((uint8_t *)data_col)[cur_out_index / 8]); - //uint32_t *bit32_ptr = (uint32_t *)bit8_ptr; - //*bit32_ptr = bit_mask; - } - } - } - - }// w_out - -#ifdef NOT_USED - if (i < ksize && h >= 0 && h < height) - { - - // wait for align address and the end of padding - for (; w_out < width_col; ++w_out) - { - int w_in = w_out * stride - pad; - int w = w_in + j; - - int in_index = (channel_in * height + h_in) * width + w_in + i * width + j; - float *src_p = (float *)&data_im[in_index]; - - int pre_out_index = h_out * width_col + w_out; - int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index; - float *dst_p = (float *)&data_col[out_index]; - - if (((uint64_t)src_p % 32 == 0) && ((uint64_t)dst_p % 32 == 0) && w > 0) { - //printf(" aligned addresses and there is no padding \n"); - break; - } - - float val = (w >= 0 && w < width) ? - (*src_p) : float(); - - *dst_p = val; - //tmp_s[0] = val; - }// w_out - - // ulonglong4 (256 bit) / instead of float (32 bit) = 8x times - for (; w_out < (width_col - 8); w_out += 8) - { - int w_in = w_out * stride - pad; - int w = w_in + j; - - ulonglong4 *src_p = (ulonglong4 *)&data_im[(channel_in * height + h_in) * width + w_in + i * width + j]; - - int pre_out_index = h_out * width_col + w_out; - int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index; - ulonglong4 *dst_p = (ulonglong4 *)&data_col[out_index]; - - ulonglong4 val = (w < width) ? - (*src_p) : ulonglong4(); - - *dst_p = val; - //tmp256_s[0] = val; - }// w_out - - for (; w_out < width_col; ++w_out) - { - //int h_in = h_out * stride - pad; - int w_in = w_out * stride - pad; - - //int h = h_in + i; - int w = w_in + j; - - float val = (w < width) ? - data_im[(channel_in * height + h_in) * width + w_in + i * width + j] : 0; - - int pre_out_index = h_out * width_col + w_out; - int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index; - data_col[out_index] = val; - //tmp_s[0] = val; - }// w_out - } -#endif // NOT_USED - } - } -} -*/ // binary im2col - stride=1 @@ -491,14 +289,14 @@ __global__ void im2col_align_bin_gpu_kernel(const int n, const float* data_im, { const int lane_id = threadIdx.x % WARP_SIZE; - const int cur_wh_index = __shfl(send_wh_index, t) + lane_id; + const int cur_wh_index = __shfl_custom(send_wh_index, t) + lane_id; if (cur_wh_index < (width_col*height_col))// && (cur_i_pad+pad) < ksize) { - const int cur_pre_out_index = __shfl(pre_out_index, t); + const int cur_pre_out_index = __shfl_custom(pre_out_index, t); - const int cur_pre_in_index = __shfl(pre_in_index, t); - const int cur_pre_in_wh_index = __shfl(pre_in_wh_index, t) + lane_id; + const int cur_pre_in_index = __shfl_custom(pre_in_index, t); + const int cur_pre_in_wh_index = __shfl_custom(pre_in_wh_index, t) + lane_id; int w = cur_pre_in_wh_index % width; int h = cur_pre_in_wh_index / width; @@ -512,7 +310,7 @@ __global__ void im2col_align_bin_gpu_kernel(const int n, const float* data_im, //data_col[out_index] = val; //tmp_s[0] = val; - uint32_t bit_mask = __ballot(val > 0); + uint32_t bit_mask = __ballot_custom(val > 0); if (lane_id == 0) { uint8_t *bit8_ptr = &(((uint8_t *)data_col)[out_index / 8]); uint32_t *bit32_ptr = (uint32_t *)bit8_ptr; @@ -565,7 +363,7 @@ __global__ void float_to_bit_gpu_kernel(float *src, unsigned char *dst, size_t s if(index < size) src_val = src[index]; else src_val = 0; //unsigned int bit_mask = __ballot_sync(0xffffffff, src_val > 0); - unsigned int bit_mask = __ballot(src_val > 0); + unsigned int bit_mask = __ballot_custom(src_val > 0); if (threadIdx.x % WARP_SIZE == 0) ((unsigned int*)dst)[index / 32] = bit_mask; } } @@ -591,7 +389,7 @@ __global__ void float_to_bit_gpu_kernel(float *src, unsigned char *dst, size_t s const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - uint32_t bit_mask = __ballot(src_val > 0); + uint32_t bit_mask = __ballot_custom(src_val > 0); if (lane_id == 0) tmp[warp_id] = bit_mask; @@ -624,7 +422,7 @@ __global__ void float_to_bit_gpu_kernel(float *src, unsigned char *dst, size_t s const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - uint32_t bit_mask = __ballot(src_val > 0); + uint32_t bit_mask = __ballot_custom(src_val > 0); if (lane_id == 0) tmp[i * 32 + warp_id] = bit_mask; } __syncthreads(); @@ -849,8 +647,8 @@ __global__ void transpose_bin_gpu_kernel_32(uint32_t *A, uint32_t *B, const int void transpose_bin_gpu(unsigned char *A, unsigned char *B, const int n, const int m, const int lda, const int ldb, const int block_size) { - size_t size = n*m/ (8*8) + 1; - size_t size32 = n*m / (32*32) + 1; + int size = n*m/ (8*8) + 1; + int size32 = n*m / (32*32) + 1; const int num_blocks = size / BLOCK + 1; const int num_blocks32 = size32 / BLOCK_TRANSPOSE32 + 1; transpose_bin_gpu_kernel_32 << > >((uint32_t *)A, (uint32_t *)B, n, m, lda, ldb, block_size); @@ -1059,7 +857,7 @@ __global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_ { float src = input[(chan + c_pack)*items_per_channel + i]; - uint32_t bit_mask = __ballot(src > 0); + uint32_t bit_mask = __ballot_custom(src > 0); if (threadIdx.x % 32 == 0) re_packed_input_bin[chan*items_per_channel / 32 + i] = bit_mask; } @@ -1070,8 +868,9 @@ __global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_ void repack_input_gpu_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c) { int size = w * h * c; - const int block_size = 128; + const int block_size = 256;// 128; const int num_blocks = get_number_of_blocks(size, block_size); + //printf("\n num_blocks = %d, num_blocks/32 = %d, block_size = %d \n", num_blocks, num_blocks/32, block_size); repack_input_kernel_bin << > >(input, re_packed_input_bin, w, h, c); } @@ -1116,7 +915,7 @@ __global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_ __syncthreads(); src = tmp[lane_id * 33 + warp_id]; - uint32_t bit_mask = __ballot(src > 0); + uint32_t bit_mask = __ballot_custom(src > 0); const int warp_item = block_item + warp_id; @@ -1426,12 +1225,17 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int __inline__ __device__ int warpAllReduceSum(int val) { for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) +#if CUDA_VERSION >= 9000 + val += __shfl_xor_sync(FULL_MASK, val, mask); +#else val += __shfl_xor(val, mask); +#endif + return val; } // Tensor Cores binary (CC >= 7.3 && CUDA >= 10.0) - __CUDA_SUBBYTE_IMMA__ -#if CUDART_VERSION >= 10000 +#if CUDA_VERSION >= 10000 #include #define WMMA_M 8 @@ -1529,7 +1333,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_tensor_kernel(int M, int N, i #pragma UNROLL for (int local_j = 0; local_j < 8; ++local_j) { - uint32_t b_val_cur = __shfl(b_val, local_j *4 + k_d); + uint32_t b_val_cur = __shfl_custom(b_val, local_j *4 + k_d); c_val[local_j] = __popc(xor_int32(a_val, b_val_cur)); } @@ -1538,7 +1342,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_tensor_kernel(int M, int N, i { #pragma UNROLL for (int local_k = 0; local_k < 4; ++local_k) { - accum_c_val[local_j] += __shfl(c_val[local_j], i_d * 4 + local_k); + accum_c_val[local_j] += __shfl_custom(c_val[local_j], i_d * 4 + local_k); } } } @@ -1571,7 +1375,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_tensor_kernel(int M, int N, i float bias_val = bias_arr[i + i_d]; float dst_val = count *mean_val + bias_val; if (leaky_activation) - dst_val = (dst_val > 0) ? (dst_val) : (0.1*dst_val); // Leaky activation + dst_val = (dst_val > 0) ? (dst_val) : (0.1f*dst_val); // Leaky activation C[(i + i_d)*ldc + (j + j_d)] = dst_val; } @@ -1635,8 +1439,8 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int for (int t = 0; t < WARP_SIZE; ++t) { const int lane_id = threadIdx.x % WARP_SIZE; - const int64_t A_i = __shfl(A_cur_index, t) + 32 * lane_id; - const int64_t B_i = __shfl(B_cur_index, t) + 32 * lane_id; + const int64_t A_i = __shfl_custom(A_cur_index, t) + 32 * lane_id; + const int64_t B_i = __shfl_custom(B_cur_index, t) + 32 * lane_id; { //ulonglong4 a_bit256 = *((ulonglong4 *)(A + A_i)); // weights @@ -1668,8 +1472,8 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int for (int t = 0; t < WARP_SIZE; ++t) { const int lane_id = threadIdx.x % WARP_SIZE; - const int64_t A_i = __shfl(A_cur_index, t) + 8 * lane_id; - const int64_t B_i = __shfl(B_cur_index, t) + 8 * lane_id; + const int64_t A_i = __shfl_custom(A_cur_index, t) + 8 * lane_id; + const int64_t B_i = __shfl_custom(B_cur_index, t) + 8 * lane_id; { //uint64_t a_bit64 = *((uint64_t *)(A + A_i)); // weights @@ -1698,8 +1502,8 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int for (int t = 0; t < WARP_SIZE; ++t) { const int lane_id = threadIdx.x % WARP_SIZE; - const int64_t A_i = __shfl(A_cur_index, t) + 4 * lane_id; - const int64_t B_i = __shfl(B_cur_index, t) + 4 * lane_id; + const int64_t A_i = __shfl_custom(A_cur_index, t) + 4 * lane_id; + const int64_t B_i = __shfl_custom(B_cur_index, t) + 4 * lane_id; { //uint64_t a_bit64 = *((uint64_t *)(A + A_i)); // weights @@ -1749,7 +1553,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int count = count - f1; // remove extra bits (from empty space for align only) float dst_val = (2 * count - K) *mean_val + bias_val; if(leaky_activation) - dst_val = (dst_val > 0) ? (dst_val) : (0.1*dst_val); // Leaky activation + dst_val = (dst_val > 0) ? (dst_val) : (0.1f*dst_val); // Leaky activation C[i*ldc + j] = dst_val; } } @@ -1778,7 +1582,7 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, unsigned char *B, int ldb, float *C, int ldc, float *mean_arr, float *bias, int leaky_activation) { - size_t size = M*N; + int size = M*N; const int num_blocks = get_number_of_blocks(size, BLOCK); //printf("\n M = %d, N = %d, M %% 8 = %d, N %% 8 = %d \n", M, N, M % 8, N % 8); @@ -1797,7 +1601,7 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, { const int M_aligned = M + (8 - (M % 8)); const int N_aligned = N + (8 - (N % 8)); - size_t size = (M_aligned / 8)*(N_aligned / 8)*WARP_SIZE; + int size = (M_aligned / 8)*(N_aligned / 8)*WARP_SIZE; const int num_blocks = get_number_of_blocks(size, BLOCK); //printf(" lda = %d, ldb = %d, ldc = %d, lda/32 = %d, ldb/32 = %d, ldc/32 = %d \n", lda, ldb, ldc, lda / 32, ldb / 32, ldc / 32); @@ -1827,142 +1631,6 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, // -------------------------------- - - -// -------------------------------- -// sequentially - B (input) in the shared_memory - BAD -// -------------------------------- -__global__ void gemm_nn_custom_bin_mean_transposed_sequentially_gpu_kernel(int M, int N, int K, - unsigned char *A, int lda, - unsigned char *B, int ldb, - float *C, int ldc, float *mean_arr) -{ - //__shared__ float mean_shared[32]; - //__shared__ uint32_t B_s[8192]; // 32 KB // [ldb x N`] // max = 262 144 bits - //__shared__ uint32_t B_s[4096]; // 16 KB // [ldb x N`] // max = 131 072 bits - __shared__ uint8_t B_s[4096*4]; // 16 KB // [ldb x N`] // max = 131 072 bits - - - const int K_items = WARP_SIZE; - int start_j = blockIdx.x*blockDim.x / (K_items * M); - - { - int end_j = (blockIdx.x*blockDim.x + blockDim.x) / (K_items * M) + 1; - if (end_j > N) end_j = N; - size_t shared_size = ldb * (end_j - start_j); - - if (shared_size != 0) { - //if(threadIdx.x == 0) printf(" start_j = %d, end_j = %d, shared_size = %d \n", start_j, end_j, shared_size); - - int k; - for (int k = threadIdx.x * 32; k < shared_size; k += blockDim.x * 32) { - int x = start_j*ldb + k; - if (x < (N*ldb)) *((uint32_t *)(B_s + k / 8)) = *((uint32_t *)(B + x / 8)); - } - } - } - __syncthreads(); - - int index = blockIdx.x*blockDim.x + threadIdx.x; - - { - int i; // l.n - int j; // out_h*out_w - int k; // l.size * l.size * l.c - - const int index2 = index / K_items; - i = index2 % M; // max M - j = index2 / M; // max N - //j = index2 % N; // max N - //i = index2 / N; // max M - - //int j_cur = index / M; - //int local_j = j_cur - start_j; - int local_j = j - start_j; - - //if (i <= 1 && j <= 1 ) printf(" k = %d, K = %d, K_items = %d, i = %d, j = %d, lda = %d, ldb = %d, ldc = %d \n", - // k, K, K_items, i, j, lda, ldb, ldc); - { // l.n - filters [16 - 55 - 1024] - // further improvements: for (l.n == 1024) iterate several (j) - - - if (j < N) - { // out_h*out_w - one channel output size [169 - 173056] - - int count = 0; - - - const int bit_step = 32; - for (k = (threadIdx.x % WARP_SIZE) * bit_step; k < K; k += bit_step*WARP_SIZE) - { // l.size*l.size*l.c - one filter size [27 - 144 - 9216] - uint32_t a_bit32 = *((uint32_t *)(A + (i*lda + k) / 8)); // weights - //uint32_t b_bit32 = *((uint32_t *)(B + (j*ldb + k) / 8)); // input - uint32_t b_bit32 = *((uint32_t *)(B_s + (local_j*ldb + k) / 8)); // input - uint32_t c_bit32 = xnor_int32(a_bit32, b_bit32); - - count += __popc(c_bit32); - } - - /* - const int bit_step = 64; - for (k = (threadIdx.x % WARP_SIZE) * bit_step; k < K; k += bit_step*WARP_SIZE) - { // l.size*l.size*l.c - one filter size [27 - 144 - 9216] - uint64_t a_bit64 = *((uint64_t *)(A + (i*lda + k) / 8)); // weights - //uint64_t b_bit64 = *((uint64_t *)(B + (j*ldb + k) / 8)); - uint64_t b_bit64 = *((uint64_t *)(B_s + (local_j*ldb + k) / 8)); // input - uint64_t c_bit64 = xnor_int64(a_bit64, b_bit64); - count += __popcll(c_bit64); - } - */ - - - //atomicAdd(&C[i*ldc + j], (2 * count) * mean_val); - - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) - count += __shfl_down(count, offset); - - - if (threadIdx.x % WARP_SIZE == 0) { - int f1 = (K % bit_step == 0) ? 0 : (bit_step - (K % bit_step)); - count = count - f1; - float mean_val = mean_arr[i]; - C[i*ldc + j] = (2 * count - K) * mean_val; - //B_s[threadIdx.x / WARP_SIZE] = (2 * count - K) * mean_val; - } - } - } - } -} - -// sequentially - BAD -void gemm_nn_custom_bin_mean_transposed_sequentially_gpu(int M, int N, int K, - unsigned char *A, int lda, - unsigned char *B, int ldb, - float *C, int ldc, float *mean_arr) -{ - //size_t size = M*N; - size_t size = M*N * 32; - - const int num_blocks = size / BLOCK + 1; - - //printf(" K = %d \n", K); - - /* - printf("\n gemm_bin size = %d, num_blocks = %d, M*K = %d KB, N*K = %d KB \n (w) M*K/num_blocks = %d KB, (i) N*K/num_blocks = %d KB \n", - size, num_blocks, M*K / 1024, N*K / 1024, M*lda / num_blocks / 1024, N*ldb / num_blocks / 1024); - printf(" M / 512 = %d, N / 512 = %d, M*lda / 512 = %d, N*ldb / 512 = %d \n", M / 512, N / 512, M*lda/512, N*ldb/512); - */ - //printf(" shared_memory: (w) lda*BLOCK/N = %d, (i) ldb*BLOCK/M = %d, \t lda = %d \n\n", lda*BLOCK / N, ldb*BLOCK / M, lda); - - gemm_nn_custom_bin_mean_transposed_sequentially_gpu_kernel << > >( - M, N, K, - A, lda, - B, ldb, - C, ldc, - mean_arr); -} -// -------------------------------- - void convolve_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad) { int fil; @@ -2133,7 +1801,7 @@ __global__ void convolve_gpu_kernel(float *input, float *weights, float *output, void convolve_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad) { - size_t array_size = in_w*in_h*n; // width X height X filters + int array_size = in_w*in_h*n; // width X height X filters const int num_blocks = array_size / BLOCK + 1; //printf("\n array_size = %d, num_blocks = %d, w = %d, h = %d, n = %d, c = %d, pad = %d \n", array_size, num_blocks, in_w, in_h, n, in_c, pad); @@ -2247,10 +1915,10 @@ __global__ void convolve_bin_gpu_kernel(float *input, float *weights, float *out //const int weights_size = size*size*in_c/8; const int weights_size = size*size*in_c / 32 + 1; - for (int fil = min_fil; fil <= max_fil; fil++) { + for (int tmp_fil = min_fil; tmp_fil <= max_fil; tmp_fil++) { for (int s = threadIdx.x; s < weights_size; s += blockDim.x) { - //weights_shared[s + (fil - min_fil)*new_lda / 8] = ((uint8_t *)weights)[fil*new_lda / 8 + s]; - weights_shared[s + (fil - min_fil)*new_lda/32] = ((uint32_t *)weights)[fil*new_lda / 32 + s]; + //weights_shared[s + (tmp_fil - min_fil)*new_lda / 8] = ((uint8_t *)weights)[tmp_fil*new_lda / 8 + s]; + weights_shared[s + (tmp_fil - min_fil)*new_lda/32] = ((uint32_t *)weights)[tmp_fil*new_lda / 32 + s]; } } __syncthreads(); @@ -2348,7 +2016,7 @@ __global__ void convolve_bin_gpu_kernel(float *input, float *weights, float *out void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad, int new_lda, float *mean_arr_gpu) { - size_t array_size = in_w*in_h*n; // width X height X filters + int array_size = in_w*in_h*n; // width X height X filters const int num_blocks = array_size / BLOCK + 1; //printf("\n array_size = %d, num_blocks = %d, w = %d, h = %d, n = %d, c = %d, pad = %d \n", array_size, num_blocks, in_w, in_h, n, in_c, pad);