mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
XNOR-net tiny-yolo_xnor.cfg ~2x faster than cuDNN on CUDA (nVidia GPU Maxwell)
This commit is contained in:
@ -32,6 +32,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=32
|
||||
size=3
|
||||
@ -45,6 +46,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=64
|
||||
size=3
|
||||
@ -58,6 +60,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=3
|
||||
@ -71,6 +74,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=3
|
||||
@ -84,6 +88,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
@ -97,6 +102,7 @@ stride=1
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=1024
|
||||
size=3
|
||||
|
@ -36,6 +36,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=32
|
||||
size=3
|
||||
@ -49,6 +50,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=64
|
||||
size=3
|
||||
@ -62,6 +64,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=3
|
||||
@ -88,6 +91,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
@ -101,6 +105,7 @@ stride=1
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=1024
|
||||
size=3
|
||||
@ -173,6 +178,7 @@ stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
|
||||
[convolutional]
|
||||
size=1
|
||||
stride=1
|
||||
|
@ -32,6 +32,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=32
|
||||
size=3
|
||||
@ -45,6 +46,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=64
|
||||
size=3
|
||||
@ -58,6 +60,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=3
|
||||
@ -71,6 +74,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=3
|
||||
@ -84,6 +88,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
@ -97,6 +102,7 @@ stride=1
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=1024
|
||||
size=3
|
||||
|
@ -36,6 +36,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=32
|
||||
size=3
|
||||
@ -49,6 +50,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=64
|
||||
size=3
|
||||
@ -62,6 +64,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=3
|
||||
@ -88,6 +91,7 @@ stride=2
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
@ -101,6 +105,7 @@ stride=1
|
||||
|
||||
[convolutional]
|
||||
xnor=1
|
||||
bin_output=1
|
||||
batch_normalize=1
|
||||
filters=1024
|
||||
size=3
|
||||
@ -173,6 +178,7 @@ stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
|
||||
[convolutional]
|
||||
size=1
|
||||
stride=1
|
||||
|
@ -119,18 +119,15 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
if(l.xnor){
|
||||
if (!l.align_bit_weights_gpu || state.train) {
|
||||
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
||||
|
||||
swap_binary(&l);
|
||||
binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
|
||||
state.input = l.binary_input_gpu;
|
||||
}
|
||||
//swap_binary(&l);
|
||||
//binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
|
||||
//state.input = l.binary_input_gpu;
|
||||
//cudaDeviceSynchronize();
|
||||
|
||||
if (l.align_bit_weights_gpu && !state.train)
|
||||
if (l.align_bit_weights_gpu && !state.train && l.c >= 256 && l.size > 1)
|
||||
{
|
||||
//return;
|
||||
cudaError_t status = cudaSuccess;
|
||||
int input_size = l.c*l.h*l.w*l.batch;
|
||||
|
||||
@ -146,15 +143,25 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
|
||||
//if(0)
|
||||
{
|
||||
//cudaDeviceSynchronize();
|
||||
|
||||
int i = 0;
|
||||
im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.align_workspace_gpu, l.bit_align);
|
||||
//cudaDeviceSynchronize();
|
||||
if (l.stride == 1 && l.c >= 256 && l.w > 13 && l.size > 1 && 0) // disabled
|
||||
{
|
||||
// stride=1 only
|
||||
im2col_align_bin_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace, l.bit_align);
|
||||
//cudaDeviceSynchronize();
|
||||
}
|
||||
else
|
||||
{
|
||||
im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.align_workspace_gpu, l.bit_align);
|
||||
//cudaDeviceSynchronize();
|
||||
//getchar();
|
||||
|
||||
// should be optimized
|
||||
float_to_bit_gpu(l.align_workspace_gpu, (unsigned char *)state.workspace, l.align_workspace_size);
|
||||
//cudaDeviceSynchronize();
|
||||
|
||||
//im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace, l.bit_align);
|
||||
// should be optimized
|
||||
float_to_bit_gpu(l.align_workspace_gpu, (unsigned char *)state.workspace, l.align_workspace_size);
|
||||
//cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)l.transposed_align_workspace_gpu, k, n, l.bit_align, new_ldb, 8);
|
||||
//cudaDeviceSynchronize();
|
||||
@ -197,7 +204,13 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
}
|
||||
}
|
||||
|
||||
fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
|
||||
if (l.xnor) {
|
||||
swap_binary(&l);
|
||||
binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
|
||||
state.input = l.binary_input_gpu;
|
||||
}
|
||||
|
||||
//fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
|
||||
|
||||
#ifdef CUDNN
|
||||
float one = 1; // alpha[0], beta[0] is float for HALF and FLOAT
|
||||
@ -294,7 +307,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
#else
|
||||
|
||||
cudnnConvolutionForward(cudnn_handle(),
|
||||
&one,
|
||||
&alpha, //&one,
|
||||
l.srcTensorDesc,
|
||||
state.input,
|
||||
l.weightDesc,
|
||||
@ -303,9 +316,11 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
l.fw_algo,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&one,
|
||||
&beta, //&one,
|
||||
l.dstTensorDesc,
|
||||
l.output_gpu);
|
||||
|
||||
//cudaDeviceSynchronize();
|
||||
#endif // CUDNN_HALF
|
||||
|
||||
|
||||
@ -338,7 +353,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
}
|
||||
#endif // no CUDNN_HALF
|
||||
|
||||
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||
if (l.activation != LINEAR) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||
//if(l.dot > 0) dot_error_gpu(l);
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
//cudaDeviceSynchronize(); // for correct profiling of performance
|
||||
|
@ -257,7 +257,7 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam)
|
||||
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output)
|
||||
{
|
||||
int i;
|
||||
convolutional_layer l = {0};
|
||||
@ -269,6 +269,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
l.n = n;
|
||||
l.binary = binary;
|
||||
l.xnor = xnor;
|
||||
l.use_bin_output = use_bin_output;
|
||||
l.batch = batch;
|
||||
l.stride = stride;
|
||||
l.size = size;
|
||||
@ -307,7 +308,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
l.binary_weights = calloc(c*n*size*size, sizeof(float));
|
||||
l.binary_input = calloc(l.inputs*l.batch, sizeof(float));
|
||||
|
||||
int align = 8;
|
||||
int align = 32;// 8;
|
||||
int src_align = l.out_h*l.out_w;
|
||||
l.bit_align = src_align + (align - src_align % align);
|
||||
}
|
||||
@ -404,8 +405,9 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
|
||||
//fprintf(stderr, "conv %5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c);
|
||||
l.bflops = (2.0 * l.n * l.size*l.size*l.c * l.out_h*l.out_w) / 1000000000.;
|
||||
if (l.xnor) fprintf(stderr, "convX ");
|
||||
else fprintf(stderr, "conv ");
|
||||
if (l.xnor && l.use_bin_output) fprintf(stderr, "convXB");
|
||||
else if (l.xnor) fprintf(stderr, "convX ");
|
||||
else fprintf(stderr, "conv ");
|
||||
fprintf(stderr, "%5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BF\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);
|
||||
|
||||
return l;
|
||||
@ -428,7 +430,7 @@ void denormalize_convolutional_layer(convolutional_layer l)
|
||||
|
||||
void test_convolutional_layer()
|
||||
{
|
||||
convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0);
|
||||
convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0, 0);
|
||||
l.batch_normalize = 1;
|
||||
float data[] = {1,1,1,1,1,
|
||||
1,1,1,1,1,
|
||||
|
@ -25,7 +25,7 @@ void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam);
|
||||
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output);
|
||||
void denormalize_convolutional_layer(convolutional_layer l);
|
||||
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
|
||||
void forward_convolutional_layer(const convolutional_layer layer, network_state state);
|
||||
|
@ -48,17 +48,17 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
|
||||
|
||||
l.input_layer = malloc(sizeof(layer));
|
||||
fprintf(stderr, "\t\t");
|
||||
*(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
|
||||
*(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0);
|
||||
l.input_layer->batch = batch;
|
||||
|
||||
l.self_layer = malloc(sizeof(layer));
|
||||
fprintf(stderr, "\t\t");
|
||||
*(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
|
||||
*(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0);
|
||||
l.self_layer->batch = batch;
|
||||
|
||||
l.output_layer = malloc(sizeof(layer));
|
||||
fprintf(stderr, "\t\t");
|
||||
*(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
|
||||
*(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0);
|
||||
l.output_layer->batch = batch;
|
||||
|
||||
l.output = l.output_layer->output;
|
||||
|
@ -17,6 +17,10 @@ void im2col_align_ongpu(float *im,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, int pad, float *data_col, int bit_align);
|
||||
|
||||
void im2col_align_bin_ongpu(float *im,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, int pad, float *data_col, int bit_align);
|
||||
|
||||
void float_to_bit_gpu(float *src, unsigned char *dst, size_t size);
|
||||
|
||||
void transpose_bin_gpu(unsigned char *A, unsigned char *B, const int n, const int m,
|
||||
|
@ -45,6 +45,7 @@ __global__ void im2col_gpu_kernel(const int n, const float* data_im,
|
||||
*data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ?
|
||||
data_im_ptr[i * width + j] : 0;
|
||||
|
||||
//data_im[(channel_in * height + h_in) * width + w_in + i * width + j];
|
||||
//*data_col_ptr = data_im_ptr[ii * width + jj];
|
||||
|
||||
data_col_ptr += height_col * width_col;
|
||||
@ -69,7 +70,7 @@ void im2col_ongpu(float *im,
|
||||
}
|
||||
// --------------------------------
|
||||
|
||||
|
||||
/*
|
||||
__global__ void im2col_align_gpu_kernel(const int n, const float* data_im,
|
||||
const int height, const int width, const int ksize,
|
||||
const int pad,
|
||||
@ -120,6 +121,71 @@ __global__ void im2col_align_gpu_kernel(const int n, const float* data_im,
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// float 32
|
||||
__global__ void im2col_align_gpu_kernel(const int n, const float* data_im,
|
||||
const int height, const int width, const int ksize,
|
||||
const int pad,
|
||||
const int stride,
|
||||
const int height_col, const int width_col,
|
||||
float *data_col, const int bit_align)
|
||||
{
|
||||
__shared__ float tmp_s[1];
|
||||
|
||||
//#define SHRED_VALS ((BLOCK / 169) * )
|
||||
__shared__ float dst_s[1024];
|
||||
//__shared__ float dst_s[1024];
|
||||
//__shared__ uint32_t bit_s[32];
|
||||
__shared__ uint8_t bit_s[128];
|
||||
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
for (; index < n; index += blockDim.x*gridDim.x) {
|
||||
int w_out = index % width_col;
|
||||
int h_index = index / width_col;
|
||||
int h_out = h_index % height_col;
|
||||
int channel_in = h_index / height_col;
|
||||
int channel_out = channel_in * ksize * ksize;
|
||||
int h_in = h_out * stride - pad;
|
||||
int w_in = w_out * stride - pad;
|
||||
//float* data_col_ptr = data_col;
|
||||
//float* data_col_ptr_32 = data_col + (channel_out * bit_align + h_out * width_col + w_out) / 32;
|
||||
//data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out;
|
||||
//data_col_ptr += channel_out * bit_align + h_out * width_col + w_out;
|
||||
float* data_col_ptr = &data_col[channel_out * bit_align + h_out * width_col + w_out];
|
||||
const float* data_im_ptr = data_im;
|
||||
data_im_ptr += (channel_in * height + h_in) * width + w_in;
|
||||
for (int i = 0; i < ksize; ++i) {
|
||||
for (int j = 0; j < ksize; ++j) {
|
||||
int h = h_in + i;
|
||||
int w = w_in + j;
|
||||
|
||||
float val = (h >= 0 && w >= 0 && h < height && w < width) ?
|
||||
data_im_ptr[i * width + j] : 0;
|
||||
|
||||
int pre_out_index = index % (width_col*height_col);
|
||||
int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index;// h_out * width_col + w_out;
|
||||
data_col[out_index] = val;
|
||||
|
||||
//*data_col_ptr = val;
|
||||
//dst_s[threadIdx.x] = val;
|
||||
//tmp_s[0] = val;
|
||||
|
||||
//*data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ?
|
||||
// data_im_ptr[i * width + j] : 0;
|
||||
|
||||
//float src_val = (h >= 0 && w >= 0 && h < height && w < width) ? data_im_ptr[i * width + j] : 0;
|
||||
//unsigned int bit_mask = __ballot_sync(0xffffffff, src_val > 0);
|
||||
//if (threadIdx.x % WARP_SIZE == 0) *((unsigned int*)data_col_ptr_32) = bit_mask;
|
||||
// use atomicOr() // *dst_ptr |= (mask << (col_index % 8));
|
||||
//data_col_ptr_32 += bit_align / 32;
|
||||
|
||||
//data_col_ptr += height_col * width_col;
|
||||
data_col_ptr += bit_align;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void im2col_align_ongpu(float *im,
|
||||
int channels, int height, int width,
|
||||
@ -135,6 +201,354 @@ void im2col_align_ongpu(float *im,
|
||||
stride, height_col,
|
||||
width_col, data_col, bit_align);
|
||||
}
|
||||
|
||||
|
||||
// --------------------------------
|
||||
|
||||
/*
|
||||
// binary im2col
|
||||
__global__ void im2col_align_bin_gpu_kernel(const int n, const float* data_im,
|
||||
const int height, const int width, const int ksize, const int channels,
|
||||
const int pad,
|
||||
const int stride,
|
||||
const int height_col, const int width_col,
|
||||
float *data_col, const int bit_align)
|
||||
{
|
||||
__shared__ float tmp_s[1];
|
||||
|
||||
//#define SHRED_VALS ((BLOCK / 169) * )
|
||||
__shared__ float dst_s[1024];
|
||||
//__shared__ float dst_s[1024];
|
||||
//__shared__ uint32_t bit_s[32];
|
||||
__shared__ uint8_t bit_s[128];
|
||||
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
for (; index < n; index += blockDim.x*gridDim.x)
|
||||
{
|
||||
//int c_index = index;
|
||||
//int channel_in = c_index % channels;
|
||||
|
||||
int h_out = index % height_col;
|
||||
int c_index = index / height_col;
|
||||
int channel_in = c_index % channels;
|
||||
|
||||
int channel_out = channel_in * ksize * ksize;
|
||||
|
||||
int j_index = c_index / channels;
|
||||
int j = j_index % ksize;
|
||||
int i = j_index / ksize;
|
||||
if (i < ksize)
|
||||
{
|
||||
for (int w_out = 0; w_out < width_col; ++w_out)
|
||||
{
|
||||
int h_in = h_out * stride - pad;
|
||||
int w_in = w_out * stride - pad;
|
||||
|
||||
int h = h_in + i;
|
||||
int w = w_in + j;
|
||||
|
||||
float val = (h >= 0 && w >= 0 && h < height && w < width) ?
|
||||
data_im[(channel_in * height + h_in) * width + w_in + i * width + j] : 0;
|
||||
|
||||
//int pre_out_index = index % (width_col*height_col);
|
||||
int pre_out_index = h_out * width_col + w_out;
|
||||
int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index;
|
||||
data_col[out_index] = val;
|
||||
|
||||
|
||||
}// w_out
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
/*
|
||||
// binary im2col
|
||||
__global__ void im2col_align_bin_gpu_kernel(const int n, const float* data_im,
|
||||
const int height, const int width, const int ksize, const int channels,
|
||||
const int pad,
|
||||
const int stride,
|
||||
const int height_col, const int width_col,
|
||||
float *data_col, const int bit_align)
|
||||
{
|
||||
__shared__ float tmp_s[1];
|
||||
__shared__ ulonglong4 tmp256_s[1];
|
||||
|
||||
|
||||
//#define SHRED_VALS ((BLOCK / 169) * )
|
||||
//__shared__ float dst_s[1024];
|
||||
//__shared__ float dst_s[1024];
|
||||
//__shared__ uint32_t bit_s[32];
|
||||
//__shared__ uint8_t bit_s[128];
|
||||
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
//for (; index < n; index += blockDim.x*gridDim.x)
|
||||
{
|
||||
//int c_index = index;
|
||||
//int channel_in = c_index % channels;
|
||||
|
||||
int h_out = index % height_col;
|
||||
int c_index = index / height_col;
|
||||
int channel_in = c_index % channels;
|
||||
|
||||
int channel_out = channel_in * ksize * ksize;
|
||||
|
||||
int j_index = c_index / channels;
|
||||
int j = j_index % ksize;
|
||||
int i = j_index / ksize;
|
||||
|
||||
int h_in = h_out * stride - pad;
|
||||
int h = h_in + i;
|
||||
|
||||
//if (i < ksize)
|
||||
{
|
||||
int w_out = 0;
|
||||
|
||||
// the end of padding
|
||||
//if(0)
|
||||
for (; w_out < (width_col); w_out += 32)
|
||||
{
|
||||
int w = w_out * stride - pad + j;
|
||||
int pre_in_index = (channel_in * height + h_in) * width + i * width;
|
||||
int in_index = pre_in_index + w;
|
||||
//float *src_p = (float *)&data_im[in_index];
|
||||
|
||||
int pre_out_index = h_out * width_col + w_out;
|
||||
int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index;
|
||||
// float *dst_p = (float *)&data_col[out_index];
|
||||
|
||||
if (i >= ksize) {
|
||||
out_index = -1;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int t = 0; t < WARP_SIZE; ++t) {
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
//const int64_t cur_pre_in_index = pre_in_index;
|
||||
//const int64_t cur_j = j;
|
||||
//const int64_t out_i = out_index;// __shfl(out_index, t) + lane_id;
|
||||
|
||||
const int64_t cur_out_index = __shfl(out_index, t);
|
||||
if (cur_out_index >= 0)
|
||||
{
|
||||
const int64_t cur_pre_in_index = __shfl(pre_in_index, t);
|
||||
const int64_t cur_j = __shfl(j, t);
|
||||
const int64_t cur_h = __shfl(h, t);
|
||||
|
||||
int cur_w = ((w_out + lane_id) * stride - pad + cur_j);
|
||||
int in_index = cur_pre_in_index + cur_w;
|
||||
|
||||
float val = (cur_w >= 0 && cur_w < width && cur_h >= 0 && cur_h < height) ?
|
||||
data_im[in_index] : float();
|
||||
|
||||
if ((w_out + lane_id) < width_col) {
|
||||
data_col[cur_out_index + lane_id] = val;
|
||||
//tmp_s[0] = val;
|
||||
|
||||
//uint32_t bit_mask = __ballot(val > 0);
|
||||
//uint8_t *bit8_ptr = &(((uint8_t *)data_col)[cur_out_index / 8]);
|
||||
//uint32_t *bit32_ptr = (uint32_t *)bit8_ptr;
|
||||
//*bit32_ptr = bit_mask;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}// w_out
|
||||
|
||||
#ifdef NOT_USED
|
||||
if (i < ksize && h >= 0 && h < height)
|
||||
{
|
||||
|
||||
// wait for align address and the end of padding
|
||||
for (; w_out < width_col; ++w_out)
|
||||
{
|
||||
int w_in = w_out * stride - pad;
|
||||
int w = w_in + j;
|
||||
|
||||
int in_index = (channel_in * height + h_in) * width + w_in + i * width + j;
|
||||
float *src_p = (float *)&data_im[in_index];
|
||||
|
||||
int pre_out_index = h_out * width_col + w_out;
|
||||
int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index;
|
||||
float *dst_p = (float *)&data_col[out_index];
|
||||
|
||||
if (((uint64_t)src_p % 32 == 0) && ((uint64_t)dst_p % 32 == 0) && w > 0) {
|
||||
//printf(" aligned addresses and there is no padding \n");
|
||||
break;
|
||||
}
|
||||
|
||||
float val = (w >= 0 && w < width) ?
|
||||
(*src_p) : float();
|
||||
|
||||
*dst_p = val;
|
||||
//tmp_s[0] = val;
|
||||
}// w_out
|
||||
|
||||
// ulonglong4 (256 bit) / instead of float (32 bit) = 8x times
|
||||
for (; w_out < (width_col - 8); w_out += 8)
|
||||
{
|
||||
int w_in = w_out * stride - pad;
|
||||
int w = w_in + j;
|
||||
|
||||
ulonglong4 *src_p = (ulonglong4 *)&data_im[(channel_in * height + h_in) * width + w_in + i * width + j];
|
||||
|
||||
int pre_out_index = h_out * width_col + w_out;
|
||||
int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index;
|
||||
ulonglong4 *dst_p = (ulonglong4 *)&data_col[out_index];
|
||||
|
||||
ulonglong4 val = (w < width) ?
|
||||
(*src_p) : ulonglong4();
|
||||
|
||||
*dst_p = val;
|
||||
//tmp256_s[0] = val;
|
||||
}// w_out
|
||||
|
||||
for (; w_out < width_col; ++w_out)
|
||||
{
|
||||
//int h_in = h_out * stride - pad;
|
||||
int w_in = w_out * stride - pad;
|
||||
|
||||
//int h = h_in + i;
|
||||
int w = w_in + j;
|
||||
|
||||
float val = (w < width) ?
|
||||
data_im[(channel_in * height + h_in) * width + w_in + i * width + j] : 0;
|
||||
|
||||
int pre_out_index = h_out * width_col + w_out;
|
||||
int out_index = (channel_out + i*ksize + j) * bit_align + pre_out_index;
|
||||
data_col[out_index] = val;
|
||||
//tmp_s[0] = val;
|
||||
}// w_out
|
||||
}
|
||||
#endif // NOT_USED
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
// binary im2col - stride=1
|
||||
__global__ void im2col_align_bin_gpu_kernel(const int n, const float* data_im,
|
||||
const int height, const int width, const int ksize, const int channels,
|
||||
const int pad,
|
||||
const int stride,
|
||||
const int height_col, const int width_col,
|
||||
float *data_col, const int bit_align)
|
||||
{
|
||||
__shared__ float tmp_s[1];
|
||||
__shared__ ulonglong4 tmp256_s[1];
|
||||
|
||||
|
||||
//#define SHRED_VALS ((BLOCK / 169) * )
|
||||
//__shared__ float dst_s[1024];
|
||||
//__shared__ float dst_s[1024];
|
||||
//__shared__ uint32_t bit_s[32];
|
||||
//__shared__ uint8_t bit_s[128];
|
||||
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
//for (; index < n; index += blockDim.x*gridDim.x)
|
||||
{
|
||||
int c_index = index;
|
||||
int channel_in = c_index % channels;
|
||||
|
||||
//int h_out = index % height_col;
|
||||
//int c_index = index / height_col;
|
||||
//int channel_in = c_index % channels;
|
||||
|
||||
int channel_out = channel_in * ksize * ksize;
|
||||
|
||||
int j_index = c_index / channels;
|
||||
int j = j_index % ksize;
|
||||
int i = j_index / ksize;
|
||||
|
||||
int pre_out_index = (channel_out + i*ksize + j) * bit_align;
|
||||
int j_pad = (j - pad);
|
||||
int i_pad = (i - pad);
|
||||
|
||||
for(int wh_index = 0; wh_index < (height_col*width_col); wh_index += 32)
|
||||
//for (int h_out = 0; h_out < height_col; ++h_out)
|
||||
{
|
||||
|
||||
// the end of padding
|
||||
//if(0)
|
||||
//for (int w_out = 0; w_out < (width_col); w_out += 32)
|
||||
{
|
||||
const int w_out = wh_index % width_col;
|
||||
const int h_out = wh_index / width_col;
|
||||
|
||||
const int w = w_out + j_pad;
|
||||
const int h = h_out + i_pad;
|
||||
|
||||
int pre_in_index = channel_in * height * width;
|
||||
int pre_in_wh_index = h * width + w;
|
||||
|
||||
int send_wh_index = wh_index;
|
||||
if (i >= ksize) send_wh_index = height_col*width_col;
|
||||
|
||||
#pragma unroll
|
||||
for (int t = 0; t < WARP_SIZE; ++t)
|
||||
{
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
const int cur_wh_index = __shfl(send_wh_index, t) + lane_id;
|
||||
|
||||
if (cur_wh_index < (width_col*height_col))// && (cur_i_pad+pad) < ksize)
|
||||
{
|
||||
const int cur_pre_out_index = __shfl(pre_out_index, t);
|
||||
|
||||
const int cur_pre_in_index = __shfl(pre_in_index, t);
|
||||
const int cur_pre_in_wh_index = __shfl(pre_in_wh_index, t) + lane_id;
|
||||
|
||||
int w = cur_pre_in_wh_index % width;
|
||||
int h = cur_pre_in_wh_index / width;
|
||||
int in_index = cur_pre_in_index + cur_pre_in_wh_index;
|
||||
|
||||
int out_index = cur_pre_out_index + cur_wh_index;
|
||||
|
||||
float val = (w >= 0 && w < width && h >= 0 && h < height) ?
|
||||
data_im[in_index] : float();
|
||||
|
||||
//data_col[out_index] = val;
|
||||
//tmp_s[0] = val;
|
||||
|
||||
uint32_t bit_mask = __ballot(val > 0);
|
||||
if (lane_id == 0) {
|
||||
uint8_t *bit8_ptr = &(((uint8_t *)data_col)[out_index / 8]);
|
||||
uint32_t *bit32_ptr = (uint32_t *)bit8_ptr;
|
||||
*bit32_ptr = bit_mask;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
}// w_out
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void im2col_align_bin_ongpu(float *im,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, int pad, float *data_col, int bit_align) {
|
||||
// We are going to launch channels * height_col * width_col kernels, each
|
||||
// kernel responsible for copying a single-channel grid.
|
||||
int height_col = (height + 2 * pad - ksize) / stride + 1;
|
||||
int width_col = (width + 2 * pad - ksize) / stride + 1;
|
||||
//int num_kernels = channels * height_col * width_col * ksize * ksize;
|
||||
//int num_kernels = channels * ksize * ksize * height_col;
|
||||
int num_kernels = channels * ksize * ksize;
|
||||
int num_blocks = num_kernels / BLOCK + 1;
|
||||
|
||||
//im2col_align_bin_gpu_kernel << <(num_kernels + BLOCK - 1) / BLOCK,
|
||||
im2col_align_bin_gpu_kernel << <num_blocks,
|
||||
BLOCK, 0, get_cuda_stream() >> >(
|
||||
num_kernels, im, height, width, ksize, channels, pad,
|
||||
stride, height_col,
|
||||
width_col, data_col, bit_align);
|
||||
}
|
||||
// --------------------------------
|
||||
|
||||
|
||||
@ -560,7 +974,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
||||
int count = 0;
|
||||
k = 0;
|
||||
|
||||
//#ifdef NON_USED
|
||||
//#ifdef NOT_USED
|
||||
// 32 thread X 64 bit = 2048 bit
|
||||
for (; k < (K - 2048); k += 2048) { // l.size*l.size*l.c - one filter size [27 - 9216]
|
||||
uint64_t c_bit64;
|
||||
@ -591,7 +1005,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
||||
}
|
||||
//#endif
|
||||
|
||||
//#ifdef NON_USED
|
||||
//#ifdef NOT_USED
|
||||
// 32 thread X 32 bit = 1024 bit
|
||||
for (; k < (K - 1024); k += 1024) { // l.size*l.size*l.c - one filter size [27 - 9216]
|
||||
|
||||
@ -626,7 +1040,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
||||
float mean_val = mean_arr[i];
|
||||
float bias_val = bias_arr[i];
|
||||
|
||||
//#ifdef NON_USED
|
||||
//#ifdef NOT_USED
|
||||
for (; k < K; k += 256) { // l.size*l.size*l.c - one filter size [27 - 144 - 9216]
|
||||
//ulonglong4 a_bit256 = *((ulonglong4 *)(A + (i*lda + k) / 8)); // weights
|
||||
ulonglong4 a_bit256 = *((ulonglong4 *)(A_s + (local_i*lda + k) / 8)); // weights
|
||||
@ -638,7 +1052,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
||||
}
|
||||
//#endif
|
||||
|
||||
#ifdef NON_USED
|
||||
#ifdef NOT_USED
|
||||
for (; k < K; k += 64) { // l.size*l.size*l.c - one filter size [27 - 9216]
|
||||
//uint64_t a_bit64 = *((uint64_t *)(A + (i*lda + k) / 8)); // weights
|
||||
uint64_t a_bit64 = *((uint64_t *)(A_s + (local_i*lda + k) / 8)); // weights
|
||||
@ -697,7 +1111,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
||||
int count = 0;
|
||||
k = 0;
|
||||
|
||||
//#ifdef NON_USED
|
||||
//#ifdef NOT_USED
|
||||
// 32 thread X 64 bit = 2048 bit
|
||||
for (; k < (K - 2048); k += 2048) { // l.size*l.size*l.c - one filter size [27 - 9216]
|
||||
uint64_t c_bit64;
|
||||
@ -728,7 +1142,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
||||
}
|
||||
//#endif
|
||||
|
||||
//#ifdef NON_USED
|
||||
//#ifdef NOT_USED
|
||||
// 32 thread X 32 bit = 1024 bit
|
||||
for (; k < (K - 1024); k += 1024) { // l.size*l.size*l.c - one filter size [27 - 9216]
|
||||
|
||||
@ -763,7 +1177,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
||||
float mean_val = mean_arr[i];
|
||||
float bias_val = bias_arr[i];
|
||||
|
||||
//#ifdef NON_USED
|
||||
//#ifdef NOT_USED
|
||||
for (; k < K; k += 256) { // l.size*l.size*l.c - one filter size [27 - 144 - 9216]
|
||||
ulonglong4 a_bit256 = *((ulonglong4 *)(A + (i*lda + k) / 8)); // weights
|
||||
//ulonglong4 b_bit256 = *((ulonglong4 *)(B + (j*ldb + k) / 8)); // input
|
||||
@ -775,7 +1189,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
||||
}
|
||||
//#endif
|
||||
|
||||
#ifdef NON_USED
|
||||
#ifdef NOT_USED
|
||||
for (; k < K; k += 64) { // l.size*l.size*l.c - one filter size [27 - 9216]
|
||||
uint64_t a_bit64 = *((uint64_t *)(A + (i*lda + k) / 8)); // weights
|
||||
//uint64_t b_bit64 = *((uint64_t *)(B + (j*ldb + k) / 8)); // input
|
||||
|
@ -793,6 +793,7 @@ void free_network(network net)
|
||||
#ifdef GPU
|
||||
if (gpu_index >= 0) cuda_free(net.workspace);
|
||||
else free(net.workspace);
|
||||
if (net.input_state_gpu) cuda_free(net.input_state_gpu);
|
||||
if (*net.input_gpu) cuda_free(*net.input_gpu);
|
||||
if (*net.truth_gpu) cuda_free(*net.truth_gpu);
|
||||
if (net.input_gpu) free(net.input_gpu);
|
||||
@ -866,7 +867,7 @@ void calculate_binary_weights(network net)
|
||||
|
||||
binary_align_weights(l);
|
||||
|
||||
if(net.layers[j + 1].use_bin_output) {
|
||||
if(net.layers[j].use_bin_output) {
|
||||
l->activation = LINEAR;
|
||||
}
|
||||
}
|
||||
|
@ -64,6 +64,8 @@ typedef struct network{
|
||||
tree *hierarchy;
|
||||
|
||||
#ifdef GPU
|
||||
float *input_state_gpu;
|
||||
|
||||
float **input_gpu;
|
||||
float **truth_gpu;
|
||||
float **input16_gpu;
|
||||
|
@ -51,7 +51,7 @@ void forward_network_gpu(network net, network_state state)
|
||||
for(i = 0; i < net.n; ++i){
|
||||
state.index = i;
|
||||
layer l = net.layers[i];
|
||||
if(l.delta_gpu){
|
||||
if(l.delta_gpu && state.train){
|
||||
fill_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1);
|
||||
}
|
||||
l.forward_gpu(l, state);
|
||||
@ -428,13 +428,15 @@ float *network_predict_gpu(network net, float *input)
|
||||
network_state state;
|
||||
state.index = 0;
|
||||
state.net = net;
|
||||
state.input = cuda_make_array(input, size);
|
||||
//state.input = cuda_make_array(input, size); // memory will be allocated in the parse_network_cfg_custom()
|
||||
state.input = net.input_state_gpu;
|
||||
cuda_push_array(state.input, input, size);
|
||||
state.truth = 0;
|
||||
state.train = 0;
|
||||
state.delta = 0;
|
||||
forward_network_gpu(net, state);
|
||||
float *out = get_network_output_gpu(net);
|
||||
cuda_free(state.input);
|
||||
//cuda_free(state.input); // will be freed in the free_network()
|
||||
return out;
|
||||
}
|
||||
|
||||
|
@ -163,11 +163,12 @@ convolutional_layer parse_convolutional(list *options, size_params params)
|
||||
int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
|
||||
int binary = option_find_int_quiet(options, "binary", 0);
|
||||
int xnor = option_find_int_quiet(options, "xnor", 0);
|
||||
int use_bin_output = option_find_int_quiet(options, "bin_output", 0);
|
||||
|
||||
convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam);
|
||||
convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output);
|
||||
layer.flipped = option_find_int_quiet(options, "flipped", 0);
|
||||
layer.dot = option_find_float_quiet(options, "dot", 0);
|
||||
layer.use_bin_output = option_find_int_quiet(options, "bin_output", 0);
|
||||
|
||||
if(params.net.adam){
|
||||
layer.B1 = params.net.B1;
|
||||
layer.B2 = params.net.B2;
|
||||
@ -819,6 +820,8 @@ network parse_network_cfg_custom(char *filename, int batch)
|
||||
#ifdef GPU
|
||||
if(gpu_index >= 0){
|
||||
net.workspace = cuda_make_array(0, workspace_size/sizeof(float) + 1);
|
||||
int size = get_network_input_size(net) * net.batch;
|
||||
net.input_state_gpu = cuda_make_array(0, size);
|
||||
}else {
|
||||
net.workspace = calloc(1, workspace_size);
|
||||
}
|
||||
|
@ -368,9 +368,11 @@ void backward_region_layer(const region_layer l, network_state state)
|
||||
|
||||
void get_region_boxes(layer l, int w, int h, float thresh, float **probs, box *boxes, int only_objectness, int *map)
|
||||
{
|
||||
int i,j,n;
|
||||
float *predictions = l.output;
|
||||
int i;
|
||||
float *const predictions = l.output;
|
||||
#pragma omp parallel for
|
||||
for (i = 0; i < l.w*l.h; ++i){
|
||||
int j, n;
|
||||
int row = i / l.w;
|
||||
int col = i % l.w;
|
||||
for(n = 0; n < l.n; ++n){
|
||||
|
Reference in New Issue
Block a user