diff --git a/Makefile b/Makefile index eee3c96c..bdbc73dd 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ OBJDIR=./obj/ CC=gcc NVCC=nvcc -OPTS=-O0 +OPTS=-O3 LDFLAGS=`pkg-config --libs opencv` -lm -pthread -lstdc++ COMMON=`pkg-config --cflags opencv` -I/usr/local/cuda/include/ CFLAGS=-Wall -Wfatal-errors diff --git a/src/col2im_kernels.cu b/src/col2im_kernels.cu index 2fa20305..76a86e65 100644 --- a/src/col2im_kernels.cu +++ b/src/col2im_kernels.cu @@ -3,60 +3,112 @@ extern "C" { #include "cuda.h" } -__global__ void col2im_kernel(float *data_col, - int channels, int height, int width, - int ksize, int stride, int pad, float *data_im) -{ +// src: https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu +// You may also want to read: https://github.com/BVLC/caffe/blob/master/LICENSE - int height_col = (height - ksize) / stride + 1; - int width_col = (width - ksize) / stride + 1; - if (pad){ - height_col = 1 + (height-1) / stride; - width_col = 1 + (width-1) / stride; - pad = ksize/2; - } - - int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; - if(id >= channels*height*width) return; - - int index = id; - int w = id%width + pad; - id /= width; - int h = id%height + pad; - id /= height; - int c = id%channels; - - int w_start = (w-ksize+stride)/stride; - int w_end = w/stride + 1; - - int h_start = (h-ksize+stride)/stride; - int h_end = h/stride + 1; - - // int rows = channels * ksize * ksize; - // int cols = height_col*width_col; - int col_offset = (c*ksize*ksize + h * ksize + w)*height_col*width_col; - int h_coeff = (1-stride*ksize*height_col)*width_col; - int w_coeff = 1-stride*height_col*width_col; - float val = 0; - int h_col, w_col; - for(h_col = h_start; h_col < h_end; ++h_col){ - for(w_col = w_start; w_col < w_end; ++w_col){ - int col_index = col_offset +h_col*h_coeff + w_col*w_coeff; - float part = (w_col < 0 || h_col < 0 || h_col >= height_col || w_col >= width_col) ? 0 : data_col[col_index]; - val += part; +__global__ void col2im_gpu_kernel(const int n, const float* data_col, + const int height, const int width, const int ksize, + const int pad, + const int stride, + const int height_col, const int width_col, + float *data_im) { + int index = blockIdx.x*blockDim.x+threadIdx.x; + for(; index < n; index += blockDim.x*gridDim.x){ + float val = 0; + int w = index % width + pad; + int h = (index / width) % height + pad; + int c = index / (width * height); + // compute the start and end of the output + int w_col_start = (w < ksize) ? 0 : (w - ksize) / stride + 1; + int w_col_end = min(w / stride + 1, width_col); + int h_col_start = (h < ksize) ? 0 : (h - ksize) / stride + 1; + int h_col_end = min(h / stride + 1, height_col); + // equivalent implementation + int offset = + (c * ksize * ksize + h * ksize + w) * height_col * width_col; + int coeff_h_col = (1 - stride * ksize * height_col) * width_col; + int coeff_w_col = (1 - stride * height_col * width_col); + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col]; + } } + data_im[index] = val; } - data_im[index] = val; +} + +void col2im_ongpu(float *im, + int channels, int height, int width, + int ksize, int stride, int pad, float *data_col){ + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + pad = pad ? ksize/2 : 0; + int height_col = (height + 2 * pad - ksize) / stride + 1; + int width_col = (width + 2 * pad - ksize) / stride + 1; + int num_kernels = channels * height * width; + col2im_gpu_kernel<<<(num_kernels+BLOCK-1)/BLOCK, + BLOCK>>>( + num_kernels, data_col, height, width, ksize, pad, + stride, height_col, + width_col, im); +} + +/* + __global__ void col2im_kernel(float *data_col, + int channels, int height, int width, + int ksize, int stride, int pad, float *data_im) + { + + int height_col = (height - ksize) / stride + 1; + int width_col = (width - ksize) / stride + 1; + if (pad){ + height_col = 1 + (height-1) / stride; + width_col = 1 + (width-1) / stride; + pad = ksize/2; + } + + int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if(id >= channels*height*width) return; + + int index = id; + int w = id%width + pad; + id /= width; + int h = id%height + pad; + id /= height; + int c = id%channels; + + int w_start = (w-ksize+stride)/stride; + int w_end = w/stride + 1; + + int h_start = (h-ksize+stride)/stride; + int h_end = h/stride + 1; + +// int rows = channels * ksize * ksize; +// int cols = height_col*width_col; +int col_offset = (c*ksize*ksize + h * ksize + w)*height_col*width_col; +int h_coeff = (1-stride*ksize*height_col)*width_col; +int w_coeff = 1-stride*height_col*width_col; +float val = 0; +int h_col, w_col; +for(h_col = h_start; h_col < h_end; ++h_col){ +for(w_col = w_start; w_col < w_end; ++w_col){ +int col_index = col_offset +h_col*h_coeff + w_col*w_coeff; +float part = (w_col < 0 || h_col < 0 || h_col >= height_col || w_col >= width_col) ? 0 : data_col[col_index]; +val += part; +} +} +data_im[index] = val; } extern "C" void col2im_ongpu(float *data_col, - int channels, int height, int width, - int ksize, int stride, int pad, float *data_im) +int channels, int height, int width, +int ksize, int stride, int pad, float *data_im) { - size_t n = channels*height*width; +size_t n = channels*height*width; - col2im_kernel<<>>(data_col, channels, height, width, ksize, stride, pad, data_im); - check_error(cudaPeekAtLastError()); +col2im_kernel<<>>(data_col, channels, height, width, ksize, stride, pad, data_im); +check_error(cudaPeekAtLastError()); } + */ diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index 864d7fa3..18a3b7d8 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -56,7 +56,7 @@ extern "C" void backward_bias_gpu(float *bias_updates, float *delta, int batch, extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state) { -clock_t time = clock(); +//clock_t time = clock(); int i; int m = layer.n; int k = layer.size*layer.size*layer.c; @@ -64,31 +64,31 @@ clock_t time = clock(); convolutional_out_width(layer); bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, n); -cudaDeviceSynchronize(); -printf("bias %f\n", sec(clock() - time)); -time = clock(); +//cudaDeviceSynchronize(); +//printf("bias %f\n", sec(clock() - time)); +//time = clock(); -float imt=0; -float gemt = 0; +//float imt=0; +//float gemt = 0; for(i = 0; i < layer.batch; ++i){ -time = clock(); +//time = clock(); im2col_ongpu(state.input + i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_gpu); -cudaDeviceSynchronize(); -imt += sec(clock()-time); -time = clock(); +//cudaDeviceSynchronize(); +//imt += sec(clock()-time); +//time = clock(); float * a = layer.filters_gpu; float * b = layer.col_image_gpu; float * c = layer.output_gpu; gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n); -cudaDeviceSynchronize(); -gemt += sec(clock()-time); -time = clock(); +//cudaDeviceSynchronize(); +//gemt += sec(clock()-time); +//time = clock(); } activate_array_ongpu(layer.output_gpu, m*n*layer.batch, layer.activation); -cudaDeviceSynchronize(); -printf("activate %f\n", sec(clock() - time)); -printf("im2col %f\n", imt); -printf("gemm %f\n", gemt); +//cudaDeviceSynchronize(); +//printf("activate %f\n", sec(clock() - time)); +//printf("im2col %f\n", imt); +//printf("gemm %f\n", gemt); } extern "C" void backward_convolutional_layer_gpu(convolutional_layer layer, network_state state) diff --git a/src/im2col_kernels.cu b/src/im2col_kernels.cu index a82c2dc5..d122748a 100644 --- a/src/im2col_kernels.cu +++ b/src/im2col_kernels.cu @@ -3,77 +3,127 @@ extern "C" { #include "cuda.h" } -__global__ void im2col_pad_kernel(float *im, - int channels, int height, int width, - int ksize, int stride, float *data_col) -{ - int c,h,w; - int height_col = 1 + (height-1) / stride; - int width_col = 1 + (width-1) / stride; - int channels_col = channels * ksize * ksize; +// src: https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu +// You may also want to read: https://github.com/BVLC/caffe/blob/master/LICENSE - int pad = ksize/2; - - int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; - int col_size = height_col*width_col*channels_col; - if (id >= col_size) return; - - int col_index = id; - w = id % width_col; - id /= width_col; - h = id % height_col; - id /= height_col; - c = id % channels_col; - id /= channels_col; - - int w_offset = c % ksize; - int h_offset = (c / ksize) % ksize; - int im_channel = c / ksize / ksize; - int im_row = h_offset + h * stride - pad; - int im_col = w_offset + w * stride - pad; - - int im_index = im_col + width*(im_row + height*im_channel); - float val = (im_row < 0 || im_col < 0 || im_row >= height || im_col >= width) ? 0 : im[im_index]; - - data_col[col_index] = val; +__global__ void im2col_gpu_kernel(const int n, const float* data_im, + const int height, const int width, const int ksize, + const int pad, + const int stride, + const int height_col, const int width_col, + float *data_col) { + int index = blockIdx.x*blockDim.x+threadIdx.x; + for(; index < n; index += blockDim.x*gridDim.x){ + int w_out = index % width_col; + int h_index = index / width_col; + int h_out = h_index % height_col; + int channel_in = h_index / height_col; + int channel_out = channel_in * ksize * ksize; + int h_in = h_out * stride - pad; + int w_in = w_out * stride - pad; + float* data_col_ptr = data_col; + data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out; + const float* data_im_ptr = data_im; + data_im_ptr += (channel_in * height + h_in) * width + w_in; + for (int i = 0; i < ksize; ++i) { + for (int j = 0; j < ksize; ++j) { + int h = h_in + i; + int w = w_in + j; + *data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ? + data_im_ptr[i * width + j] : 0; + data_col_ptr += height_col * width_col; + } + } + } } -__global__ void im2col_nopad_kernel(float *im, - int channels, int height, int width, - int ksize, int stride, float *data_col) -{ - int c,h,w; - int height_col = (height - ksize) / stride + 1; - int width_col = (width - ksize) / stride + 1; - int channels_col = channels * ksize * ksize; - - int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; - int col_size = height_col*width_col*channels_col; - if (id >= col_size) return; - - int col_index = id; - w = id % width_col; - id /= width_col; - h = id % height_col; - id /= height_col; - c = id % channels_col; - id /= channels_col; - - int w_offset = c % ksize; - int h_offset = (c / ksize) % ksize; - int im_channel = c / ksize / ksize; - int im_row = h_offset + h * stride; - int im_col = w_offset + w * stride; - - int im_index = im_col + width*(im_row + height*im_channel); - float val = (im_row < 0 || im_col < 0 || im_row >= height || im_col >= width) ? 0 : im[im_index]; - - data_col[col_index] = val; +void im2col_ongpu(float *im, + int channels, int height, int width, + int ksize, int stride, int pad, float *data_col){ + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + pad = pad ? ksize/2 : 0; + int height_col = (height + 2 * pad - ksize) / stride + 1; + int width_col = (width + 2 * pad - ksize) / stride + 1; + int num_kernels = channels * height_col * width_col; + im2col_gpu_kernel<<<(num_kernels+BLOCK-1)/BLOCK, + BLOCK>>>( + num_kernels, im, height, width, ksize, pad, + stride, height_col, + width_col, data_col); } +/* + __global__ void im2col_pad_kernel(float *im, + int channels, int height, int width, + int ksize, int stride, float *data_col) + { + int c,h,w; + int height_col = 1 + (height-1) / stride; + int width_col = 1 + (width-1) / stride; + int channels_col = channels * ksize * ksize; -extern "C" void im2col_ongpu(float *im, - int channels, int height, int width, - int ksize, int stride, int pad, float *data_col) + int pad = ksize/2; + + int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + int col_size = height_col*width_col*channels_col; + if (id >= col_size) return; + + int col_index = id; + w = id % width_col; + id /= width_col; + h = id % height_col; + id /= height_col; + c = id % channels_col; + id /= channels_col; + + int w_offset = c % ksize; + int h_offset = (c / ksize) % ksize; + int im_channel = c / ksize / ksize; + int im_row = h_offset + h * stride - pad; + int im_col = w_offset + w * stride - pad; + + int im_index = im_col + width*(im_row + height*im_channel); + float val = (im_row < 0 || im_col < 0 || im_row >= height || im_col >= width) ? 0 : im[im_index]; + + data_col[col_index] = val; + } + + __global__ void im2col_nopad_kernel(float *im, + int channels, int height, int width, + int ksize, int stride, float *data_col) + { + int c,h,w; + int height_col = (height - ksize) / stride + 1; + int width_col = (width - ksize) / stride + 1; + int channels_col = channels * ksize * ksize; + + int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + int col_size = height_col*width_col*channels_col; + if (id >= col_size) return; + + int col_index = id; + w = id % width_col; + id /= width_col; + h = id % height_col; + id /= height_col; + c = id % channels_col; + id /= channels_col; + + int w_offset = c % ksize; + int h_offset = (c / ksize) % ksize; + int im_channel = c / ksize / ksize; + int im_row = h_offset + h * stride; + int im_col = w_offset + w * stride; + + int im_index = im_col + width*(im_row + height*im_channel); + float val = (im_row < 0 || im_col < 0 || im_row >= height || im_col >= width) ? 0 : im[im_index]; + + data_col[col_index] = val; + } + + extern "C" void im2col_ongpu(float *im, + int channels, int height, int width, +int ksize, int stride, int pad, float *data_col) { int height_col = (height - ksize) / stride + 1; @@ -91,3 +141,4 @@ extern "C" void im2col_ongpu(float *im, else im2col_nopad_kernel<<>>(im, channels, height, width, ksize, stride, data_col); check_error(cudaPeekAtLastError()); } +*/ diff --git a/src/imagenet.c b/src/imagenet.c index 7da73a09..9118c084 100644 --- a/src/imagenet.c +++ b/src/imagenet.c @@ -13,7 +13,7 @@ void train_imagenet(char *cfgfile, char *weightfile) load_weights(&net, weightfile); } printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); - int imgs = 128; + int imgs = 1024; int i = net.seen/imgs; char **labels = get_labels("/home/pjreddie/data/imagenet/cls.labels.list"); list *plist = get_paths("/data/imagenet/cls.train.list");