Experimental repack

This commit is contained in:
AlexeyAB
2019-01-18 19:52:11 +03:00
parent bf6b40f4e9
commit 3a51f4af74
6 changed files with 87 additions and 75 deletions

View File

@ -220,10 +220,11 @@ __global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delt
extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a)
{
const int num_blocks = get_number_of_blocks(n, BLOCK);
if (a == LINEAR) return;
else if(a == LEAKY) activate_array_leaky_kernel << <(n / BLOCK + 1), BLOCK, 0, get_cuda_stream() >> >(x, n);
else if (a == LOGISTIC) activate_array_logistic_kernel << <(n / BLOCK + 1), BLOCK, 0, get_cuda_stream() >> >(x, n);
else if (a == SELU) activate_array_selu_kernel << <(n / BLOCK + 1), BLOCK, 0, get_cuda_stream() >> >(x, n);
else if(a == LEAKY) activate_array_leaky_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
else if (a == LOGISTIC) activate_array_logistic_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
else if (a == SELU) activate_array_selu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
else activate_array_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream()>>>(x, n, a);
check_error(cudaPeekAtLastError());
}

View File

@ -777,57 +777,6 @@ void binary_align_weights(convolutional_layer *l)
free(align_weights);
}
/*
void binary_align_weights(convolutional_layer *l)
{
int m = l->n;
int k = l->size*l->size*l->c;
size_t new_lda = k + (l->lda_align - k % l->lda_align); // (k / 8 + 1) * 8;
l->new_lda = new_lda;
binarize_weights(l->weights, m, k, l->binary_weights);
size_t align_weights_size = new_lda * m;
l->align_bit_weights_size = align_weights_size / 8 + 1;
float *align_weights = calloc(align_weights_size, sizeof(float));
l->align_bit_weights = calloc(l->align_bit_weights_size, sizeof(char));
size_t i, j;
// align A without transpose
for (i = 0; i < m; ++i) {
for (j = 0; j < k; ++j) {
align_weights[i*new_lda + j] = l->binary_weights[i*k + j];
}
}
float_to_bit(align_weights, l->align_bit_weights, align_weights_size);
//l->mean_arr = calloc(l->n, sizeof(float));
get_mean_array(align_weights, align_weights_size, l->n, l->mean_arr);
#ifdef GPU
cudaError_t status;
l->align_workspace_size = l->bit_align * l->size * l->size * l->c;
status = cudaMalloc((void **)&l->align_workspace_gpu, l->align_workspace_size * sizeof(float));
status = cudaMalloc((void **)&l->transposed_align_workspace_gpu, l->align_workspace_size * sizeof(float));
check_error(status);
//l->align_bit_weights_gpu = cuda_make_array(l->align_bit_weights, l->align_bit_weights_size * sizeof(char)/sizeof(float));
status = cudaMalloc((void **)&l->align_bit_weights_gpu, l->align_bit_weights_size);
check_error(status);
status = cudaMemcpy(l->align_bit_weights_gpu, l->align_bit_weights, l->align_bit_weights_size, cudaMemcpyHostToDevice);
check_error(status);
status = cudaMemcpy(l->binary_weights_gpu, l->binary_weights, m*k*sizeof(float), cudaMemcpyHostToDevice);
check_error(status);
//l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n);
cuda_push_array(l->mean_arr_gpu, l->mean_arr, l->n);
cudaDeviceSynchronize();
#endif // GPU
free(align_weights);
}
*/
// binary transpose
size_t binary_transpose_align_input(int k, int n, float *b, char **t_bit_input, size_t ldb_align, int bit_align)
{

View File

@ -230,6 +230,11 @@ void cuda_pull_array_async(float *x_gpu, float *x, size_t n)
//cudaStreamSynchronize(get_cuda_stream());
}
int get_number_of_blocks(int array_size, int block_size)
{
return array_size / block_size + ((array_size % block_size > 0) ? 1 : 0);
}
#else // GPU
#include "cuda.h"
void cuda_set_device(int n) {}

View File

@ -38,6 +38,7 @@ extern "C" {
dim3 cuda_gridsize(size_t n);
cudaStream_t get_cuda_stream();
cudaStream_t get_cuda_memcpy_stream();
int get_number_of_blocks(int array_size, int block_size);
#ifdef __cplusplus
}
#endif // __cplusplus

View File

