mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Added fast_binarize_weights_gpu()
This commit is contained in:
@ -61,18 +61,17 @@ void binarize_input_gpu(float *input, int n, int size, float *binary)
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
__global__ void binarize_weights_kernel(float *weights, int n, int size, float *binary)
|
||||
{
|
||||
int f = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (f >= n) return;
|
||||
int i = 0;
|
||||
float mean = 0;
|
||||
for(i = 0; i < size; ++i){
|
||||
for (i = 0; i < size; ++i) {
|
||||
mean += fabs(weights[f*size + i]);
|
||||
}
|
||||
mean = mean / size;
|
||||
for(i = 0; i < size; ++i){
|
||||
for (i = 0; i < size; ++i) {
|
||||
binary[f*size + i] = (weights[f*size + i] > 0) ? mean : -mean;
|
||||
//binary[f*size + i] = weights[f*size + i];
|
||||
}
|
||||
@ -80,10 +79,62 @@ __global__ void binarize_weights_kernel(float *weights, int n, int size, float *
|
||||
|
||||
void binarize_weights_gpu(float *weights, int n, int size, float *binary)
|
||||
{
|
||||
binarize_weights_kernel<<<cuda_gridsize(n), BLOCK>>>(weights, n, size, binary);
|
||||
binarize_weights_kernel << <cuda_gridsize(n), BLOCK >> >(weights, n, size, binary);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
__global__ void set_zero_kernel(float *src, int size)
|
||||
{
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < size) src[i] = 0;
|
||||
}
|
||||
|
||||
__inline__ __device__
|
||||
float warpAllReduceSum(float val) {
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2)
|
||||
val += __shfl_xor(val, mask);
|
||||
return val;
|
||||
}
|
||||
|
||||
// only if (size % 32 == 0)
|
||||
__global__ void reduce_kernel(float *weights, int n, int size, float *mean_arr_gpu)
|
||||
{
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int f = i / size;
|
||||
if (f >= n) return;
|
||||
float warp_mean = warpAllReduceSum(fabs(weights[i]));
|
||||
if(i % 32 == 0)
|
||||
atomicAdd(&mean_arr_gpu[f], warp_mean / size);
|
||||
}
|
||||
|
||||
__global__ void binarize_weights_mean_kernel(float *weights, int n, int size, float *binary, float *mean_arr_gpu)
|
||||
{
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int f = i / size;
|
||||
if (f >= n) return;
|
||||
float mean = mean_arr_gpu[f];
|
||||
binary[i] = (weights[i] > 0) ? mean : -mean;
|
||||
}
|
||||
|
||||
void fast_binarize_weights_gpu(float *weights, int n, int size, float *binary, float *mean_arr_gpu)
|
||||
{
|
||||
if (size % 32 == 0) {
|
||||
size_t gridsize = n * size;
|
||||
const int num_blocks = gridsize / BLOCK + 1;
|
||||
|
||||
set_zero_kernel << <(n/BLOCK + 1), BLOCK >> > (mean_arr_gpu, n);
|
||||
reduce_kernel << <num_blocks, BLOCK >> > (weights, n, size, mean_arr_gpu);
|
||||
binarize_weights_mean_kernel << <num_blocks, BLOCK >> > (weights, n, size, binary, mean_arr_gpu);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
else {
|
||||
binarize_weights_gpu(weights, n, size, binary);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void cuda_f32_to_f16(float* input_f32, size_t size, half *output_f16)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
@ -128,7 +179,9 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
|
||||
if(l.xnor){
|
||||
if (!l.align_bit_weights_gpu || state.train) {
|
||||
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
||||
//binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
||||
|
||||
fast_binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu, l.mean_arr_gpu);
|
||||
}
|
||||
//swap_binary(&l);
|
||||
//binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
|
||||
|
@ -314,6 +314,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
int align = 32;// 8;
|
||||
int src_align = l.out_h*l.out_w;
|
||||
l.bit_align = src_align + (align - src_align % align);
|
||||
|
||||
l.mean_arr = calloc(l.n, sizeof(float));
|
||||
}
|
||||
|
||||
if(batch_normalize){
|
||||
@ -369,6 +371,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
}
|
||||
if(xnor){
|
||||
l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size);
|
||||
l.mean_arr_gpu = cuda_make_array(0, l.n);
|
||||
l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
|
||||
}
|
||||
|
||||
@ -628,7 +631,7 @@ void binary_align_weights(convolutional_layer *l)
|
||||
}
|
||||
float_to_bit(align_weights, l->align_bit_weights, align_weights_size);
|
||||
|
||||
l->mean_arr = calloc(l->n, sizeof(float));
|
||||
//l->mean_arr = calloc(l->n, sizeof(float));
|
||||
get_mean_array(align_weights, align_weights_size, l->n, l->mean_arr);
|
||||
|
||||
#ifdef GPU
|
||||
@ -646,7 +649,8 @@ void binary_align_weights(convolutional_layer *l)
|
||||
status = cudaMemcpy(l->binary_weights_gpu, l->binary_weights, m*k*sizeof(float), cudaMemcpyHostToDevice);
|
||||
check_error(status);
|
||||
|
||||
l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n);
|
||||
//l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n);
|
||||
cuda_push_array(l->mean_arr_gpu, l->mean_arr, l->n);
|
||||
cudaDeviceSynchronize();
|
||||
#endif // GPU
|
||||
|
||||
|
@ -1123,6 +1123,38 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
||||
int count = 0;
|
||||
k = 0;
|
||||
|
||||
#ifdef NOT_USED
|
||||
// 32 thread X 256 bit = 8192 bit
|
||||
for (; k < (K - 8192); k += 8192) { // l.size*l.size*l.c - one filter size [27 - 9216]
|
||||
ulonglong4 c_bit256;
|
||||
|
||||
//int64_t A_cur_index = (i*lda + k) / 8;
|
||||
int64_t A_cur_index = (local_i*lda + k) / 8;
|
||||
int64_t B_cur_index = (j*ldb + k) / 8;
|
||||
if (i >= M) A_cur_index = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int t = 0; t < WARP_SIZE; ++t) {
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
const int64_t A_i = __shfl(A_cur_index, t) + 32 * lane_id;
|
||||
const int64_t B_i = __shfl(B_cur_index, t) + 32 * lane_id;
|
||||
|
||||
{
|
||||
//ulonglong4 a_bit256 = *((ulonglong4 *)(A + A_i)); // weights
|
||||
ulonglong4 a_bit256 = *((ulonglong4 *)(A_s + A_i)); // weights
|
||||
ulonglong4 b_bit256 = *((ulonglong4 *)(B + B_i)); // input
|
||||
c_bit256 = xnor_int256(a_bit256, b_bit256);
|
||||
int tmp_count = __popcll(c_bit256.w) + __popcll(c_bit256.x) +
|
||||
__popcll(c_bit256.y) + __popcll(c_bit256.z);
|
||||
|
||||
int sum_count = warpAllReduceSum(tmp_count);
|
||||
if (lane_id == t) count += sum_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
//#ifdef NOT_USED
|
||||
// 32 thread X 64 bit = 2048 bit
|
||||
for (; k < (K - 2048); k += 2048) { // l.size*l.size*l.c - one filter size [27 - 9216]
|
||||
|
@ -866,10 +866,10 @@ void calculate_binary_weights(network net)
|
||||
//if (l->size*l->size*l->c >= 2048) l->lda_align = 512;
|
||||
|
||||
binary_align_weights(l);
|
||||
}
|
||||
|
||||
if (net.layers[j].use_bin_output) {
|
||||
l->activation = LINEAR;
|
||||
if (net.layers[j].use_bin_output) {
|
||||
l->activation = LINEAR;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user