diff --git a/include/darknet.h b/include/darknet.h index 1ab2f339..4154d664 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -204,6 +204,7 @@ struct layer { int size; int side; int stride; + int dilation; int reverse; int flatten; int spatial; diff --git a/src/col2im.c b/src/col2im.c index 925d054c..660b0e58 100644 --- a/src/col2im.c +++ b/src/col2im.c @@ -37,3 +37,56 @@ void col2im_cpu(float* data_col, } } } +// ---------------------------------------- +void caffe_set(const int N, const float alpha, float* Y) { + if (alpha == 0) { + memset(Y, 0, sizeof(float) * N); // NOLINT(caffe/alt_fn) + return; + } + for (int i = 0; i < N; ++i) { + Y[i] = alpha; + } +} + +inline int is_a_ge_zero_and_a_lt_b(int a, int b) { + return (unsigned)(a) < (unsigned)(b); +} + +// https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cpp +void col2im_cpu_ext(const float* data_col, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + float* data_im) +{ + caffe_set(height * width * channels, 0.0F, data_im); + const int output_h = (height + 2 * pad_h - + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int output_w = (width + 2 * pad_w - + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + const int channel_size = height * width; + for (int channel = channels; channel--; data_im += channel_size) { + for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_row = -pad_h + kernel_row * dilation_h; + for (int output_rows = output_h; output_rows; output_rows--) { + if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { + data_col += output_w; + } + else { + int input_col = -pad_w + kernel_col * dilation_w; + for (int output_col = output_w; output_col; output_col--) { + if (is_a_ge_zero_and_a_lt_b(input_col, width)) { + data_im[input_row * width + input_col] += *data_col; + } + data_col++; + input_col += stride_w; + } + } + input_row += stride_h; + } + } + } + } +} \ No newline at end of file diff --git a/src/col2im.h b/src/col2im.h index a8493e38..984f7c4b 100644 --- a/src/col2im.h +++ b/src/col2im.h @@ -8,10 +8,24 @@ void col2im_cpu(float* data_col, int channels, int height, int width, int ksize, int stride, int pad, float* data_im); +void col2im_cpu_ext(const float* data_col, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + float* data_im); + #ifdef GPU void col2im_ongpu(float *data_col, int channels, int height, int width, int ksize, int stride, int pad, float *data_im); + + +void col2im_gpu_ext(const float* data_col, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + float* data_im); #endif #ifdef __cplusplus } diff --git a/src/col2im_kernels.cu b/src/col2im_kernels.cu index a6e985f1..0e07bc37 100644 --- a/src/col2im_kernels.cu +++ b/src/col2im_kernels.cu @@ -55,3 +55,82 @@ void col2im_ongpu(float *data_col, CHECK_CUDA(cudaPeekAtLastError()); } +// ----------------------------------------- + +// CUDA: use 512 threads per block +const int CAFFE_CUDA_NUM_THREADS = 512; + +// CUDA: number of blocks for threads. +inline int CAFFE_GET_BLOCKS(const int N) { + return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS; +} + +// CUDA: grid stride looping +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +// https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu +__global__ void col2im_gpu_kernel_ext(const int n, const float* data_col, + const int height, const int width, const int channels, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int height_col, const int width_col, + float* data_im) { + CUDA_KERNEL_LOOP(index, n) { + float val = 0; + const int w_im = index % width + pad_w; + const int h_im = (index / width) % height + pad_h; + const int c_im = index / (width * height); + int kernel_extent_w = (kernel_w - 1) * dilation_w + 1; + int kernel_extent_h = (kernel_h - 1) * dilation_h + 1; + // compute the start and end of the output + const int w_col_start = + (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1; + const int w_col_end = min(w_im / stride_w + 1, width_col); + const int h_col_start = + (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1; + const int h_col_end = min(h_im / stride_h + 1, height_col); + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) { + int h_k = (h_im - h_col * stride_h); + int w_k = (w_im - w_col * stride_w); + if (h_k % dilation_h == 0 && w_k % dilation_w == 0) { + h_k /= dilation_h; + w_k /= dilation_w; + int data_col_index = (((c_im * kernel_h + h_k) * kernel_w + w_k) * + height_col + h_col) * width_col + w_col; + val += data_col[data_col_index]; + } + } + } + data_im[index] = val; + } +} + +void col2im_gpu_ext(const float* data_col, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + float* data_im) +{ + int height_col = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / + stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / + stride_w + 1; + int num_kernels = channels * height * width; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // NOLINT_NEXT_LINE(whitespace/operators) + col2im_gpu_kernel_ext<< > >( + num_kernels, data_col, height, width, channels, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + height_col, width_col, data_im); + + CHECK_CUDA(cudaPeekAtLastError()); +} \ No newline at end of file diff --git a/src/conv_lstm_layer.c b/src/conv_lstm_layer.c index f3041ced..5caa3754 100644 --- a/src/conv_lstm_layer.c +++ b/src/conv_lstm_layer.c @@ -32,7 +32,7 @@ static void increment_layer(layer *l, int steps) } -layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor) +layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor) { fprintf(stderr, "CONV_LSTM Layer: %d x %d x %d image, %d filters\n", h, w, c, output_filters); /* @@ -53,6 +53,7 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i l.steps = steps; l.size = size; l.stride = stride; + l.dilation = dilation; l.pad = pad; l.h = h; l.w = w; @@ -65,44 +66,44 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i // U l.uf = (layer*)calloc(1, sizeof(layer)); - *(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.uf->batch = batch; if (l.workspace_size < l.uf->workspace_size) l.workspace_size = l.uf->workspace_size; l.ui = (layer*)calloc(1, sizeof(layer)); - *(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.ui->batch = batch; if (l.workspace_size < l.ui->workspace_size) l.workspace_size = l.ui->workspace_size; l.ug = (layer*)calloc(1, sizeof(layer)); - *(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.ug->batch = batch; if (l.workspace_size < l.ug->workspace_size) l.workspace_size = l.ug->workspace_size; l.uo = (layer*)calloc(1, sizeof(layer)); - *(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.uo->batch = batch; if (l.workspace_size < l.uo->workspace_size) l.workspace_size = l.uo->workspace_size; // W l.wf = (layer*)calloc(1, sizeof(layer)); - *(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.wf->batch = batch; if (l.workspace_size < l.wf->workspace_size) l.workspace_size = l.wf->workspace_size; l.wi = (layer*)calloc(1, sizeof(layer)); - *(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.wi->batch = batch; if (l.workspace_size < l.wi->workspace_size) l.workspace_size = l.wi->workspace_size; l.wg = (layer*)calloc(1, sizeof(layer)); - *(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.wg->batch = batch; if (l.workspace_size < l.wg->workspace_size) l.workspace_size = l.wg->workspace_size; l.wo = (layer*)calloc(1, sizeof(layer)); - *(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.wo->batch = batch; if (l.workspace_size < l.wo->workspace_size) l.workspace_size = l.wo->workspace_size; @@ -110,21 +111,21 @@ layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, i // V l.vf = (layer*)calloc(1, sizeof(layer)); if (l.peephole) { - *(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.vf->batch = batch; if (l.workspace_size < l.vf->workspace_size) l.workspace_size = l.vf->workspace_size; } l.vi = (layer*)calloc(1, sizeof(layer)); if (l.peephole) { - *(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.vi->batch = batch; if (l.workspace_size < l.vi->workspace_size) l.workspace_size = l.vi->workspace_size; } l.vo = (layer*)calloc(1, sizeof(layer)); if (l.peephole) { - *(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.vo->batch = batch; if (l.workspace_size < l.vo->workspace_size) l.workspace_size = l.vo->workspace_size; } diff --git a/src/conv_lstm_layer.h b/src/conv_lstm_layer.h index 56a57298..17e4fdc3 100644 --- a/src/conv_lstm_layer.h +++ b/src/conv_lstm_layer.h @@ -9,7 +9,7 @@ #ifdef __cplusplus extern "C" { #endif -layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor); +layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor); void resize_conv_lstm_layer(layer *l, int w, int h); void free_state_conv_lstm(layer l); void randomize_state_conv_lstm(layer l); diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index a26b95ea..dd6625f1 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -566,7 +566,17 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) b = im; } else { - im2col_ongpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, state.workspace); + //im2col_ongpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, state.workspace); + + im2col_gpu_ext(im, // input + l.c / l.groups, // input channels + l.h, l.w, // input size (h, w) + l.size, l.size, // kernel size (h, w) + l.pad, l.pad, // padding (h, w) + l.stride, l.stride, // stride (h, w) + l.dilation, l.dilation, // dilation (h, w) + state.workspace); // output + } gemm_ongpu(0, 0, m, n, k, 1., a, k, b, n, 1., c + i*m*n, n); } @@ -798,7 +808,15 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state float *im = state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w; - im2col_ongpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, state.workspace); + //im2col_ongpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, state.workspace); + im2col_gpu_ext(im, // input + l.c / l.groups, // input channels + l.h, l.w, // input size (h, w) + l.size, l.size, // kernel size (h, w) + l.pad, l.pad, // padding (h, w) + l.stride, l.stride, // stride (h, w) + l.dilation, l.dilation, // dilation (h, w) + state.workspace); // output gemm_ongpu(0, 1, m, n, k, 1, a + i*m*k, k, b, k, 1, c, n); if (state.delta) { @@ -811,7 +829,17 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state float *delta = state.delta + (i*l.groups + j)*l.c / l.groups*l.h*l.w; - col2im_ongpu(state.workspace, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, delta); + //col2im_ongpu(state.workspace, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, delta); + col2im_gpu_ext( + state.workspace, // input + l.c / l.groups, // input channels + l.h, l.w, // input size (h, w) + l.size, l.size, // kernel size (h, w) + l.pad, l.pad, // padding size (h, w) + l.stride, l.stride, // stride size (h, w) + l.dilation, l.dilation, // dilation size (h, w) + delta); // output (delta) + if (l.binary || l.xnor) { swap_binary(&l); } diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index d983ab61..84bc75fb 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -275,9 +275,9 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference) CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w)); #if(CUDNN_MAJOR >= 6) - CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT)); // cudnn >= 6.0 + CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, l->dilation, l->dilation, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT)); // cudnn >= 6.0 #else - CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION)); // cudnn 5.1 + CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, l->dilation, l->dilation, CUDNN_CROSS_CORRELATION)); // cudnn 5.1 #endif int forward_algo = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST; int backward_algo = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST; @@ -331,7 +331,7 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference) #endif #endif -convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index) +convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index) { int total_batch = batch*steps; int i; @@ -353,6 +353,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, l.batch = batch; l.steps = steps; l.stride = stride; + l.dilation = dilation; l.size = size; l.pad = padding; l.batch_normalize = batch_normalize; @@ -525,7 +526,7 @@ void denormalize_convolutional_layer(convolutional_layer l) void test_convolutional_layer() { - convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 1, LEAKY, 1, 0, 0, 0, 0, 0); + convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 1, 1, LEAKY, 1, 0, 0, 0, 0, 0); l.batch_normalize = 1; float data[] = {1,1,1,1,1, 1,1,1,1,1, @@ -981,8 +982,17 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) } else { //printf(" l.index = %d - FP32 \n", l.index); - im2col_cpu(state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w, - l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, b); + float *im = state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w; + //im2col_cpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, b); + + im2col_cpu_ext(im, // input + l.c / l.groups, // input channels + l.h, l.w, // input size (h, w) + l.size, l.size, // kernel size (h, w) + l.pad, l.pad, // padding (h, w) + l.stride, l.stride, // stride (h, w) + l.dilation, l.dilation, // dilation (h, w) + b); // output gemm(0, 0, m, n, k, 1, a, k, b, n, 1, c, n); // bit-count to float @@ -1028,8 +1038,17 @@ void backward_convolutional_layer(convolutional_layer l, network_state state) float *im = state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w; - im2col_cpu(im, l.c / l.groups, l.h, l.w, - l.size, l.stride, l.pad, b); + //im2col_cpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, b); + im2col_cpu_ext( + im, // input + l.c / l.groups, // input channels + l.h, l.w, // input size (h, w) + l.size, l.size, // kernel size (h, w) + l.pad, l.pad, // padding (h, w) + l.stride, l.stride, // stride (h, w) + l.dilation, l.dilation, // dilation (h, w) + b); // output + gemm(0, 1, m, n, k, 1, a, k, b, k, 1, c, n); if (state.delta) { @@ -1039,8 +1058,18 @@ void backward_convolutional_layer(convolutional_layer l, network_state state) gemm(1, 0, n, k, m, 1, a, n, b, k, 0, c, k); - col2im_cpu(state.workspace, l.c / l.groups, l.h, l.w, l.size, l.stride, - l.pad, state.delta + (i*l.groups + j)*l.c / l.groups*l.h*l.w); + //col2im_cpu(state.workspace, l.c / l.groups, l.h, l.w, l.size, l.stride, + // l.pad, state.delta + (i*l.groups + j)*l.c / l.groups*l.h*l.w); + + col2im_cpu_ext( + state.workspace, // input + l.c / l.groups, // input channels (h, w) + l.h, l.w, // input size (h, w) + l.size, l.size, // kernel size (h, w) + l.pad, l.pad, // padding (h, w) + l.stride, l.stride, // stride (h, w) + l.dilation, l.dilation, // dilation (h, w) + state.delta + (i*l.groups + j)*l.c / l.groups*l.h*l.w); // output (delta) } } } diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index dc00dabf..586bd2cd 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -30,7 +30,7 @@ void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16); #endif size_t get_convolutional_workspace_size(layer l); -convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index); +convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index); void denormalize_convolutional_layer(convolutional_layer l); void resize_convolutional_layer(convolutional_layer *layer, int w, int h); void forward_convolutional_layer(const convolutional_layer layer, network_state state); diff --git a/src/crnn_layer.c b/src/crnn_layer.c index 8534c69a..d7c75b50 100644 --- a/src/crnn_layer.c +++ b/src/crnn_layer.c @@ -26,7 +26,7 @@ static void increment_layer(layer *l, int steps) #endif } -layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int xnor) +layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int xnor) { fprintf(stderr, "CRNN Layer: %d x %d x %d image, %d filters\n", h,w,c,output_filters); batch = batch / steps; @@ -36,6 +36,7 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou l.steps = steps; l.size = size; l.stride = stride; + l.dilation = dilation; l.pad = pad; l.h = h; l.w = w; @@ -49,17 +50,17 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou l.state = (float*)calloc(l.hidden * l.batch * (l.steps + 1), sizeof(float)); l.input_layer = (layer*)calloc(1, sizeof(layer)); - *(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.input_layer->batch = batch; if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size; l.self_layer = (layer*)calloc(1, sizeof(layer)); - *(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.self_layer->batch = batch; if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size; l.output_layer = (layer*)calloc(1, sizeof(layer)); - *(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + *(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, dilation, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.output_layer->batch = batch; if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size; diff --git a/src/crnn_layer.h b/src/crnn_layer.h index 55feb599..33560aae 100644 --- a/src/crnn_layer.h +++ b/src/crnn_layer.h @@ -9,7 +9,7 @@ #ifdef __cplusplus extern "C" { #endif -layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int xnor); +layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int dilation, int pad, ACTIVATION activation, int batch_normalize, int xnor); void resize_crnn_layer(layer *l, int w, int h); void free_state_crnn(layer l); diff --git a/src/im2col.c b/src/im2col.c index 40b5251d..ee08405a 100644 --- a/src/im2col.c +++ b/src/im2col.c @@ -37,3 +37,56 @@ void im2col_cpu(float* data_im, } } } + + +// Function uses casting from int to unsigned to compare if value of +// parameter a is greater or equal to zero and lower than value of +// parameter b. The b parameter is of type signed and is always positive, +// therefore its value is always lower than 0x800... where casting +// negative value of a parameter converts it to value higher than 0x800... +// The casting allows to use one condition instead of two. +inline int is_a_ge_zero_and_a_lt_b(int a, int b) { + return (unsigned)(a) < (unsigned)(b); +} + +// https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cpp +void im2col_cpu_ext(const float* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + float* data_col) +{ + const int output_h = (height + 2 * pad_h - + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int output_w = (width + 2 * pad_w - + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + const int channel_size = height * width; + for (int channel = channels; channel--; data_im += channel_size) { + for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_row = -pad_h + kernel_row * dilation_h; + for (int output_rows = output_h; output_rows; output_rows--) { + if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { + for (int output_cols = output_w; output_cols; output_cols--) { + *(data_col++) = 0; + } + } + else { + int input_col = -pad_w + kernel_col * dilation_w; + for (int output_col = output_w; output_col; output_col--) { + if (is_a_ge_zero_and_a_lt_b(input_col, width)) { + *(data_col++) = data_im[input_row * width + input_col]; + } + else { + *(data_col++) = 0; + } + input_col += stride_w; + } + } + input_row += stride_h; + } + } + } + } +} diff --git a/src/im2col.h b/src/im2col.h index 35f0039a..65dd6ec8 100644 --- a/src/im2col.h +++ b/src/im2col.h @@ -14,12 +14,26 @@ void im2col_cpu(float* data_im, float im2col_get_pixel(float* im, int height, int width, int channels, int row, int col, int channel, int pad); +void im2col_cpu_ext(const float* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + float* data_col); + #ifdef GPU void im2col_ongpu(float *im, int channels, int height, int width, int ksize, int stride, int pad,float *data_col); +void im2col_gpu_ext(const float* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + float* data_col); + void im2col_align_ongpu(float *im, int channels, int height, int width, int ksize, int stride, int pad, float *data_col, int bit_align); diff --git a/src/im2col_kernels.cu b/src/im2col_kernels.cu index 954a4694..54961d3a 100644 --- a/src/im2col_kernels.cu +++ b/src/im2col_kernels.cu @@ -2214,3 +2214,75 @@ void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int } // -------------------------------- + +// CUDA: use 512 threads per block +const int CAFFE_CUDA_NUM_THREADS = 512; + +// CUDA: number of blocks for threads. +inline int CAFFE_GET_BLOCKS(const int N) { + return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS; +} + +// CUDA: grid stride looping +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +// https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu +__global__ void im2col_gpu_kernel_ext(const int n, const float* data_im, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int height_col, const int width_col, + float* data_col) { + CUDA_KERNEL_LOOP(index, n) { + const int h_index = index / width_col; + const int h_col = h_index % height_col; + const int w_col = index % width_col; + const int c_im = h_index / height_col; + const int c_col = c_im * kernel_h * kernel_w; + const int h_offset = h_col * stride_h - pad_h; + const int w_offset = w_col * stride_w - pad_w; + float* data_col_ptr = data_col; + data_col_ptr += (c_col * height_col + h_col) * width_col + w_col; + const float* data_im_ptr = data_im; + data_im_ptr += (c_im * height + h_offset) * width + w_offset; + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + int h_im = h_offset + i * dilation_h; + int w_im = w_offset + j * dilation_w; + *data_col_ptr = + (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ? + data_im_ptr[i * dilation_h * width + j * dilation_w] : 0; + data_col_ptr += height_col * width_col; + } + } + } +} + + +void im2col_gpu_ext(const float* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + float* data_col) +{ + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + int height_col = (height + 2 * pad_h - + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col; + // NOLINT_NEXT_LINE(whitespace/operators) + im2col_gpu_kernel_ext << > >( + num_kernels, data_im, height, width, kernel_h, kernel_w, pad_h, + pad_w, stride_h, stride_w, dilation_h, dilation_w, height_col, + width_col, data_col); + + CHECK_CUDA(cudaPeekAtLastError()); +} \ No newline at end of file diff --git a/src/parser.c b/src/parser.c index b37b2873..0cb906a5 100644 --- a/src/parser.c +++ b/src/parser.c @@ -153,6 +153,7 @@ convolutional_layer parse_convolutional(list *options, size_params params) int groups = option_find_int_quiet(options, "groups", 1); int size = option_find_int(options, "size",1); int stride = option_find_int(options, "stride",1); + int dilation = option_find_int_quiet(options, "dilation", 1); int pad = option_find_int_quiet(options, "pad",0); int padding = option_find_int_quiet(options, "padding",0); if(pad) padding = size/2; @@ -171,7 +172,7 @@ convolutional_layer parse_convolutional(list *options, size_params params) int xnor = option_find_int_quiet(options, "xnor", 0); int use_bin_output = option_find_int_quiet(options, "bin_output", 0); - convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index); + convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride,dilation,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index); layer.flipped = option_find_int_quiet(options, "flipped", 0); layer.dot = option_find_float_quiet(options, "dot", 0); @@ -188,6 +189,7 @@ layer parse_crnn(list *options, size_params params) { int size = option_find_int_quiet(options, "size", 3); int stride = option_find_int_quiet(options, "stride", 1); + int dilation = option_find_int_quiet(options, "dilation", 1); int pad = option_find_int_quiet(options, "pad", 0); int padding = option_find_int_quiet(options, "padding", 0); if (pad) padding = size / 2; @@ -200,7 +202,7 @@ layer parse_crnn(list *options, size_params params) int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0); int xnor = option_find_int_quiet(options, "xnor", 0); - layer l = make_crnn_layer(params.batch, params.h, params.w, params.c, hidden_filters, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, xnor); + layer l = make_crnn_layer(params.batch, params.h, params.w, params.c, hidden_filters, output_filters, groups, params.time_steps, size, stride, dilation, padding, activation, batch_normalize, xnor); l.shortcut = option_find_int_quiet(options, "shortcut", 0); @@ -248,6 +250,7 @@ layer parse_conv_lstm(list *options, size_params params) // a ConvLSTM with a larger transitional kernel should be able to capture faster motions int size = option_find_int_quiet(options, "size", 3); int stride = option_find_int_quiet(options, "stride", 1); + int dilation = option_find_int_quiet(options, "dilation", 1); int pad = option_find_int_quiet(options, "pad", 0); int padding = option_find_int_quiet(options, "padding", 0); if (pad) padding = size / 2; @@ -260,7 +263,7 @@ layer parse_conv_lstm(list *options, size_params params) int xnor = option_find_int_quiet(options, "xnor", 0); int peephole = option_find_int_quiet(options, "peephole", 0); - layer l = make_conv_lstm_layer(params.batch, params.h, params.w, params.c, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, peephole, xnor); + layer l = make_conv_lstm_layer(params.batch, params.h, params.w, params.c, output_filters, groups, params.time_steps, size, stride, dilation, padding, activation, batch_normalize, peephole, xnor); l.state_constrain = option_find_int_quiet(options, "state_constrain", params.time_steps * 32); l.shortcut = option_find_int_quiet(options, "shortcut", 0);