@ -635,7 +635,8 @@ __global__ void float_to_bit_gpu_kernel(float *src, unsigned char *dst, size_t s
void float_to_bit_gpu(float *src, unsigned char *dst, size_t size)
{
//const int num_blocks = size / 1024 + 1;
const int num_blocks = size / (32*1024) + 1;
//const int num_blocks = size / (32*1024) + 1;
const int num_blocks = get_number_of_blocks(size, 32 * 1024);
float_to_bit_gpu_kernel<<<num_blocks, 1024, 0, get_cuda_stream()>>>(src, dst, size);
}
// --------------------------------
@ -1025,16 +1026,17 @@ void repack_input_gpu_2(float *input, float *re_packed_input, int w, int h, int
// --------------------------------
// 32 channels -> 1 channel (with 32 floats)
// 256 channels -> 8 channels (with 32 floats)
__global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c)
{
__shared__ uint32_t tmp[32];
//__shared__ uint32_t tmp[32];
int index = blockIdx.x*blockDim.x + threadIdx.x;
const int num_of_warps = blockDim.x / WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
//const int num_of_warps = blockDim.x / WARP_SIZE;
//const int warp_id = threadIdx.x / WARP_SIZE;
//const int lane_id = threadIdx.x % WARP_SIZE;
const int items_per_channel = w * h;
@ -1058,18 +1060,8 @@ __global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_
float src = input[(chan + c_pack)*items_per_channel + i];
uint32_t bit_mask = __ballot(src > 0);
//if (threadIdx.x % 32 == 0)
// re_packed_input_bin[chan*items_per_channel/32 + i + c_pack/32] = bit_mask;
if (lane_id == 0) tmp[warp_id] = bit_mask;
__syncthreads();
if (warp_id == 0) {
if (lane_id < num_of_warps) {
re_packed_input_bin[chan*items_per_channel / 32 + i + lane_id] = tmp[lane_id];
}
}
__syncthreads();
if (threadIdx.x % 32 == 0)
re_packed_input_bin[chan*items_per_channel / 32 + i] = bit_mask;
}
}
}
@ -1078,9 +1070,74 @@ __global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_
void repack_input_gpu_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c)
{
int size = w * h * c;
const int num_blocks = size / BLOCK + 1;
repack_input_kernel_bin << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(input, re_packed_input_bin, w, h, c);
const int block_size = 128;
const int num_blocks = get_number_of_blocks(size, block_size);
repack_input_kernel_bin << <num_blocks, block_size, 0, get_cuda_stream() >> >(input, re_packed_input_bin, w, h, c);
}
/*
// 32 channels -> 1 channel (with 32 floats)
// 256 channels -> 8 channels (with 32 floats)
__global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c, int items_per_channel_align)
{
__shared__ float tmp[33*32]; // misalgined array 32x32
//const int index = blockIdx.x*blockDim.x + threadIdx.x;
const int num_of_warps = blockDim.x / WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int items_per_channel = w * h;
//const int items_per_channel_align = items_per_channel + (32 - items_per_channel % 32);
const int blocks_per_wh = items_per_channel_align / 32;
//const int blocks_per_c = c / 32;
// input[C x H x W] = input[C x ITEMS]
// BLOCK per C x ITEMS = 32x32
const int block_item_id = blockIdx.x % blocks_per_wh;
const int block_channel_id = blockIdx.x / blocks_per_wh;
const int block_item = block_item_id * 32;
const int block_channel = block_channel_id * 32;
const int lane_item = block_item + lane_id;
const int warp_channel = block_channel + warp_id;
if (warp_channel < c)
{
float src = 0;
if (lane_item < items_per_channel)
src = input[warp_channel*items_per_channel + lane_item];
tmp[warp_id * 33 + lane_id] = src;
__syncthreads();
src = tmp[lane_id * 33 + warp_id];
uint32_t bit_mask = __ballot(src > 0);
const int warp_item = block_item + warp_id;
if (lane_id == 0 && warp_item < items_per_channel)
re_packed_input_bin[block_channel_id*items_per_channel + warp_item] = bit_mask;
}
}
#define BLOCK_REPACK 1024
void repack_input_gpu_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c)
{
int items_per_channel = w*h;
int items_per_channel_align = items_per_channel + (32 - items_per_channel % 32);
int channel_align = c + (32 - c % 32);
//int size = w * h * c;
int size = items_per_channel_align * channel_align;
const int num_blocks = get_number_of_blocks(size, BLOCK_REPACK);
repack_input_kernel_bin << <num_blocks, BLOCK_REPACK, 0, get_cuda_stream() >> >(input, re_packed_input_bin, w, h, c, items_per_channel_align);
}
*/
// --------------------------------
@ -1657,7 +1714,7 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K,
float *C, int ldc, float *mean_arr, float *bias)
{
size_t size = M*N;
const int num_blocks = size / BLOCK + 1;
const int num_blocks = get_number_of_blocks(size, BLOCK);
/*
printf("\n gemm_bin size = %d, num_blocks = %d, M*K = %d KB, N*K = %d KB \n (w) M*K/num_blocks = %d KB, (i) N*K/num_blocks = %d KB \n",

View File

@ -88,7 +88,6 @@ void forward_network_gpu(network net, network_state state)
*/
}
cudaStreamSynchronize(get_cuda_stream()); // sync CUDA-functions
//cudaStreamSynchronize(get_cuda_memcpy_stream()); // sync cudaMemcpyAsync()
//cudaDeviceSynchronize();
//show_total_time();
}