From 76ee68f96d864a27312c9aa09856ddda559a5cd9 Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Wed, 27 Aug 2014 19:11:46 -0700 Subject: [PATCH] Trying some stuff w/ dropout --- src/activations.c | 51 +++++++++++-- src/activations.cl | 33 +++++++- src/activations.h | 2 + src/cnn.c | 55 +++++++++++++- src/col2im.c | 81 +++++++++++++++----- src/col2im.cl | 41 ++++++++++ src/convolutional_layer.c | 152 ++++++++++++++++++++++++++++++++----- src/convolutional_layer.cl | 25 ++++++ src/convolutional_layer.h | 1 + src/data.c | 16 +++- src/data.h | 1 + src/gemm.c | 14 +++- src/im2col.c | 84 ++++++-------------- src/im2col.cl | 47 ++++++++---- src/mini_blas.h | 29 ++++--- src/network.c | 53 ++----------- src/network.h | 4 + src/opencl.c | 61 +++++++++++++-- 18 files changed, 550 insertions(+), 200 deletions(-) create mode 100644 src/convolutional_layer.cl diff --git a/src/activations.c b/src/activations.c index 04b27c92..4a4bd3fe 100644 --- a/src/activations.c +++ b/src/activations.c @@ -41,6 +41,12 @@ float relu_activate(float x){return x*(x>0);} float ramp_activate(float x){return x*(x>0)+.1*x;} float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);} +float linear_gradient(float x){return 1;} +float sigmoid_gradient(float x){return (1-x)*x;} +float relu_gradient(float x){return (x>0);} +float ramp_gradient(float x){return (x>0)+.1;} +float tanh_gradient(float x){return 1-x*x;} + float activate(float x, ACTIVATION a) { switch(a){ @@ -66,19 +72,19 @@ void activate_array(float *x, const int n, const ACTIVATION a) } } - -float gradient(float x, ACTIVATION a){ +float gradient(float x, ACTIVATION a) +{ switch(a){ case LINEAR: - return 1; + return linear_gradient(x); case SIGMOID: - return (1.-x)*x; + return sigmoid_gradient(x); case RELU: - return (x>0); + return relu_gradient(x); case RAMP: - return (x>0) + .1; + return ramp_gradient(x); case TANH: - return 1-x*x; + return tanh_gradient(x); } return 0; } @@ -107,7 +113,6 @@ cl_kernel get_activation_kernel() return kernel; } - void activate_array_ongpu(cl_mem x, int n, ACTIVATION a) { cl_setup(); @@ -125,4 +130,34 @@ void activate_array_ongpu(cl_mem x, int n, ACTIVATION a) clEnqueueNDRangeKernel(queue, kernel, 1, 0, &gsize, 0, 0, 0, 0); check_error(cl); } + +cl_kernel get_gradient_kernel() +{ + static int init = 0; + static cl_kernel kernel; + if(!init){ + kernel = get_kernel("src/activations.cl", "gradient_array", 0); + init = 1; + } + return kernel; +} + +void gradient_array_ongpu(cl_mem x, int n, ACTIVATION a, cl_mem delta) +{ + cl_setup(); + cl_kernel kernel = get_gradient_kernel(); + cl_command_queue queue = cl.queue; + + cl_uint i = 0; + cl.error = clSetKernelArg(kernel, i++, sizeof(x), (void*) &x); + cl.error = clSetKernelArg(kernel, i++, sizeof(n), (void*) &n); + cl.error = clSetKernelArg(kernel, i++, sizeof(a), (void*) &a); + cl.error = clSetKernelArg(kernel, i++, sizeof(delta), (void*) &delta); + check_error(cl); + + size_t gsize = n; + + clEnqueueNDRangeKernel(queue, kernel, 1, 0, &gsize, 0, 0, 0, 0); + check_error(cl); +} #endif diff --git a/src/activations.cl b/src/activations.cl index 65131c55..da06e8a2 100644 --- a/src/activations.cl +++ b/src/activations.cl @@ -8,6 +8,12 @@ float relu_activate(float x){return x*(x>0);} float ramp_activate(float x){return x*(x>0)+.1*x;} float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);} +float linear_gradient(float x){return 1;} +float sigmoid_gradient(float x){return (1-x)*x;} +float relu_gradient(float x){return (x>0);} +float ramp_gradient(float x){return (x>0)+.1;} +float tanh_gradient(float x){return 1-x*x;} + float activate(float x, ACTIVATION a) { switch(a){ @@ -25,9 +31,32 @@ float activate(float x, ACTIVATION a) return 0; } -__kernel void activate_array(__global float *x, - const int n, const ACTIVATION a) +float gradient(float x, ACTIVATION a) +{ + switch(a){ + case LINEAR: + return linear_gradient(x); + case SIGMOID: + return sigmoid_gradient(x); + case RELU: + return relu_gradient(x); + case RAMP: + return ramp_gradient(x); + case TANH: + return tanh_gradient(x); + } + return 0; +} + +__kernel void activate_array(__global float *x, int n, ACTIVATION a) { int i = get_global_id(0); x[i] = activate(x[i], a); } + +__kernel void gradient_array(__global float *x, int n, ACTIVATION a, __global float *delta) +{ + int i = get_global_id(0); + delta[i] *= gradient(x[i], a); +} + diff --git a/src/activations.h b/src/activations.h index 8c4287e0..c406c181 100644 --- a/src/activations.h +++ b/src/activations.h @@ -14,7 +14,9 @@ float gradient(float x, ACTIVATION a); void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta); void activate_array(float *x, const int n, const ACTIVATION a); #ifdef GPU +cl_kernel get_activation_kernel(); void activate_array_ongpu(cl_mem x, int n, ACTIVATION a); +void gradient_array_ongpu(cl_mem x, int n, ACTIVATION a, cl_mem delta); #endif #endif diff --git a/src/cnn.c b/src/cnn.c index 72ad4a14..0cd6da30 100644 --- a/src/cnn.c +++ b/src/cnn.c @@ -32,6 +32,51 @@ void test_convolve() show_image_layers(edge, "Test Convolve"); } +#ifdef GPU + +void test_convolutional_layer() +{ + int i; + image dog = load_image("data/dog.jpg",256,256); + network net = parse_network_cfg("cfg/convolutional.cfg"); +// data test = load_cifar10_data("data/cifar10/test_batch.bin"); +// float *X = calloc(net.batch*test.X.cols, sizeof(float)); +// float *y = calloc(net.batch*test.y.cols, sizeof(float)); + int in_size = get_network_input_size(net)*net.batch; + int size = get_network_output_size(net)*net.batch; +float *X = calloc(in_size, sizeof(float)); + for(i = 0; i < in_size; ++i){ + X[i] = dog.data[i%get_network_input_size(net)]; + } +// get_batch(test, net.batch, X, y); + clock_t start, end; + cl_mem input_cl = cl_make_array(X, in_size); + + forward_network_gpu(net, input_cl, 1); + start = clock(); + forward_network_gpu(net, input_cl, 1); + end = clock(); + float gpu_sec = (float)(end-start)/CLOCKS_PER_SEC; + float *gpu_out = calloc(size, sizeof(float)); + memcpy(gpu_out, get_network_output(net), size*sizeof(float)); + + start = clock(); + forward_network(net, X, 1); + end = clock(); + float cpu_sec = (float)(end-start)/CLOCKS_PER_SEC; + float *cpu_out = calloc(size, sizeof(float)); + memcpy(cpu_out, get_network_output(net), size*sizeof(float)); + + float sum = 0; + for(i = 0; i < size; ++i) { + //printf("%f, %f\n", gpu_out[i], cpu_out[i]); + sum += pow(gpu_out[i] - cpu_out[i], 2); + } + printf("gpu: %f sec, cpu: %f sec, diff: %f, size: %d\n", gpu_sec, cpu_sec, sum, size); +} + +#endif + void test_convolve_matrix() { image dog = load_image("dog.jpg",300,400); @@ -325,7 +370,7 @@ void test_nist() void train_nist() { srand(222222); - network net = parse_network_cfg("cfg/nist_final.cfg"); + network net = parse_network_cfg("cfg/nist.cfg"); data train = load_categorical_data_csv("data/mnist/mnist_train.csv", 0, 10); data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10); translate_data_rows(train, -144); @@ -349,7 +394,7 @@ void train_nist() mean_array(get_network_output_layer(net,3), 100), mean_array(get_network_output_layer(net,4), 100)); */ - save_network(net, "cfg/nist_final2.cfg"); + //save_network(net, "cfg/nist_final2.cfg"); //printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*1000, loss, lr, momentum, decay); //end = clock(); @@ -798,7 +843,7 @@ int main(int argc, char *argv[]) { //train_full(); //test_distribution(); - feenableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW); + //feenableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW); //test_blas(); //test_visualize(); @@ -809,7 +854,9 @@ int main(int argc, char *argv[]) //test_split(); //test_ensemble(); //test_nist_single(); - test_nist(); + //test_nist(); + train_nist(); + //test_convolutional_layer(); //test_cifar10(); //train_cifar10(); //test_vince(); diff --git a/src/col2im.c b/src/col2im.c index fd7de4fa..c418fa54 100644 --- a/src/col2im.c +++ b/src/col2im.c @@ -1,21 +1,21 @@ #include #include inline void col2im_add_pixel(float *im, int height, int width, int channels, - int row, int col, int channel, int pad, float val) + int b, int row, int col, int channel, int pad, float val) { row -= pad; col -= pad; if (row < 0 || col < 0 || row >= height || col >= width) return; - im[col + width*(row + channel*height)] += val; + im[col + width*(row + height*(channel+b*channels))] += val; } //This one might be too, can't remember. -void col2im_cpu(float* data_col, - const int channels, const int height, const int width, - const int ksize, const int stride, int pad, float* data_im) +void col2im_cpu(float* data_col, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, float* data_im) { - int c,h,w; + int b,c,h,w; int height_col = (height - ksize) / stride + 1; int width_col = (width - ksize) / stride + 1; if (pad){ @@ -24,20 +24,67 @@ void col2im_cpu(float* data_col, pad = ksize/2; } int channels_col = channels * ksize * ksize; - for (c = 0; c < channels_col; ++c) { - int w_offset = c % ksize; - int h_offset = (c / ksize) % ksize; - int c_im = c / ksize / ksize; - for (h = 0; h < height_col; ++h) { - for (w = 0; w < width_col; ++w) { - int im_row = h_offset + h * stride; - int im_col = w_offset + w * stride; - double val = data_col[(c * height_col + h) * width_col + w]; - col2im_add_pixel(data_im, height, width, channels, - im_row, im_col, c_im, pad, val); + int col_size = height_col*width_col*channels_col; + for(b = 0; b < batch; ++b){ + for (c = 0; c < channels_col; ++c) { + int w_offset = c % ksize; + int h_offset = (c / ksize) % ksize; + int c_im = c / ksize / ksize; + for (h = 0; h < height_col; ++h) { + for (w = 0; w < width_col; ++w) { + int im_row = h_offset + h * stride; + int im_col = w_offset + w * stride; + int col_index = (c * height_col + h) * width_col + w + b*col_size; + double val = data_col[col_index]; + col2im_add_pixel(data_im, height, width, channels, + b, im_row, im_col, c_im, pad, val); + } } } } } +#ifdef GPU + +#include "opencl.h" + +cl_kernel get_col2im_kernel() +{ + static int init = 0; + static cl_kernel im2col_kernel; + if(!init){ + im2col_kernel = get_kernel("src/col2im.cl", "col2im", 0); + init = 1; + } + return im2col_kernel; +} + +void col2im_ongpu(cl_mem data_col, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, cl_mem data_im) +{ + cl_setup(); + cl_kernel kernel = get_col2im_kernel(); + cl_command_queue queue = cl.queue; + + cl_uint i = 0; + cl.error = clSetKernelArg(kernel, i++, sizeof(data_col), (void*) &data_col); + cl.error = clSetKernelArg(kernel, i++, sizeof(batch), (void*) &batch); + cl.error = clSetKernelArg(kernel, i++, sizeof(channels), (void*) &channels); + cl.error = clSetKernelArg(kernel, i++, sizeof(height), (void*) &height); + cl.error = clSetKernelArg(kernel, i++, sizeof(width), (void*) &width); + cl.error = clSetKernelArg(kernel, i++, sizeof(ksize), (void*) &ksize); + cl.error = clSetKernelArg(kernel, i++, sizeof(stride), (void*) &stride); + cl.error = clSetKernelArg(kernel, i++, sizeof(pad), (void*) &pad); + cl.error = clSetKernelArg(kernel, i++, sizeof(data_im), (void*) &data_im); + check_error(cl); + + size_t global_size = {channels*height*width*batch}; + + clEnqueueNDRangeKernel(queue, kernel, 3, 0, + global_size, 0, 0, 0, 0); + check_error(cl); +} + +#endif diff --git a/src/col2im.cl b/src/col2im.cl index e69de29b..c8e3b30b 100644 --- a/src/col2im.cl +++ b/src/col2im.cl @@ -0,0 +1,41 @@ +int index(int row, int col) +{ + +} + +__kernel void col2im(__global float *data_col, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, __global float *data_im) +{ + int id = get_global_id(0); + int index = id; + int w = id%width; + id /= width; + int h = id%height; + id /= height; + int c = id%channels; + id /= channels; + int b = id%batch; + + int height_col = (height - ksize) / stride + 1; + int width_col = (width - ksize) / stride + 1; + int rows = channels * ksize * ksize; + if (pad){ + height_col = 1 + (height-1) / stride; + width_col = 1 + (width-1) / stride; + pad = ksize/2; + } + int cols = height_col*width_col; + int batch_offset = b*cols*rows; + int channel_offset = c*cols*ksize*ksize; + data_col[index] = 0; + int i,j; + for(i = 0; i < ksize; ++i){ + row_offset = i*height_col*width_col; + for(j = 0; j < ksize; ++j){ + col_offset = + } + } + + data_col[col_index] = im2col_get_pixel(data_im, height, width, channels, b, im_row, im_col, c_im, pad); +} diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 2d4d7489..bdbfbfd9 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -147,15 +147,9 @@ void forward_convolutional_layer(const convolutional_layer layer, float *in) for(i = 0; i < layer.batch; ++i){ gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); - c += n*m; - in += layer.h*layer.w*layer.c; b += k*n; + c += n*m; } - /* - int i; - for(i = 0; i < m*n; ++i) printf("%f, ", layer.output[i]); - printf("\n"); - */ activate_array(layer.output, m*n*layer.batch, layer.activation); } @@ -205,10 +199,10 @@ void backward_convolutional_layer(convolutional_layer layer, float *delta) for(i = 0; i < layer.batch; ++i){ gemm(1,0,m,n,k,1,a,m,b,n,0,c,n); - col2im_cpu(c, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, delta); - c += k*n; - delta += layer.h*layer.w*layer.c; + b += k*n; + c += m*n; } + col2im_cpu(layer.col_image, layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, delta); } } @@ -278,22 +272,140 @@ image *visualize_convolutional_layer(convolutional_layer layer, char *window, im } #ifdef GPU + +cl_kernel get_convolutional_learn_bias_kernel() +{ + static int init = 0; + static cl_kernel kernel; + if(!init){ + kernel = get_kernel("src/convolutional_layer.cl", "learn_bias", 0); + init = 1; + } + return kernel; +} + +void learn_bias_convolutional_layer_ongpu(convolutional_layer layer) +{ + int size = convolutional_out_height(layer) * convolutional_out_width(layer); + + cl_setup(); + cl_kernel kernel = get_convolutional_learn_bias_kernel(); + cl_command_queue queue = cl.queue; + + cl_uint i = 0; + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.batch), (void*) &layer.batch); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.n), (void*) &layer.n); + cl.error = clSetKernelArg(kernel, i++, sizeof(size), (void*) &size); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.delta_cl), (void*) &layer.delta_cl); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.bias_updates_cl), (void*) &layer.bias_updates_cl); + check_error(cl); + + const size_t global_size[] = {layer.n}; + + clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); + check_error(cl); +} + +cl_kernel get_convolutional_bias_kernel() +{ + static int init = 0; + static cl_kernel kernel; + if(!init){ + kernel = get_kernel("src/convolutional_layer.cl", "bias", 0); + init = 1; + } + return kernel; +} + +void bias_output_gpu(const convolutional_layer layer) +{ + int out_h = convolutional_out_height(layer); + int out_w = convolutional_out_width(layer); + int size = out_h*out_w; + + cl_setup(); + cl_kernel kernel = get_convolutional_bias_kernel(); + cl_command_queue queue = cl.queue; + + cl_uint i = 0; + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.n), (void*) &layer.n); + cl.error = clSetKernelArg(kernel, i++, sizeof(size), (void*) &size); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.biases_cl), (void*) &layer.biases_cl); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.output_cl), (void*) &layer.output_cl); + check_error(cl); + + const size_t global_size[] = {layer.batch, layer.n*size}; + + clEnqueueNDRangeKernel(queue, kernel, 2, 0, global_size, 0, 0, 0, 0); + check_error(cl); +} + void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in) { + int i; int m = layer.n; int k = layer.size*layer.size*layer.c; int n = convolutional_out_height(layer)* - convolutional_out_width(layer)* - layer.batch; + convolutional_out_width(layer); - cl_write_array(layer.filters_cl, layer.filters, m*k); - cl_mem a = layer.filters_cl; - cl_mem b = layer.col_image_cl; - cl_mem c = layer.output_cl; - im2col_ongpu(in, layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, b); - gemm_ongpu(0,0,m,n,k,1,a,k,b,n,0,c,n); - activate_array_ongpu(layer.output_cl, m*n, layer.activation); - cl_read_array(layer.output_cl, layer.output, m*n); + //cl_write_array(layer.filters_cl, layer.filters, m*k); + //cl_write_array(layer.biases_cl, layer.biases, m); + bias_output_gpu(layer); + im2col_ongpu(in, layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_cl); + for(i = 0; i < layer.batch; ++i){ + cl_mem a = layer.filters_cl; + cl_mem b = cl_sub_array(layer.col_image_cl, i*k*n, k*n); + cl_mem c = cl_sub_array(layer.output_cl, i*m*n, m*n); + gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c,n); + clReleaseMemObject(b); + clReleaseMemObject(c); + } + activate_array_ongpu(layer.output_cl, m*n*layer.batch, layer.activation); + cl_read_array(layer.output_cl, layer.output, m*n*layer.batch); } + +void backward_convolutional_layer_gpu(convolutional_layer layer, cl_mem delta_cl) +{ + int i; + int m = layer.n; + int n = layer.size*layer.size*layer.c; + int k = convolutional_out_height(layer)* + convolutional_out_width(layer); + gradient_array_ongpu(layer.output_cl, m*k*layer.batch, layer.activation, layer.delta_cl); + learn_bias_convolutional_layer_ongpu(layer); + + for(i = 0; i < layer.batch; ++i){ + cl_mem a = cl_sub_array(layer.delta_cl,i*m*k, m*k); + cl_mem b = cl_sub_array(layer.col_image_cl,i*k*n, k*n); + cl_mem c = layer.filter_updates_cl; + + gemm_ongpu(0,1,m,n,k,1,a,k,b,k,1,c,n); + + clReleaseMemObject(a); + clReleaseMemObject(b); + } + cl_read_array(layer.filter_updates_cl, layer.filter_updates, m*n); + cl_read_array(layer.bias_updates_cl, layer.bias_updates, m); + + + if(delta_cl){ + m = layer.size*layer.size*layer.c; + k = layer.n; + n = convolutional_out_height(layer)* + convolutional_out_width(layer); + + for(i = 0; i < layer.batch; ++i){ + a = layer.filters_cl; + b = cl_sub_array(layer.delta_cl, i*k*n, k*n); + c = cl_sub_array(layer.col_image_cl, i*m*n, m*n); + + gemm_ongpu(1,0,m,n,k,1,a,m,b,n,0,c,n); + clReleaseMemObject(b); + clReleaseMemObject(c); + } + col2im_gpu(layer.col_image_cl, layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, delta_cl); + } +} + #endif diff --git a/src/convolutional_layer.cl b/src/convolutional_layer.cl new file mode 100644 index 00000000..6393c37b --- /dev/null +++ b/src/convolutional_layer.cl @@ -0,0 +1,25 @@ + +__kernel void bias(int n, int size, __global float *biases, __global float *output) +{ + int batch = get_global_id(0); + int id = get_global_id(1); + int filter = id/size; + int position = id%size; + + output[batch*n*size + id] = biases[filter]; +} + +__kernel void learn_bias(int batch, int n, int size, __global float *delta, __global float *bias_updates) +{ + int i,b; + int filter = get_global_id(0); + float sum = 0; + for(b = 0; b < batch; ++b){ + for(i = 0; i < size; ++i){ + int index = i + size*(filter + n*b); + sum += delta[index]; + } + } + bias_updates[filter] += sum; +} + diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index f876e8b4..cf897a74 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -50,6 +50,7 @@ typedef struct { #ifdef GPU void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in); +void backward_convolutional_layer_gpu(convolutional_layer layer, cl_mem delta_cl); #endif convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, float learning_rate, float momentum, float decay); diff --git a/src/data.c b/src/data.c index 846b950a..aa8fecf2 100644 --- a/src/data.c +++ b/src/data.c @@ -148,6 +148,16 @@ data load_cifar10_data(char *filename) return d; } +void get_batch(data d, int n, float *X, float *y) +{ + int j; + for(j = 0; j < n; ++j){ + int index = rand()%d.X.rows; + memcpy(X+j*d.X.cols, d.X.vals[index], d.X.cols*sizeof(float)); + memcpy(y+j*d.y.cols, d.y.vals[index], d.y.cols*sizeof(float)); + } +} + data load_all_cifar10() { data d; @@ -158,7 +168,7 @@ data load_all_cifar10() d.X = X; d.y = y; - + for(b = 0; b < 5; ++b){ char buff[256]; sprintf(buff, "data/cifar10/data_batch_%d.bin", b+1); @@ -176,8 +186,8 @@ data load_all_cifar10() fclose(fp); } //normalize_data_rows(d); - translate_data_rows(d, -144); - scale_data_rows(d, 1./128); + translate_data_rows(d, -144); + scale_data_rows(d, 1./128); return d; } diff --git a/src/data.h b/src/data.h index 0a1830e6..bd677e86 100644 --- a/src/data.h +++ b/src/data.h @@ -20,6 +20,7 @@ data load_data_image_pathfile_random(char *filename, int n, char **labels, data load_cifar10_data(char *filename); data load_all_cifar10(); list *get_paths(char *filename); +void get_batch(data d, int n, float *X, float *y); data load_categorical_data_csv(char *filename, int target, int k); void normalize_data_rows(data d); void scale_data_rows(data d, float s); diff --git a/src/gemm.c b/src/gemm.c index 1a7bcdd4..65542bcc 100644 --- a/src/gemm.c +++ b/src/gemm.c @@ -6,11 +6,7 @@ void gemm(int TA, int TB, int M, int N, int K, float ALPHA, float BETA, float *C, int ldc) { -#ifdef GPU - gemm_gpu( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); -#else gemm_cpu( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); -#endif } void gemm_nn(int M, int N, int K, float ALPHA, @@ -83,6 +79,7 @@ void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA, float BETA, float *C, int ldc) { + //printf("cpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc); int i, j; for(i = 0; i < M; ++i){ for(j = 0; j < N; ++j){ @@ -107,7 +104,11 @@ void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA, #define STR_HELPER(x) #x #define STR(x) STR_HELPER(x) +#ifdef __APPLE__ +#define BLOCK 1 +#else #define BLOCK 8 +#endif cl_kernel get_gemm_kernel() { @@ -126,6 +127,7 @@ void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA, float BETA, cl_mem C_gpu, int ldc) { + //printf("gpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc); cl_setup(); cl_kernel gemm_kernel = get_gemm_kernel(); cl_command_queue queue = cl.queue; @@ -256,6 +258,8 @@ void test_gpu_accuracy(int TA, int TB, int m, int k, int n) void test_gpu_blas() { + test_gpu_accuracy(0,0,10,576,75); + test_gpu_accuracy(0,0,17,10,10); test_gpu_accuracy(1,0,17,10,10); test_gpu_accuracy(0,1,17,10,10); @@ -266,6 +270,7 @@ void test_gpu_blas() test_gpu_accuracy(0,1,1000,10,100); test_gpu_accuracy(1,1,1000,10,100); +/* time_gpu_random_matrix(0,0,1000,1000,100); time_random_matrix(0,0,1000,1000,100); @@ -277,6 +282,7 @@ void test_gpu_blas() time_gpu_random_matrix(1,1,1000,1000,100); time_random_matrix(1,1,1000,1000,100); + */ } #endif diff --git a/src/im2col.c b/src/im2col.c index 6ed9d891..08f7ce43 100644 --- a/src/im2col.c +++ b/src/im2col.c @@ -1,22 +1,21 @@ #include "mini_blas.h" #include - inline float im2col_get_pixel(float *im, int height, int width, int channels, - int row, int col, int channel, int pad) + int b, int row, int col, int channel, int pad) { row -= pad; col -= pad; if (row < 0 || col < 0 || row >= height || col >= width) return 0; - return im[col + width*(row + channel*height)]; + return im[col + width*(row + height*(channel+b*channels))]; } //From Berkeley Vision's Caffe! //https://github.com/BVLC/caffe/blob/master/LICENSE -void im2col_cpu_batch(float* data_im, - const int batch, const int channels, const int height, const int width, - const int ksize, const int stride, int pad, float* data_col) +void im2col_cpu(float* data_im, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, float* data_col) { int c,h,w,b; int height_col = (height - ksize) / stride + 1; @@ -27,44 +26,6 @@ void im2col_cpu_batch(float* data_im, pad = ksize/2; } int channels_col = channels * ksize * ksize; - int im_size = height*width*channels; - //int col_size = height_col*width_col*channels_col; - for (b = 0; b < batch; ++b) { - for (c = 0; c < channels_col; ++c) { - int w_offset = c % ksize; - int h_offset = (c / ksize) % ksize; - int c_im = c / ksize / ksize; - for (h = 0; h < height_col; ++h) { - for (w = 0; w < width_col; ++w) { - int im_row = h_offset + h * stride; - int im_col = w_offset + w * stride; - int col_index = (c * height_col + h) * width_col + w + (batch-1) * c * height_col*width_col; - data_col[col_index] = im2col_get_pixel(data_im, height, width, channels, - im_row, im_col, c_im, pad); - } - } - } - data_im += im_size; - data_col+= channels_col; - } -} - -//From Berkeley Vision's Caffe! -//https://github.com/BVLC/caffe/blob/master/LICENSE -void im2col_cpu(float* data_im, const int batch, - const int channels, const int height, const int width, - const int ksize, const int stride, int pad, float* data_col) -{ - int c,h,w,b; - int height_col = (height - ksize) / stride + 1; - int width_col = (width - ksize) / stride + 1; - if (pad){ - height_col = 1 + (height-1) / stride; - width_col = 1 + (width-1) / stride; - pad = ksize/2; - } - int channels_col = channels * ksize * ksize; - int im_size = height*width*channels; int col_size = height_col*width_col*channels_col; for (b = 0; b < batch; ++b) { for (c = 0; c < channels_col; ++c) { @@ -75,14 +36,12 @@ void im2col_cpu(float* data_im, const int batch, for (w = 0; w < width_col; ++w) { int im_row = h_offset + h * stride; int im_col = w_offset + w * stride; - int col_index = (c * height_col + h) * width_col + w; + int col_index = (c * height_col + h) * width_col + w + b*col_size; data_col[col_index] = im2col_get_pixel(data_im, height, width, channels, - im_row, im_col, c_im, pad); + b, im_row, im_col, c_im, pad); } } } - data_im += im_size; - data_col += col_size; } } @@ -104,9 +63,9 @@ cl_kernel get_im2col_kernel() } -void im2col_ongpu(cl_mem data_im, const int batch, - const int channels, const int height, const int width, - const int ksize, const int stride, cl_mem data_col) +void im2col_ongpu(cl_mem data_im, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, cl_mem data_col) { cl_setup(); cl_kernel im2col_kernel = get_im2col_kernel(); @@ -120,29 +79,30 @@ void im2col_ongpu(cl_mem data_im, const int batch, cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(width), (void*) &width); cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(ksize), (void*) &ksize); cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(stride), (void*) &stride); + cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(pad), (void*) &pad); cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_col), (void*) &data_col); check_error(cl); int height_col = (height - ksize) / stride + 1; int width_col = (width - ksize) / stride + 1; int channels_col = channels * ksize * ksize; + if (pad){ + height_col = 1 + (height-1) / stride; + width_col = 1 + (width-1) / stride; + } size_t global_size[2]; - size_t local_size[2]; - global_size[0] = batch; - global_size[1] = channels_col; - local_size[0] = height_col; - local_size[1] = width_col; + global_size[0] = batch*channels_col; + global_size[1] = height_col*width_col; clEnqueueNDRangeKernel(queue, im2col_kernel, 2, 0, - global_size, local_size, 0, 0, 0); + global_size, 0, 0, 0, 0); check_error(cl); } -void im2col_gpu(float *data_im, - const int batch, const int channels, const int height, const int width, - const int ksize, const int stride, - float *data_col) +void im2col_gpu(float *data_im, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, float *data_col) { cl_setup(); cl_context context = cl.context; @@ -165,7 +125,7 @@ void im2col_gpu(float *data_im, check_error(cl); im2col_ongpu(im_gpu, batch, channels, height, width, - ksize, stride, col_gpu); + ksize, stride, pad, col_gpu); clEnqueueReadBuffer(queue, col_gpu, CL_TRUE, 0, size, data_col, 0, 0, 0); check_error(cl); diff --git a/src/im2col.cl b/src/im2col.cl index 765a92df..6ed5d89c 100644 --- a/src/im2col.cl +++ b/src/im2col.cl @@ -1,26 +1,43 @@ -__kernel void im2col(__global float *data_im, const int im_offset, - const int channels, const int height, const int width, - const int ksize, const int stride, __global float *data_col, const int col_offset) +float im2col_get_pixel(__global float *im, int height, int width, int channels, + int batch, int row, int col, int channel, int pad) { - int b = get_global_id(0); - int c = get_global_id(1); + row -= pad; + col -= pad; - int h = get_local_id(0); - int w = get_local_id(1); + if (row < 0 || col < 0 || row >= height || col >= width) return 0; + int index = col + width*(row + height*(channel+batch*channels)); + return im[index]; +} +__kernel void im2col(__global float *data_im, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, __global float *data_col) +{ + int c,h,w,b; int height_col = (height - ksize) / stride + 1; int width_col = (width - ksize) / stride + 1; + if (pad){ + height_col = 1 + (height-1) / stride; + width_col = 1 + (width-1) / stride; + pad = ksize/2; + } + int gid1 = get_global_id(0); + b = gid1%batch; + c = gid1/batch; + + int gid2 = get_global_id(1); + h = gid2%height_col; + w = gid2/height_col; + + int channels_col = channels * ksize * ksize; - - int im_offset = height*width*channels*b; - int col_offset = height_col*width_col*channels_col*b; - + int col_size = height_col*width_col*channels_col; int w_offset = c % ksize; int h_offset = (c / ksize) % ksize; int c_im = c / ksize / ksize; - - data_col[(c * height_col + h) * width_col + w + col_offset] = - data_im[(c_im * height + h * stride + h_offset) * width - + w * stride + w_offset + im_offset]; + int im_row = h_offset + h * stride; + int im_col = w_offset + w * stride; + int col_index = (c * height_col + h) * width_col + w + b*col_size; + data_col[col_index] = im2col_get_pixel(data_im, height, width, channels, b, im_row, im_col, c_im, pad); } diff --git a/src/mini_blas.h b/src/mini_blas.h index c80e6ad5..34905a17 100644 --- a/src/mini_blas.h +++ b/src/mini_blas.h @@ -10,13 +10,17 @@ float *random_matrix(int rows, int cols); void time_random_matrix(int TA, int TB, int m, int k, int n); #ifdef GPU -void im2col_ongpu(cl_mem data_im, const int batch, - const int channels, const int height, const int width, - const int ksize, const int stride, cl_mem data_col); +void im2col_ongpu(cl_mem data_im, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, cl_mem data_col); -void im2col_gpu(float *data_im, - const int batch, const int channels, const int height, const int width, - const int ksize, const int stride, float *data_col); +void col2im_ongpu(cl_mem data_col, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, cl_mem data_im); + +void im2col_gpu(float *data_im, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, float *data_col); void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA, cl_mem A_gpu, int lda, @@ -25,13 +29,14 @@ void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA, cl_mem C_gpu, int ldc); #endif -void im2col_cpu(float* data_im, const int batch, - const int channels, const int height, const int width, - const int ksize, const int stride, int pad, float* data_col); +void im2col_cpu(float* data_im, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, float* data_col); + +void col2im_cpu(float* data_col, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, float* data_im); -void col2im_cpu(float* data_col, - const int channels, const int height, const int width, - const int ksize, const int stride, int pad, float* data_im); void test_blas(); void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA, diff --git a/src/network.c b/src/network.c index 292bba0e..3761bf92 100644 --- a/src/network.c +++ b/src/network.c @@ -28,25 +28,16 @@ network make_network(int n, int batch) } #ifdef GPU -void forward_network(network net, float *input, int train) +void forward_network_gpu(network net, cl_mem input_cl, int train) { - cl_setup(); - size_t size = get_network_input_size(net); - if(!net.input_cl){ - net.input_cl = clCreateBuffer(cl.context, - CL_MEM_READ_WRITE, size*sizeof(float), 0, &cl.error); - check_error(cl); - } - cl_write_array(net.input_cl, input, size); - cl_mem input_cl = net.input_cl; int i; for(i = 0; i < net.n; ++i){ if(net.types[i] == CONVOLUTIONAL){ convolutional_layer layer = *(convolutional_layer *)net.layers[i]; forward_convolutional_layer_gpu(layer, input_cl); input_cl = layer.output_cl; - input = layer.output; } + /* else if(net.types[i] == CONNECTED){ connected_layer layer = *(connected_layer *)net.layers[i]; forward_connected_layer(layer, input, train); @@ -72,10 +63,11 @@ void forward_network(network net, float *input, int train) forward_normalization_layer(layer, input); input = layer.output; } + */ } } -#else +#endif void forward_network(network net, float *input, int train) { @@ -118,7 +110,6 @@ void forward_network(network net, float *input, int train) } } } -#endif void update_network(network net) { @@ -275,45 +266,13 @@ float train_network_sgd(network net, data d, int n) float *X = calloc(batch*d.X.cols, sizeof(float)); float *y = calloc(batch*d.y.cols, sizeof(float)); - int i,j; + int i; float sum = 0; - int index = 0; for(i = 0; i < n; ++i){ - for(j = 0; j < batch; ++j){ - index = rand()%d.X.rows; - memcpy(X+j*d.X.cols, d.X.vals[index], d.X.cols*sizeof(float)); - memcpy(y+j*d.y.cols, d.y.vals[index], d.y.cols*sizeof(float)); - } - + get_batch(d, batch, X, y); float err = train_network_datum(net, X, y); sum += err; - //train_network_datum(net, X, y); - /* - float *y = d.y.vals[index]; - int class = get_predicted_class_network(net); - correct += (y[class]?1:0); - */ - -/* - for(j = 0; j < d.y.cols*batch; ++j){ - printf("%6.3f ", y[j]); - } - printf("\n"); - for(j = 0; j < d.y.cols*batch; ++j){ - printf("%6.3f ", get_network_output(net)[j]); - } - printf("\n"); - printf("\n"); - */ - - - //printf("%d %f %f\n", i,net.output[0], d.y.vals[index][0]); - //if((i+1)%10 == 0){ - // printf("%d: %f\n", (i+1), (float)correct/(i+1)); - //} } - //printf("Accuracy: %f\n",(float) correct/n); - //show_image(float_to_image(32,32,3,X), "Orig"); free(X); free(y); return (float)sum/(n*batch); diff --git a/src/network.h b/src/network.h index f8666e65..65ace579 100644 --- a/src/network.h +++ b/src/network.h @@ -33,6 +33,10 @@ typedef struct { #endif } network; +#ifdef GPU +void forward_network_gpu(network net, cl_mem input, int train); +#endif + network make_network(int n, int batch); void forward_network(network net, float *input, int train); float backward_network(network net, float *input, float *truth); diff --git a/src/opencl.c b/src/opencl.c index 8f9edd3c..bcc0f09e 100644 --- a/src/opencl.c +++ b/src/opencl.c @@ -11,6 +11,7 @@ cl_info cl = {0}; void check_error(cl_info info) { + clFinish(cl.queue); if (info.error != CL_SUCCESS) { printf("\n Error number %d", info.error); exit(1); @@ -27,13 +28,60 @@ cl_info cl_init() // Fetch the Platform and Device IDs; we only want one. cl_device_id devices[MAX_DEVICES]; info.error=clGetPlatformIDs(1, &info.platform, &num_platforms); + + printf("=== %d OpenCL platform(s) found: ===\n", num_platforms); + char buffer[10240]; + clGetPlatformInfo(info.platform, CL_PLATFORM_PROFILE, 10240, buffer, NULL); + printf(" PROFILE = %s\n", buffer); + clGetPlatformInfo(info.platform, CL_PLATFORM_VERSION, 10240, buffer, NULL); + printf(" VERSION = %s\n", buffer); + clGetPlatformInfo(info.platform, CL_PLATFORM_NAME, 10240, buffer, NULL); + printf(" NAME = %s\n", buffer); + clGetPlatformInfo(info.platform, CL_PLATFORM_VENDOR, 10240, buffer, NULL); + printf(" VENDOR = %s\n", buffer); + clGetPlatformInfo(info.platform, CL_PLATFORM_EXTENSIONS, 10240, buffer, NULL); + printf(" EXTENSIONS = %s\n", buffer); + check_error(info); info.error=clGetDeviceIDs(info.platform, CL_DEVICE_TYPE_ALL, MAX_DEVICES, devices, &num_devices); if(num_devices > MAX_DEVICES) num_devices = MAX_DEVICES; + printf("=== %d OpenCL device(s) found on platform:\n", num_devices); + int i; + for (i=0; i