XNOR-net tiny-yolo_xnor.cfg ~2x faster than cuDNN on CUDA (nVidia GPU Maxwell)

This commit is contained in:
AlexeyAB
2018-09-22 02:01:14 +03:00
parent 0224ba3d0d
commit 7dd97537fb
15 changed files with 511 additions and 42 deletions

View File

@ -32,6 +32,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=32
size=3
@ -45,6 +46,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=64
size=3
@ -58,6 +60,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=128
size=3
@ -71,6 +74,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=256
size=3
@ -84,6 +88,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=512
size=3
@ -97,6 +102,7 @@ stride=1
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=1024
size=3

View File

@ -36,6 +36,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=32
size=3
@ -49,6 +50,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=64
size=3
@ -62,6 +64,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=128
size=3
@ -88,6 +91,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=512
size=3
@ -101,6 +105,7 @@ stride=1
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=1024
size=3
@ -173,6 +178,7 @@ stride=1
pad=1
activation=leaky
[convolutional]
size=1
stride=1

View File

@ -32,6 +32,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=32
size=3
@ -45,6 +46,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=64
size=3
@ -58,6 +60,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=128
size=3
@ -71,6 +74,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=256
size=3
@ -84,6 +88,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=512
size=3
@ -97,6 +102,7 @@ stride=1
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=1024
size=3

View File

@ -36,6 +36,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=32
size=3
@ -49,6 +50,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=64
size=3
@ -62,6 +64,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=128
size=3
@ -88,6 +91,7 @@ stride=2
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=512
size=3
@ -101,6 +105,7 @@ stride=1
[convolutional]
xnor=1
bin_output=1
batch_normalize=1
filters=1024
size=3
@ -173,6 +178,7 @@ stride=1
pad=1
activation=leaky
[convolutional]
size=1
stride=1

View File

@ -119,18 +119,15 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
if(l.xnor){
if (!l.align_bit_weights_gpu || state.train) {
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
swap_binary(&l);
binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
state.input = l.binary_input_gpu;
}
//swap_binary(&l);
//binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
//state.input = l.binary_input_gpu;
//cudaDeviceSynchronize();
if (l.align_bit_weights_gpu && !state.train)
if (l.align_bit_weights_gpu && !state.train && l.c >= 256 && l.size > 1)
{
//return;
cudaError_t status = cudaSuccess;
int input_size = l.c*l.h*l.w*l.batch;
@ -146,15 +143,25 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
//if(0)
{
//cudaDeviceSynchronize();
int i = 0;
im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.align_workspace_gpu, l.bit_align);
//cudaDeviceSynchronize();
if (l.stride == 1 && l.c >= 256 && l.w > 13 && l.size > 1 && 0) // disabled
{
// stride=1 only
im2col_align_bin_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace, l.bit_align);
//cudaDeviceSynchronize();
}
else
{
im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.align_workspace_gpu, l.bit_align);
//cudaDeviceSynchronize();
//getchar();
// should be optimized
float_to_bit_gpu(l.align_workspace_gpu, (unsigned char *)state.workspace, l.align_workspace_size);
//cudaDeviceSynchronize();
//im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace, l.bit_align);
// should be optimized
float_to_bit_gpu(l.align_workspace_gpu, (unsigned char *)state.workspace, l.align_workspace_size);
//cudaDeviceSynchronize();
}
transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)l.transposed_align_workspace_gpu, k, n, l.bit_align, new_ldb, 8);
//cudaDeviceSynchronize();
@ -197,7 +204,13 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
}
}
fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
if (l.xnor) {
swap_binary(&l);
binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
state.input = l.binary_input_gpu;
}
//fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
#ifdef CUDNN
float one = 1; // alpha[0], beta[0] is float for HALF and FLOAT
@ -294,7 +307,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
#else
cudnnConvolutionForward(cudnn_handle(),
&one,
&alpha, //&one,
l.srcTensorDesc,
state.input,
l.weightDesc,
@ -303,9 +316,11 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
l.fw_algo,
state.workspace,
l.workspace_size,
&one,
&beta, //&one,
l.dstTensorDesc,
l.output_gpu);
//cudaDeviceSynchronize();
#endif // CUDNN_HALF
@ -338,7 +353,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
}
#endif // no CUDNN_HALF
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
if (l.activation != LINEAR) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
//if(l.dot > 0) dot_error_gpu(l);
if(l.binary || l.xnor) swap_binary(&l);
//cudaDeviceSynchronize(); // for correct profiling of performance

View File

@ -257,7 +257,7 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
#endif
#endif
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam)
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output)
{
int i;
convolutional_layer l = {0};
@ -269,6 +269,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.n = n;
l.binary = binary;
l.xnor = xnor;
l.use_bin_output = use_bin_output;
l.batch = batch;
l.stride = stride;
l.size = size;
@ -307,7 +308,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.binary_weights = calloc(c*n*size*size, sizeof(float));
l.binary_input = calloc(l.inputs*l.batch, sizeof(float));
int align = 8;
int align = 32;// 8;
int src_align = l.out_h*l.out_w;
l.bit_align = src_align + (align - src_align % align);
}
@ -404,8 +405,9 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
//fprintf(stderr, "conv %5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c);
l.bflops = (2.0 * l.n * l.size*l.size*l.c * l.out_h*l.out_w) / 1000000000.;
if (l.xnor) fprintf(stderr, "convX ");
else fprintf(stderr, "conv ");
if (l.xnor && l.use_bin_output) fprintf(stderr, "convXB");
else if (l.xnor) fprintf(stderr, "convX ");
else fprintf(stderr, "conv ");
fprintf(stderr, "%5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BF\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);
return l;
@ -428,7 +430,7 @@ void denormalize_convolutional_layer(convolutional_layer l)
void test_convolutional_layer()
{
convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0);
convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0, 0);
l.batch_normalize = 1;
float data[] = {1,1,1,1,1,
1,1,1,1,1,

View File

@ -25,7 +25,7 @@ void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16);
#endif
#endif
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam);
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output);
void denormalize_convolutional_layer(convolutional_layer l);
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
void forward_convolutional_layer(const convolutional_layer layer, network_state state);

View File

@ -48,17 +48,17 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
l.input_layer = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
*(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0);
l.input_layer->batch = batch;
l.self_layer = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
*(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0);
l.self_layer->batch = batch;
l.output_layer = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
*(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0);
l.output_layer->batch = batch;
l.output = l.output_layer->output;

View File

@ -17,6 +17,10 @@ void im2col_align_ongpu(float *im,
int channels, int height, int width,
int ksize, int stride, int pad, float *data_col, int bit_align);
void im2col_align_bin_ongpu(float *im,
int channels, int height, int width,
int ksize, int stride, int pad, float *data_col, int bit_align);
void float_to_bit_gpu(float *src, unsigned char *dst, size_t size);
void transpose_bin_gpu(unsigned char *A, unsigned char *B, const int n, const int m,

View File

@ -45,6 +45,7 @@ __global__ void im2col_gpu_kernel(const int n, const float* data_im,
*data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ?
data_im_ptr[i * width + j] : 0;
//data_im[(channel_in * height + h_in) * width + w_in + i * width + j];
//*data_col_ptr = data_im_ptr[ii * width + jj];
data_col_ptr += height_col * width_col;
@ -69,7 +70,7 @@ void im2col_ongpu(float *im,
}
// --------------------------------
/*
__global__ void im2col_align_gpu_kernel(const int n, const float* data_im,
const int height, const int width, const int ksize,
const int pad,
@ -120,6 +121,71 @@ __global__ void im2col_align_gpu_kernel(const int n, const float* data_im,
}
}
}
*/
// float 32
__global__ void im2col_align_gpu_kernel(const int n, const float* data_im,
const int height, const int width, const int ksize,
const int pad,
const int stride,
const int height_col, const int width_col,
float *data_col, const int bit_align)
{
__shared__ float tmp_s[1];
//#define SHRED_VALS ((BLOCK / 169) * )
__shared__ float dst_s[1024];
//__shared__ float dst_s[1024];
//__shared__ uint32_t bit_s[32];
__shared__ uint8_t bit_s[128];
int index = blockIdx.x*blockDim.x + threadIdx.x;
for (; index < n; index += blockDim.x*gridDim.x) {
int w_out = index % width_col;
int h_index = index / width_col;
int h_out = h_index % height_col;
int channel_in = h_index / height_col;
int channel_out = channel_in * ksize * ksize;
int h_in = h_out * stride - pad;
int w_in = w_out * stride - pad;
//float* data_col_ptr = data_col;
//float* data_col_ptr_32 = data_col + (channel_out * bit_align + h_out * width_col + w_out) / 32;
//data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out;
//data_col_ptr += channel_out * bit_align + h_out * width_col + w_out;
float* data_col_ptr = &data_col[channel_out * bit_align + h_out * width_col + w_out];
const float* data_im_ptr = data_im;
data_im_ptr += (channel_in * height + h_in) * width + w_in;
for (int i = 0; i < ksize; ++i) {
for (int j = 0; j < ksize; ++j) {
int h = h_in + i;
int w = w_in + j;
float val = (h >= 0 && w >= 0 && h < height && w < width) ?
data_im_ptr[i * width + j] : 0;
int pre_out_index = index % (width_col*height_col);
int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index;// h_out * width_col + w_out;
data_col[out_index] = val;
//*data_col_ptr = val;
//dst_s[threadIdx.x] = val;
//tmp_s[0] = val;
//*data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ?
// data_im_ptr[i * width + j] : 0;
//float src_val = (h >= 0 && w >= 0 && h < height && w < width) ? data_im_ptr[i * width + j] : 0;
//unsigned int bit_mask = __ballot_sync(0xffffffff, src_val > 0);
//if (threadIdx.x % WARP_SIZE == 0) *((unsigned int*)data_col_ptr_32) = bit_mask;
// use atomicOr() // *dst_ptr |= (mask << (col_index % 8));
//data_col_ptr_32 += bit_align / 32;
//data_col_ptr += height_col * width_col;
data_col_ptr += bit_align;
}
}
}
}
void im2col_align_ongpu(float *im,
int channels, int height, int width,
@ -135,6 +201,354 @@ void im2col_align_ongpu(float *im,
stride, height_col,
width_col, data_col, bit_align);
}
// --------------------------------
/*
// binary im2col
__global__ void im2col_align_bin_gpu_kernel(const int n, const float* data_im,
const int height, const int width, const int ksize, const int channels,
const int pad,
const int stride,
const int height_col, const int width_col,
float *data_col, const int bit_align)
{
__shared__ float tmp_s[1];
//#define SHRED_VALS ((BLOCK / 169) * )
__shared__ float dst_s[1024];
//__shared__ float dst_s[1024];
//__shared__ uint32_t bit_s[32];
__shared__ uint8_t bit_s[128];
int index = blockIdx.x*blockDim.x + threadIdx.x;
for (; index < n; index += blockDim.x*gridDim.x)
{
//int c_index = index;
//int channel_in = c_index % channels;
int h_out = index % height_col;
int c_index = index / height_col;
int channel_in = c_index % channels;
int channel_out = channel_in * ksize * ksize;
int j_index = c_index / channels;
int j = j_index % ksize;
int i = j_index / ksize;
if (i < ksize)
{
for (int w_out = 0; w_out < width_col; ++w_out)
{
int h_in = h_out * stride - pad;
int w_in = w_out * stride - pad;
int h = h_in + i;
int w = w_in + j;
float val = (h >= 0 && w >= 0 && h < height && w < width) ?
data_im[(channel_in * height + h_in) * width + w_in + i * width + j] : 0;
//int pre_out_index = index % (width_col*height_col);
int pre_out_index = h_out * width_col + w_out;
int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index;
data_col[out_index] = val;
}// w_out
}
}
}
*/
/*
// binary im2col
__global__ void im2col_align_bin_gpu_kernel(const int n, const float* data_im,
const int height, const int width, const int ksize, const int channels,
const int pad,
const int stride,
const int height_col, const int width_col,
float *data_col, const int bit_align)
{
__shared__ float tmp_s[1];
__shared__ ulonglong4 tmp256_s[1];
//#define SHRED_VALS ((BLOCK / 169) * )
//__shared__ float dst_s[1024];
//__shared__ float dst_s[1024];
//__shared__ uint32_t bit_s[32];
//__shared__ uint8_t bit_s[128];
int index = blockIdx.x*blockDim.x + threadIdx.x;
//for (; index < n; index += blockDim.x*gridDim.x)
{
//int c_index = index;
//int channel_in = c_index % channels;
int h_out = index % height_col;
int c_index = index / height_col;
int channel_in = c_index % channels;
int channel_out = channel_in * ksize * ksize;
int j_index = c_index / channels;
int j = j_index % ksize;
int i = j_index / ksize;
int h_in = h_out * stride - pad;
int h = h_in + i;
//if (i < ksize)
{
int w_out = 0;
// the end of padding
//if(0)
for (; w_out < (width_col); w_out += 32)
{
int w = w_out * stride - pad + j;
int pre_in_index = (channel_in * height + h_in) * width + i * width;
int in_index = pre_in_index + w;
//float *src_p = (float *)&data_im[in_index];
int pre_out_index = h_out * width_col + w_out;
int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index;
// float *dst_p = (float *)&data_col[out_index];
if (i >= ksize) {
out_index = -1;
}
#pragma unroll
for (int t = 0; t < WARP_SIZE; ++t) {
const int lane_id = threadIdx.x % WARP_SIZE;
//const int64_t cur_pre_in_index = pre_in_index;
//const int64_t cur_j = j;
//const int64_t out_i = out_index;// __shfl(out_index, t) + lane_id;
const int64_t cur_out_index = __shfl(out_index, t);
if (cur_out_index >= 0)
{
const int64_t cur_pre_in_index = __shfl(pre_in_index, t);
const int64_t cur_j = __shfl(j, t);
const int64_t cur_h = __shfl(h, t);
int cur_w = ((w_out + lane_id) * stride - pad + cur_j);
int in_index = cur_pre_in_index + cur_w;
float val = (cur_w >= 0 && cur_w < width && cur_h >= 0 && cur_h < height) ?
data_im[in_index] : float();
if ((w_out + lane_id) < width_col) {
data_col[cur_out_index + lane_id] = val;
//tmp_s[0] = val;
//uint32_t bit_mask = __ballot(val > 0);
//uint8_t *bit8_ptr = &(((uint8_t *)data_col)[cur_out_index / 8]);
//uint32_t *bit32_ptr = (uint32_t *)bit8_ptr;
//*bit32_ptr = bit_mask;
}
}
}
}// w_out
#ifdef NOT_USED
if (i < ksize && h >= 0 && h < height)
{
// wait for align address and the end of padding
for (; w_out < width_col; ++w_out)
{
int w_in = w_out * stride - pad;
int w = w_in + j;
int in_index = (channel_in * height + h_in) * width + w_in + i * width + j;
float *src_p = (float *)&data_im[in_index];
int pre_out_index = h_out * width_col + w_out;
int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index;
float *dst_p = (float *)&data_col[out_index];
if (((uint64_t)src_p % 32 == 0) && ((uint64_t)dst_p % 32 == 0) && w > 0) {
//printf(" aligned addresses and there is no padding \n");
break;
}
float val = (w >= 0 && w < width) ?
(*src_p) : float();
*dst_p = val;
//tmp_s[0] = val;
}// w_out
// ulonglong4 (256 bit) / instead of float (32 bit) = 8x times
for (; w_out < (width_col - 8); w_out += 8)
{
int w_in = w_out * stride - pad;
int w = w_in + j;
ulonglong4 *src_p = (ulonglong4 *)&data_im[(channel_in * height + h_in) * width + w_in + i * width + j];
int pre_out_index = h_out * width_col + w_out;
int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index;
ulonglong4 *dst_p = (ulonglong4 *)&data_col[out_index];
ulonglong4 val = (w < width) ?
(*src_p) : ulonglong4();
*dst_p = val;
//tmp256_s[0] = val;
}// w_out
for (; w_out < width_col; ++w_out)
{
//int h_in = h_out * stride - pad;
int w_in = w_out * stride - pad;
//int h = h_in + i;
int w = w_in + j;
float val = (w < width) ?
data_im[(channel_in * height + h_in) * width + w_in + i * width + j] : 0;
int pre_out_index = h_out * width_col + w_out;
int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index;
data_col[out_index] = val;
//tmp_s[0] = val;
}// w_out
}
#endif // NOT_USED
}
}
}
*/
// binary im2col - stride=1
__global__ void im2col_align_bin_gpu_kernel(const int n, const float* data_im,
const int height, const int width, const int ksize, const int channels,
const int pad,
const int stride,
const int height_col, const int width_col,
float *data_col, const int bit_align)
{
__shared__ float tmp_s[1];
__shared__ ulonglong4 tmp256_s[1];
//#define SHRED_VALS ((BLOCK / 169) * )
//__shared__ float dst_s[1024];
//__shared__ float dst_s[1024];
//__shared__ uint32_t bit_s[32];
//__shared__ uint8_t bit_s[128];
int index = blockIdx.x*blockDim.x + threadIdx.x;
//for (; index < n; index += blockDim.x*gridDim.x)
{
int c_index = index;
int channel_in = c_index % channels;
//int h_out = index % height_col;
//int c_index = index / height_col;
//int channel_in = c_index % channels;
int channel_out = channel_in * ksize * ksize;
int j_index = c_index / channels;
int j = j_index % ksize;
int i = j_index / ksize;
int pre_out_index = (channel_out + i*ksize + j) * bit_align;
int j_pad = (j - pad);
int i_pad = (i - pad);
for(int wh_index = 0; wh_index < (height_col*width_col); wh_index += 32)
//for (int h_out = 0; h_out < height_col; ++h_out)
{
// the end of padding
//if(0)
//for (int w_out = 0; w_out < (width_col); w_out += 32)
{
const int w_out = wh_index % width_col;
const int h_out = wh_index / width_col;
const int w = w_out + j_pad;
const int h = h_out + i_pad;
int pre_in_index = channel_in * height * width;
int pre_in_wh_index = h * width + w;
int send_wh_index = wh_index;
if (i >= ksize) send_wh_index = height_col*width_col;
#pragma unroll
for (int t = 0; t < WARP_SIZE; ++t)
{
const int lane_id = threadIdx.x % WARP_SIZE;
const int cur_wh_index = __shfl(send_wh_index, t) + lane_id;
if (cur_wh_index < (width_col*height_col))// && (cur_i_pad+pad) < ksize)
{
const int cur_pre_out_index = __shfl(pre_out_index, t);
const int cur_pre_in_index = __shfl(pre_in_index, t);
const int cur_pre_in_wh_index = __shfl(pre_in_wh_index, t) + lane_id;
int w = cur_pre_in_wh_index % width;
int h = cur_pre_in_wh_index / width;
int in_index = cur_pre_in_index + cur_pre_in_wh_index;
int out_index = cur_pre_out_index + cur_wh_index;
float val = (w >= 0 && w < width && h >= 0 && h < height) ?
data_im[in_index] : float();
//data_col[out_index] = val;
//tmp_s[0] = val;
uint32_t bit_mask = __ballot(val > 0);
if (lane_id == 0) {
uint8_t *bit8_ptr = &(((uint8_t *)data_col)[out_index / 8]);
uint32_t *bit32_ptr = (uint32_t *)bit8_ptr;
*bit32_ptr = bit_mask;
}
}
}
}// w_out
}
}
}
void im2col_align_bin_ongpu(float *im,
int channels, int height, int width,
int ksize, int stride, int pad, float *data_col, int bit_align) {
// We are going to launch channels * height_col * width_col kernels, each
// kernel responsible for copying a single-channel grid.
int height_col = (height + 2 * pad - ksize) / stride + 1;
int width_col = (width + 2 * pad - ksize) / stride + 1;
//int num_kernels = channels * height_col * width_col * ksize * ksize;
//int num_kernels = channels * ksize * ksize * height_col;
int num_kernels = channels * ksize * ksize;
int num_blocks = num_kernels / BLOCK + 1;
//im2col_align_bin_gpu_kernel << <(num_kernels + BLOCK - 1) / BLOCK,
im2col_align_bin_gpu_kernel << <num_blocks,
BLOCK, 0, get_cuda_stream() >> >(
num_kernels, im, height, width, ksize, channels, pad,
stride, height_col,
width_col, data_col, bit_align);
}
// --------------------------------
@ -560,7 +974,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
int count = 0;
k = 0;
//#ifdef NON_USED
//#ifdef NOT_USED
// 32 thread X 64 bit = 2048 bit
for (; k < (K - 2048); k += 2048) { // l.size*l.size*l.c - one filter size [27 - 9216]
uint64_t c_bit64;
@ -591,7 +1005,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
}
//#endif
//#ifdef NON_USED
//#ifdef NOT_USED
// 32 thread X 32 bit = 1024 bit
for (; k < (K - 1024); k += 1024) { // l.size*l.size*l.c - one filter size [27 - 9216]
@ -626,7 +1040,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
float mean_val = mean_arr[i];
float bias_val = bias_arr[i];
//#ifdef NON_USED
//#ifdef NOT_USED
for (; k < K; k += 256) { // l.size*l.size*l.c - one filter size [27 - 144 - 9216]
//ulonglong4 a_bit256 = *((ulonglong4 *)(A + (i*lda + k) / 8)); // weights
ulonglong4 a_bit256 = *((ulonglong4 *)(A_s + (local_i*lda + k) / 8)); // weights
@ -638,7 +1052,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
}
//#endif
#ifdef NON_USED
#ifdef NOT_USED
for (; k < K; k += 64) { // l.size*l.size*l.c - one filter size [27 - 9216]
//uint64_t a_bit64 = *((uint64_t *)(A + (i*lda + k) / 8)); // weights
uint64_t a_bit64 = *((uint64_t *)(A_s + (local_i*lda + k) / 8)); // weights
@ -697,7 +1111,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
int count = 0;
k = 0;
//#ifdef NON_USED
//#ifdef NOT_USED
// 32 thread X 64 bit = 2048 bit
for (; k < (K - 2048); k += 2048) { // l.size*l.size*l.c - one filter size [27 - 9216]
uint64_t c_bit64;
@ -728,7 +1142,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
}
//#endif
//#ifdef NON_USED
//#ifdef NOT_USED
// 32 thread X 32 bit = 1024 bit
for (; k < (K - 1024); k += 1024) { // l.size*l.size*l.c - one filter size [27 - 9216]
@ -763,7 +1177,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
float mean_val = mean_arr[i];
float bias_val = bias_arr[i];
//#ifdef NON_USED
//#ifdef NOT_USED
for (; k < K; k += 256) { // l.size*l.size*l.c - one filter size [27 - 144 - 9216]
ulonglong4 a_bit256 = *((ulonglong4 *)(A + (i*lda + k) / 8)); // weights
//ulonglong4 b_bit256 = *((ulonglong4 *)(B + (j*ldb + k) / 8)); // input
@ -775,7 +1189,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
}
//#endif
#ifdef NON_USED
#ifdef NOT_USED
for (; k < K; k += 64) { // l.size*l.size*l.c - one filter size [27 - 9216]
uint64_t a_bit64 = *((uint64_t *)(A + (i*lda + k) / 8)); // weights
//uint64_t b_bit64 = *((uint64_t *)(B + (j*ldb + k) / 8)); // input

View File

@ -793,6 +793,7 @@ void free_network(network net)
#ifdef GPU
if (gpu_index >= 0) cuda_free(net.workspace);
else free(net.workspace);
if (net.input_state_gpu) cuda_free(net.input_state_gpu);
if (*net.input_gpu) cuda_free(*net.input_gpu);
if (*net.truth_gpu) cuda_free(*net.truth_gpu);
if (net.input_gpu) free(net.input_gpu);
@ -866,7 +867,7 @@ void calculate_binary_weights(network net)
binary_align_weights(l);
if(net.layers[j + 1].use_bin_output) {
if(net.layers[j].use_bin_output) {
l->activation = LINEAR;
}
}

View File

@ -64,6 +64,8 @@ typedef struct network{
tree *hierarchy;
#ifdef GPU
float *input_state_gpu;
float **input_gpu;
float **truth_gpu;
float **input16_gpu;

View File

@ -51,7 +51,7 @@ void forward_network_gpu(network net, network_state state)
for(i = 0; i < net.n; ++i){
state.index = i;
layer l = net.layers[i];
if(l.delta_gpu){
if(l.delta_gpu && state.train){
fill_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1);
}
l.forward_gpu(l, state);
@ -428,13 +428,15 @@ float *network_predict_gpu(network net, float *input)
network_state state;
state.index = 0;
state.net = net;
state.input = cuda_make_array(input, size);
//state.input = cuda_make_array(input, size); // memory will be allocated in the parse_network_cfg_custom()
state.input = net.input_state_gpu;
cuda_push_array(state.input, input, size);
state.truth = 0;
state.train = 0;
state.delta = 0;
forward_network_gpu(net, state);
float *out = get_network_output_gpu(net);
cuda_free(state.input);
//cuda_free(state.input); // will be freed in the free_network()
return out;
}

View File

@ -163,11 +163,12 @@ convolutional_layer parse_convolutional(list *options, size_params params)
int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
int binary = option_find_int_quiet(options, "binary", 0);
int xnor = option_find_int_quiet(options, "xnor", 0);
int use_bin_output = option_find_int_quiet(options, "bin_output", 0);
convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam);
convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output);
layer.flipped = option_find_int_quiet(options, "flipped", 0);
layer.dot = option_find_float_quiet(options, "dot", 0);
layer.use_bin_output = option_find_int_quiet(options, "bin_output", 0);
if(params.net.adam){
layer.B1 = params.net.B1;
layer.B2 = params.net.B2;
@ -819,6 +820,8 @@ network parse_network_cfg_custom(char *filename, int batch)
#ifdef GPU
if(gpu_index >= 0){
net.workspace = cuda_make_array(0, workspace_size/sizeof(float) + 1);
int size = get_network_input_size(net) * net.batch;
net.input_state_gpu = cuda_make_array(0, size);
}else {
net.workspace = calloc(1, workspace_size);
}

View File

@ -368,9 +368,11 @@ void backward_region_layer(const region_layer l, network_state state)
void get_region_boxes(layer l, int w, int h, float thresh, float **probs, box *boxes, int only_objectness, int *map)
{
int i,j,n;
float *predictions = l.output;
int i;
float *const predictions = l.output;
#pragma omp parallel for
for (i = 0; i < l.w*l.h; ++i){
int j, n;
int row = i / l.w;
int col = i % l.w;
for(n = 0; n < l.n; ++n){