From 787d5345609459f21fd65d2d8b4fcd55201e21a1 Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Mon, 13 Oct 2014 00:29:01 -0700 Subject: [PATCH] Convolutional working on GPU --- Makefile | 4 +- src/activations.c | 1 + src/axpy.c | 114 +++++++++++++++++++++++++++++++++- src/axpy.cl | 18 ++++++ src/cnn.c | 100 ++++++++++++++++++++++++------ src/col2im.c | 27 +++++++- src/col2im.cl | 61 +++++++++--------- src/connected_layer.c | 9 ++- src/connected_layer.h | 14 +++++ src/convolutional_layer.c | 33 +++++++--- src/convolutional_layer.h | 4 +- src/cost_layer.c | 49 +++++++++++++++ src/cost_layer.h | 24 +++++++ src/freeweight_layer.c | 24 +++++++ src/freeweight_layer.h | 14 +++++ src/im2col.cl | 1 - src/mini_blas.h | 12 +++- src/network.c | 127 +++++++++++++++++++++++++++++++++----- src/network.h | 15 +++-- src/opencl.c | 4 +- src/parser.c | 81 ++++++++++++++++++++++++ 21 files changed, 643 insertions(+), 93 deletions(-) create mode 100644 src/cost_layer.c create mode 100644 src/cost_layer.h create mode 100644 src/freeweight_layer.c create mode 100644 src/freeweight_layer.h diff --git a/Makefile b/Makefile index cf0cfdf1..b5ad1eb0 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ CC=gcc -GPU=0 +GPU=1 COMMON=-Wall -Wfatal-errors `pkg-config --cflags opencv` -I/usr/local/cuda/include/ ifeq ($(GPU), 1) COMMON+=-DGPU @@ -25,7 +25,7 @@ VPATH=./src/ EXEC=cnn OBJDIR=./obj/ -OBJ=network.o image.o cnn.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o convolutional_layer.o gemm.o normalization_layer.o opencl.o im2col.o col2im.o axpy.o dropout_layer.o crop_layer.o +OBJ=network.o image.o cnn.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o convolutional_layer.o gemm.o normalization_layer.o opencl.o im2col.o col2im.o axpy.o dropout_layer.o crop_layer.o freeweight_layer.o cost_layer.o OBJS = $(addprefix $(OBJDIR), $(OBJ)) all: $(EXEC) diff --git a/src/activations.c b/src/activations.c index 4a4bd3fe..84fe9f96 100644 --- a/src/activations.c +++ b/src/activations.c @@ -40,6 +40,7 @@ float sigmoid_activate(float x){return 1./(1. + exp(-x));} float relu_activate(float x){return x*(x>0);} float ramp_activate(float x){return x*(x>0)+.1*x;} float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);} +//float tanh_activate(float x){return x - (x*x*x)/3;} float linear_gradient(float x){return 1;} float sigmoid_gradient(float x){return (1-x)*x;} diff --git a/src/axpy.c b/src/axpy.c index 750f47ea..c4ec1ebc 100644 --- a/src/axpy.c +++ b/src/axpy.c @@ -1,14 +1,124 @@ #include "mini_blas.h" -void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY) +inline void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY) { int i; for(i = 0; i < N; ++i) Y[i*INCY] += ALPHA*X[i*INCX]; } -void scal_cpu(int N, float ALPHA, float *X, int INCX) +inline void scal_cpu(int N, float ALPHA, float *X, int INCX) { int i; for(i = 0; i < N; ++i) X[i*INCX] *= ALPHA; } +inline void copy_cpu(int N, float *X, int INCX, float *Y, int INCY) +{ + int i; + for(i = 0; i < N; ++i) Y[i*INCY] = X[i*INCX]; +} + +inline float dot_cpu(int N, float *X, int INCX, float *Y, int INCY) +{ + int i; + float dot = 0; + for(i = 0; i < N; ++i) dot += X[i*INCX] * Y[i*INCY]; + return dot; +} + +#ifdef GPU +#include "opencl.h" + +cl_kernel get_axpy_kernel() +{ + static int init = 0; + static cl_kernel kernel; + if(!init){ + kernel = get_kernel("src/axpy.cl", "axpy", 0); + init = 1; + } + return kernel; +} + +cl_kernel get_copy_kernel() +{ + static int init = 0; + static cl_kernel kernel; + if(!init){ + kernel = get_kernel("src/axpy.cl", "copy", 0); + init = 1; + } + return kernel; +} + +cl_kernel get_scal_kernel() +{ + static int init = 0; + static cl_kernel kernel; + if(!init){ + kernel = get_kernel("src/axpy.cl", "scal", 0); + init = 1; + } + return kernel; +} + + +void axpy_ongpu(int N, float ALPHA, cl_mem X, int INCX, cl_mem Y, int INCY) +{ + cl_setup(); + cl_kernel kernel = get_axpy_kernel(); + cl_command_queue queue = cl.queue; + + cl_uint i = 0; + cl.error = clSetKernelArg(kernel, i++, sizeof(N), (void*) &N); + cl.error = clSetKernelArg(kernel, i++, sizeof(ALPHA), (void*) &ALPHA); + cl.error = clSetKernelArg(kernel, i++, sizeof(X), (void*) &X); + cl.error = clSetKernelArg(kernel, i++, sizeof(INCX), (void*) &INCX); + cl.error = clSetKernelArg(kernel, i++, sizeof(Y), (void*) &Y); + cl.error = clSetKernelArg(kernel, i++, sizeof(INCY), (void*) &INCY); + check_error(cl); + + const size_t global_size[] = {N}; + + clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); + check_error(cl); + +} +void copy_ongpu(int N, cl_mem X, int INCX, cl_mem Y, int INCY) +{ + cl_setup(); + cl_kernel kernel = get_copy_kernel(); + cl_command_queue queue = cl.queue; + + cl_uint i = 0; + cl.error = clSetKernelArg(kernel, i++, sizeof(N), (void*) &N); + cl.error = clSetKernelArg(kernel, i++, sizeof(X), (void*) &X); + cl.error = clSetKernelArg(kernel, i++, sizeof(INCX), (void*) &INCX); + cl.error = clSetKernelArg(kernel, i++, sizeof(Y), (void*) &Y); + cl.error = clSetKernelArg(kernel, i++, sizeof(INCY), (void*) &INCY); + check_error(cl); + + const size_t global_size[] = {N}; + + clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); + check_error(cl); +} +void scal_ongpu(int N, float ALPHA, cl_mem X, int INCX) +{ + cl_setup(); + cl_kernel kernel = get_scal_kernel(); + cl_command_queue queue = cl.queue; + + cl_uint i = 0; + cl.error = clSetKernelArg(kernel, i++, sizeof(N), (void*) &N); + cl.error = clSetKernelArg(kernel, i++, sizeof(ALPHA), (void*) &ALPHA); + cl.error = clSetKernelArg(kernel, i++, sizeof(X), (void*) &X); + cl.error = clSetKernelArg(kernel, i++, sizeof(INCX), (void*) &INCX); + check_error(cl); + + const size_t global_size[] = {N}; + + clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); + check_error(cl); +} +#endif diff --git a/src/axpy.cl b/src/axpy.cl index e69de29b..394d8976 100644 --- a/src/axpy.cl +++ b/src/axpy.cl @@ -0,0 +1,18 @@ +__kernel void axpy(int N, float ALPHA, __global float *X, int INCX, __global float *Y, int INCY) +{ + int i = get_global_id(0); + Y[i*INCY] += ALPHA*X[i*INCX]; +} + +__kernel void scal(int N, float ALPHA, __global float *X, int INCX) +{ + int i = get_global_id(0); + X[i*INCX] *= ALPHA; +} + +__kernel void copy(int N, __global float *X, int INCX, __global float *Y, int INCY) +{ + int i = get_global_id(0); + Y[i*INCY] = X[i*INCX]; +} + diff --git a/src/cnn.c b/src/cnn.c index 0cd6da30..472aa03b 100644 --- a/src/cnn.c +++ b/src/cnn.c @@ -37,42 +37,104 @@ void test_convolve() void test_convolutional_layer() { int i; - image dog = load_image("data/dog.jpg",256,256); + image dog = load_image("data/dog.jpg",224,224); network net = parse_network_cfg("cfg/convolutional.cfg"); // data test = load_cifar10_data("data/cifar10/test_batch.bin"); // float *X = calloc(net.batch*test.X.cols, sizeof(float)); // float *y = calloc(net.batch*test.y.cols, sizeof(float)); int in_size = get_network_input_size(net)*net.batch; + int del_size = get_network_output_size_layer(net, 0)*net.batch; int size = get_network_output_size(net)*net.batch; -float *X = calloc(in_size, sizeof(float)); + float *X = calloc(in_size, sizeof(float)); + float *y = calloc(size, sizeof(float)); for(i = 0; i < in_size; ++i){ X[i] = dog.data[i%get_network_input_size(net)]; } // get_batch(test, net.batch, X, y); clock_t start, end; cl_mem input_cl = cl_make_array(X, in_size); + cl_mem truth_cl = cl_make_array(y, size); - forward_network_gpu(net, input_cl, 1); + forward_network_gpu(net, input_cl, truth_cl, 1); start = clock(); - forward_network_gpu(net, input_cl, 1); + forward_network_gpu(net, input_cl, truth_cl, 1); end = clock(); float gpu_sec = (float)(end-start)/CLOCKS_PER_SEC; + printf("forward gpu: %f sec\n", gpu_sec); + start = clock(); + backward_network_gpu(net, input_cl); + end = clock(); + gpu_sec = (float)(end-start)/CLOCKS_PER_SEC; + printf("backward gpu: %f sec\n", gpu_sec); + //float gpu_cost = get_network_cost(net); float *gpu_out = calloc(size, sizeof(float)); memcpy(gpu_out, get_network_output(net), size*sizeof(float)); + float *gpu_del = calloc(del_size, sizeof(float)); + memcpy(gpu_del, get_network_delta_layer(net, 0), del_size*sizeof(float)); + +/* start = clock(); - forward_network(net, X, 1); + forward_network(net, X, y, 1); + backward_network(net, X); + float cpu_cost = get_network_cost(net); end = clock(); float cpu_sec = (float)(end-start)/CLOCKS_PER_SEC; float *cpu_out = calloc(size, sizeof(float)); memcpy(cpu_out, get_network_output(net), size*sizeof(float)); + float *cpu_del = calloc(del_size, sizeof(float)); + memcpy(cpu_del, get_network_delta_layer(net, 0), del_size*sizeof(float)); float sum = 0; - for(i = 0; i < size; ++i) { - //printf("%f, %f\n", gpu_out[i], cpu_out[i]); - sum += pow(gpu_out[i] - cpu_out[i], 2); + float del_sum = 0; + for(i = 0; i < size; ++i) sum += pow(gpu_out[i] - cpu_out[i], 2); + for(i = 0; i < del_size; ++i) { + //printf("%f %f\n", cpu_del[i], gpu_del[i]); + del_sum += pow(cpu_del[i] - gpu_del[i], 2); } - printf("gpu: %f sec, cpu: %f sec, diff: %f, size: %d\n", gpu_sec, cpu_sec, sum, size); + printf("GPU cost: %f, CPU cost: %f\n", gpu_cost, cpu_cost); + printf("gpu: %f sec, cpu: %f sec, diff: %f, delta diff: %f, size: %d\n", gpu_sec, cpu_sec, sum, del_sum, size); + */ +} + +void test_col2im() +{ + float col[] = {1,2,1,2, + 1,2,1,2, + 1,2,1,2, + 1,2,1,2, + 1,2,1,2, + 1,2,1,2, + 1,2,1,2, + 1,2,1,2, + 1,2,1,2}; + float im[16] = {0}; + int batch = 1; + int channels = 1; + int height=4; + int width=4; + int ksize = 3; + int stride = 1; + int pad = 0; + col2im_gpu(col, batch, + channels, height, width, + ksize, stride, pad, im); + int i; + for(i = 0; i < 16; ++i)printf("%f,", im[i]); + printf("\n"); + /* + float data_im[] = { + 1,2,3,4, + 5,6,7,8, + 9,10,11,12 + }; + float data_col[18] = {0}; + im2col_cpu(data_im, batch, + channels, height, width, + ksize, stride, pad, data_col) ; + for(i = 0; i < 18; ++i)printf("%f,", data_col[i]); + printf("\n"); + */ } #endif @@ -274,7 +336,7 @@ void test_full() normalize_data_rows(test); for(j = 0; j < test.X.rows; ++j){ float *x = test.X.vals[j]; - forward_network(net, x, 0); + forward_network(net, x, 0, 0); int class = get_predicted_class_network(net); fprintf(fp, "%d\n", class); } @@ -285,7 +347,6 @@ void test_full() void test_cifar10() { - network net = parse_network_cfg("cfg/cifar10_part5.cfg"); data test = load_cifar10_data("data/cifar10/test_batch.bin"); clock_t start = clock(), end; @@ -457,7 +518,7 @@ void test_random_classify() int index = rand()%m.rows; //image p = float_to_image(1690,1,1,m.vals[index]); //normalize_image(p); - forward_network(net, m.vals[index], 1); + forward_network(net, m.vals[index], 0, 1); float *out = get_network_output(net); float *delta = get_network_delta(net); //printf("%f\n", out[0]); @@ -478,7 +539,7 @@ void test_random_classify() matrix test = csv_to_matrix("test.csv"); truth = pop_column(&test, 0); for(i = 0; i < test.rows; ++i){ - forward_network(net, test.vals[i], 0); + forward_network(net, test.vals[i],0, 0); float *out = get_network_output(net); if(fabs(out[0]) < .5) fprintf(fp, "0\n"); else fprintf(fp, "1\n"); @@ -578,7 +639,7 @@ image features_output_size(network net, IplImage *src, int outh, int outw) //normalize_array(im.data, im.h*im.w*im.c); translate_image(im, -144); resize_network(net, im.h, im.w, im.c); - forward_network(net, im.data, 0); + forward_network(net, im.data, 0, 0); image out = get_network_image(net); free_image(im); cvReleaseImage(&sized); @@ -630,7 +691,7 @@ void visualize_imagenet_topk(char *filename) resize_network(net, im.h, im.w, im.c); //scale_image(im, 1./255); translate_image(im, -144); - forward_network(net, im.data, 0); + forward_network(net, im.data, 0, 0); image out = get_network_image(net); int dh = (im.h - h)/(out.h-1); @@ -692,7 +753,7 @@ void visualize_imagenet_features(char *filename) image im = load_image(image_path, 0, 0); printf("Processing %dx%d image\n", im.h, im.w); resize_network(net, im.h, im.w, im.c); - forward_network(net, im.data, 0); + forward_network(net, im.data, 0, 0); image out = get_network_image(net); int dh = (im.h - h)/h; @@ -725,7 +786,7 @@ void visualize_cat() image im = load_image("data/cat.png", 0, 0); printf("Processing %dx%d image\n", im.h, im.w); resize_network(net, im.h, im.w, im.c); - forward_network(net, im.data, 0); + forward_network(net, im.data, 0, 0); visualize_network(net); cvWaitKey(0); @@ -855,8 +916,9 @@ int main(int argc, char *argv[]) //test_ensemble(); //test_nist_single(); //test_nist(); - train_nist(); - //test_convolutional_layer(); + //train_nist(); + test_convolutional_layer(); + //test_col2im(); //test_cifar10(); //train_cifar10(); //test_vince(); diff --git a/src/col2im.c b/src/col2im.c index c418fa54..65db22a6 100644 --- a/src/col2im.c +++ b/src/col2im.c @@ -80,11 +80,32 @@ void col2im_ongpu(cl_mem data_col, int batch, cl.error = clSetKernelArg(kernel, i++, sizeof(data_im), (void*) &data_im); check_error(cl); - size_t global_size = {channels*height*width*batch}; + size_t global_size = channels*height*width*batch; - clEnqueueNDRangeKernel(queue, kernel, 3, 0, - global_size, 0, 0, 0, 0); + clEnqueueNDRangeKernel(queue, kernel, 1, 0, + &global_size, 0, 0, 0, 0); check_error(cl); } +void col2im_gpu(float *data_col, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, float *data_im) +{ + int height_col = (height - ksize) / stride + 1; + int width_col = (width - ksize) / stride + 1; + int channels_col = channels * ksize * ksize; + + size_t size = height_col*width_col*channels_col*batch; + cl_mem col_gpu = cl_make_array(data_col, size); + size = channels*height*width*batch; + cl_mem im_gpu = cl_make_array(data_im, size); + + col2im_ongpu(col_gpu, batch, channels, height, width, + ksize, stride, pad, im_gpu); + + cl_read_array(im_gpu, data_im, size); + clReleaseMemObject(col_gpu); + clReleaseMemObject(im_gpu); +} + #endif diff --git a/src/col2im.cl b/src/col2im.cl index c8e3b30b..00d8f83b 100644 --- a/src/col2im.cl +++ b/src/col2im.cl @@ -1,41 +1,46 @@ -int index(int row, int col) +__kernel void col2im(__global float *data_col, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, __global float *data_im) { - -} - -__kernel void col2im(__global float *data_col, int batch, - int channels, int height, int width, - int ksize, int stride, int pad, __global float *data_im) -{ - int id = get_global_id(0); - int index = id; - int w = id%width; - id /= width; - int h = id%height; - id /= height; - int c = id%channels; - id /= channels; - int b = id%batch; int height_col = (height - ksize) / stride + 1; int width_col = (width - ksize) / stride + 1; - int rows = channels * ksize * ksize; if (pad){ height_col = 1 + (height-1) / stride; width_col = 1 + (width-1) / stride; pad = ksize/2; } + + int id = get_global_id(0); + int index = id; + int w = id%width + pad; + id /= width; + int h = id%height + pad; + id /= height; + int c = id%channels; + id /= channels; + int b = id%batch; + + int w_start = (wdelta = calloc(batch*outputs, sizeof(float*)); layer->weight_updates = calloc(inputs*outputs, sizeof(float)); - layer->weight_adapt = calloc(inputs*outputs, sizeof(float)); + //layer->weight_adapt = calloc(inputs*outputs, sizeof(float)); layer->weight_momentum = calloc(inputs*outputs, sizeof(float)); layer->weights = calloc(inputs*outputs, sizeof(float)); float scale = 1./inputs; @@ -34,13 +34,16 @@ connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVA layer->weights[i] = scale*2*(rand_uniform()-.5); layer->bias_updates = calloc(outputs, sizeof(float)); - layer->bias_adapt = calloc(outputs, sizeof(float)); + //layer->bias_adapt = calloc(outputs, sizeof(float)); layer->bias_momentum = calloc(outputs, sizeof(float)); layer->biases = calloc(outputs, sizeof(float)); - for(i = 0; i < outputs; ++i) + for(i = 0; i < outputs; ++i){ //layer->biases[i] = rand_normal()*scale + scale; layer->biases[i] = 1; + } + #ifdef GPU + #endif layer->activation = activation; return layer; } diff --git a/src/connected_layer.h b/src/connected_layer.h index e9e461c5..43226594 100644 --- a/src/connected_layer.h +++ b/src/connected_layer.h @@ -2,6 +2,7 @@ #define CONNECTED_LAYER_H #include "activations.h" +#include "opencl.h" typedef struct{ float learning_rate; @@ -26,6 +27,19 @@ typedef struct{ float *output; float *delta; + #ifdef GPU + cl_mem weights_cl; + cl_mem biases_cl; + + cl_mem weight_updates_cl; + cl_mem bias_updates_cl; + + cl_mem weight_momentum_cl; + cl_mem bias_momentum_cl; + + cl_mem output_cl; + cl_mem delta_cl; + #endif ACTIVATION activation; } connected_layer; diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index bdbfbfd9..00de1538 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -195,13 +195,14 @@ void backward_convolutional_layer(convolutional_layer layer, float *delta) b = layer.delta; c = layer.col_image; - memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float)); - for(i = 0; i < layer.batch; ++i){ gemm(1,0,m,n,k,1,a,m,b,n,0,c,n); b += k*n; c += m*n; } + + memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float)); + col2im_cpu(layer.col_image, layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, delta); } } @@ -361,7 +362,7 @@ void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in) clReleaseMemObject(c); } activate_array_ongpu(layer.output_cl, m*n*layer.batch, layer.activation); - cl_read_array(layer.output_cl, layer.output, m*n*layer.batch); + //cl_read_array(layer.output_cl, layer.output, m*n*layer.batch); } void backward_convolutional_layer_gpu(convolutional_layer layer, cl_mem delta_cl) @@ -384,9 +385,7 @@ void backward_convolutional_layer_gpu(convolutional_layer layer, cl_mem delta_cl clReleaseMemObject(a); clReleaseMemObject(b); } - cl_read_array(layer.filter_updates_cl, layer.filter_updates, m*n); - cl_read_array(layer.bias_updates_cl, layer.bias_updates, m); - + //cl_read_array(layer.delta_cl, layer.delta, m*k*layer.batch); if(delta_cl){ m = layer.size*layer.size*layer.c; @@ -395,17 +394,31 @@ void backward_convolutional_layer_gpu(convolutional_layer layer, cl_mem delta_cl convolutional_out_width(layer); for(i = 0; i < layer.batch; ++i){ - a = layer.filters_cl; - b = cl_sub_array(layer.delta_cl, i*k*n, k*n); - c = cl_sub_array(layer.col_image_cl, i*m*n, m*n); + cl_mem a = layer.filters_cl; + cl_mem b = cl_sub_array(layer.delta_cl, i*k*n, k*n); + cl_mem c = cl_sub_array(layer.col_image_cl, i*m*n, m*n); gemm_ongpu(1,0,m,n,k,1,a,m,b,n,0,c,n); clReleaseMemObject(b); clReleaseMemObject(c); } - col2im_gpu(layer.col_image_cl, layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, delta_cl); + + scal_ongpu(layer.batch*layer.h*layer.w*layer.c,0,delta_cl, 1); + col2im_ongpu(layer.col_image_cl, layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, delta_cl); } } +void update_convolutional_layer_gpu(convolutional_layer layer) +{ + int size = layer.size*layer.size*layer.c*layer.n; + axpy_ongpu(layer.n, layer.learning_rate, layer.bias_updates_cl, 1, layer.biases_cl, 1); + scal_ongpu(layer.n,layer.momentum, layer.bias_updates_cl, 1); + + scal_ongpu(size, 1.-layer.learning_rate*layer.decay, layer.filters_cl, 1); + axpy_ongpu(size, layer.learning_rate, layer.filter_updates_cl, 1, layer.filters_cl, 1); + scal_ongpu(size, layer.momentum, layer.filter_updates_cl, 1); +} + + #endif diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index cf897a74..465d3091 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -1,10 +1,7 @@ #ifndef CONVOLUTIONAL_LAYER_H #define CONVOLUTIONAL_LAYER_H -#ifdef GPU #include "opencl.h" -#endif - #include "image.h" #include "activations.h" @@ -51,6 +48,7 @@ typedef struct { #ifdef GPU void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in); void backward_convolutional_layer_gpu(convolutional_layer layer, cl_mem delta_cl); +void update_convolutional_layer_gpu(convolutional_layer layer); #endif convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, float learning_rate, float momentum, float decay); diff --git a/src/cost_layer.c b/src/cost_layer.c new file mode 100644 index 00000000..dd0ff905 --- /dev/null +++ b/src/cost_layer.c @@ -0,0 +1,49 @@ +#include "cost_layer.h" +#include "mini_blas.h" +#include +#include +#include + +cost_layer *make_cost_layer(int batch, int inputs) +{ + fprintf(stderr, "Cost Layer: %d inputs\n", inputs); + cost_layer *layer = calloc(1, sizeof(cost_layer)); + layer->batch = batch; + layer->inputs = inputs; + layer->delta = calloc(inputs*batch, sizeof(float)); + layer->output = calloc(1, sizeof(float)); + #ifdef GPU + layer->delta_cl = cl_make_array(layer->delta, inputs*batch); + #endif + return layer; +} + +void forward_cost_layer(cost_layer layer, float *input, float *truth) +{ + if (!truth) return; + copy_cpu(layer.batch*layer.inputs, truth, 1, layer.delta, 1); + axpy_cpu(layer.batch*layer.inputs, -1, input, 1, layer.delta, 1); + *(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1); +} + +void backward_cost_layer(const cost_layer layer, float *input, float *delta) +{ + copy_cpu(layer.batch*layer.inputs, layer.delta, 1, delta, 1); +} + +#ifdef GPU +void forward_cost_layer_gpu(cost_layer layer, cl_mem input, cl_mem truth) +{ + if (!truth) return; + copy_ongpu(layer.batch*layer.inputs, truth, 1, layer.delta_cl, 1); + axpy_ongpu(layer.batch*layer.inputs, -1, input, 1, layer.delta_cl, 1); + cl_read_array(layer.delta_cl, layer.delta, layer.batch*layer.inputs); + *(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1); +} + +void backward_cost_layer_gpu(const cost_layer layer, cl_mem input, cl_mem delta) +{ + copy_ongpu(layer.batch*layer.inputs, layer.delta_cl, 1, delta, 1); +} +#endif + diff --git a/src/cost_layer.h b/src/cost_layer.h new file mode 100644 index 00000000..edda8f95 --- /dev/null +++ b/src/cost_layer.h @@ -0,0 +1,24 @@ +#ifndef COST_LAYER_H +#define COST_LAYER_H +#include "opencl.h" + +typedef struct { + int inputs; + int batch; + float *delta; + float *output; + #ifdef GPU + cl_mem delta_cl; + #endif +} cost_layer; + +cost_layer *make_cost_layer(int batch, int inputs); +void forward_cost_layer(const cost_layer layer, float *input, float *truth); +void backward_cost_layer(const cost_layer layer, float *input, float *delta); + +#ifdef GPU +void forward_cost_layer_gpu(cost_layer layer, cl_mem input, cl_mem truth); +void backward_cost_layer_gpu(const cost_layer layer, cl_mem input, cl_mem delta); +#endif + +#endif diff --git a/src/freeweight_layer.c b/src/freeweight_layer.c new file mode 100644 index 00000000..2cc805a7 --- /dev/null +++ b/src/freeweight_layer.c @@ -0,0 +1,24 @@ +#include "freeweight_layer.h" +#include "stdlib.h" +#include "stdio.h" + +freeweight_layer *make_freeweight_layer(int batch, int inputs) +{ + fprintf(stderr, "Freeweight Layer: %d inputs\n", inputs); + freeweight_layer *layer = calloc(1, sizeof(freeweight_layer)); + layer->inputs = inputs; + layer->batch = batch; + return layer; +} + +void forward_freeweight_layer(freeweight_layer layer, float *input) +{ + int i; + for(i = 0; i < layer.batch * layer.inputs; ++i){ + input[i] *= 2.*((float)rand()/RAND_MAX); + } +} +void backward_freeweight_layer(freeweight_layer layer, float *input, float *delta) +{ + // Don't do shit LULZ +} diff --git a/src/freeweight_layer.h b/src/freeweight_layer.h new file mode 100644 index 00000000..bfca2c19 --- /dev/null +++ b/src/freeweight_layer.h @@ -0,0 +1,14 @@ +#ifndef FREEWEIGHT_LAYER_H +#define FREEWEIGHT_LAYER_H + +typedef struct{ + int batch; + int inputs; +} freeweight_layer; + +freeweight_layer *make_freeweight_layer(int batch, int inputs); + +void forward_freeweight_layer(freeweight_layer layer, float *input); +void backward_freeweight_layer(freeweight_layer layer, float *input, float *delta); + +#endif diff --git a/src/im2col.cl b/src/im2col.cl index 6ed5d89c..877ee529 100644 --- a/src/im2col.cl +++ b/src/im2col.cl @@ -1,4 +1,3 @@ - float im2col_get_pixel(__global float *im, int height, int width, int channels, int batch, int row, int col, int channel, int pad) { diff --git a/src/mini_blas.h b/src/mini_blas.h index 34905a17..a155c351 100644 --- a/src/mini_blas.h +++ b/src/mini_blas.h @@ -10,10 +10,16 @@ float *random_matrix(int rows, int cols); void time_random_matrix(int TA, int TB, int m, int k, int n); #ifdef GPU +void axpy_ongpu(int N, float ALPHA, cl_mem X, int INCX, cl_mem Y, int INCY); +void copy_ongpu(int N, cl_mem X, int INCX, cl_mem Y, int INCY); +void scal_ongpu(int N, float ALPHA, cl_mem X, int INCX); void im2col_ongpu(cl_mem data_im, int batch, int channels, int height, int width, int ksize, int stride, int pad, cl_mem data_col); +void col2im_gpu(float *data_col, int batch, + int channels, int height, int width, + int ksize, int stride, int pad, float *data_im); void col2im_ongpu(cl_mem data_col, int batch, int channels, int height, int width, int ksize, int stride, int pad, cl_mem data_im); @@ -49,6 +55,8 @@ void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA, float *B, int ldb, float BETA, float *C, int ldc); -void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY); -void scal_cpu(int N, float ALPHA, float *X, int INCX); +inline void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY); +inline void copy_cpu(int N, float *X, int INCX, float *Y, int INCY); +inline void scal_cpu(int N, float ALPHA, float *X, int INCX); +inline float dot_cpu(int N, float *X, int INCX, float *Y, int INCY); void test_gpu_blas(); diff --git a/src/network.c b/src/network.c index 3761bf92..58331667 100644 --- a/src/network.c +++ b/src/network.c @@ -8,7 +8,9 @@ #include "connected_layer.h" #include "convolutional_layer.h" #include "maxpool_layer.h" +#include "cost_layer.h" #include "normalization_layer.h" +#include "freeweight_layer.h" #include "softmax_layer.h" #include "dropout_layer.h" @@ -28,14 +30,18 @@ network make_network(int n, int batch) } #ifdef GPU -void forward_network_gpu(network net, cl_mem input_cl, int train) +void forward_network_gpu(network net, cl_mem input, cl_mem truth, int train) { int i; for(i = 0; i < net.n; ++i){ if(net.types[i] == CONVOLUTIONAL){ convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - forward_convolutional_layer_gpu(layer, input_cl); - input_cl = layer.output_cl; + forward_convolutional_layer_gpu(layer, input); + input = layer.output_cl; + } + else if(net.types[i] == COST){ + cost_layer layer = *(cost_layer *)net.layers[i]; + forward_cost_layer_gpu(layer, input, truth); } /* else if(net.types[i] == CONNECTED){ @@ -67,9 +73,75 @@ void forward_network_gpu(network net, cl_mem input_cl, int train) } } +void backward_network_gpu(network net, cl_mem input) +{ + int i; + cl_mem prev_input; + cl_mem prev_delta; + for(i = net.n-1; i >= 0; --i){ + if(i == 0){ + prev_input = input; + prev_delta = 0; + }else{ + prev_input = get_network_output_cl_layer(net, i-1); + prev_delta = get_network_delta_cl_layer(net, i-1); + } + if(net.types[i] == CONVOLUTIONAL){ + convolutional_layer layer = *(convolutional_layer *)net.layers[i]; + backward_convolutional_layer_gpu(layer, prev_delta); + } + else if(net.types[i] == COST){ + cost_layer layer = *(cost_layer *)net.layers[i]; + backward_cost_layer_gpu(layer, prev_input, prev_delta); + } + } +} + +void update_network_gpu(network net) +{ + int i; + for(i = 0; i < net.n; ++i){ + if(net.types[i] == CONVOLUTIONAL){ + convolutional_layer layer = *(convolutional_layer *)net.layers[i]; + update_convolutional_layer_gpu(layer); + } + else if(net.types[i] == MAXPOOL){ + //maxpool_layer layer = *(maxpool_layer *)net.layers[i]; + } + else if(net.types[i] == SOFTMAX){ + //maxpool_layer layer = *(maxpool_layer *)net.layers[i]; + } + else if(net.types[i] == NORMALIZATION){ + //maxpool_layer layer = *(maxpool_layer *)net.layers[i]; + } + else if(net.types[i] == CONNECTED){ + connected_layer layer = *(connected_layer *)net.layers[i]; + update_connected_layer(layer); + } + } +} + +cl_mem get_network_output_cl_layer(network net, int i) +{ + if(net.types[i] == CONVOLUTIONAL){ + convolutional_layer layer = *(convolutional_layer *)net.layers[i]; + return layer.output_cl; + } + return 0; +} + +cl_mem get_network_delta_cl_layer(network net, int i) +{ + if(net.types[i] == CONVOLUTIONAL){ + convolutional_layer layer = *(convolutional_layer *)net.layers[i]; + return layer.delta_cl; + } + return 0; +} + #endif -void forward_network(network net, float *input, int train) +void forward_network(network net, float *input, float *truth, int train) { int i; for(i = 0; i < net.n; ++i){ @@ -88,6 +160,10 @@ void forward_network(network net, float *input, int train) forward_crop_layer(layer, input); input = layer.output; } + else if(net.types[i] == COST){ + cost_layer layer = *(cost_layer *)net.layers[i]; + forward_cost_layer(layer, input, truth); + } else if(net.types[i] == SOFTMAX){ softmax_layer layer = *(softmax_layer *)net.layers[i]; forward_softmax_layer(layer, input); @@ -108,6 +184,11 @@ void forward_network(network net, float *input, int train) dropout_layer layer = *(dropout_layer *)net.layers[i]; forward_dropout_layer(layer, input); } + else if(net.types[i] == FREEWEIGHT){ + if(!train) continue; + freeweight_layer layer = *(freeweight_layer *)net.layers[i]; + forward_freeweight_layer(layer, input); + } } } @@ -159,7 +240,9 @@ float *get_network_output_layer(network net, int i) } float *get_network_output(network net) { - return get_network_output_layer(net, net.n-1); + int i; + for(i = net.n-1; i > 0; --i) if(net.types[i] != COST) break; + return get_network_output_layer(net, i); } float *get_network_delta_layer(network net, int i) @@ -182,6 +265,14 @@ float *get_network_delta_layer(network net, int i) return 0; } +float get_network_cost(network net) +{ + if(net.types[net.n-1] == COST){ + return ((cost_layer *)net.layers[net.n-1])->output[0]; + } + return 0; +} + float *get_network_delta(network net) { return get_network_delta_layer(net, net.n-1); @@ -212,9 +303,8 @@ int get_predicted_class_network(network net) return max_index(out, k); } -float backward_network(network net, float *input, float *truth) +void backward_network(network net, float *input) { - float error = calculate_error_network(net, truth); int i; float *prev_input; float *prev_delta; @@ -246,15 +336,19 @@ float backward_network(network net, float *input, float *truth) connected_layer layer = *(connected_layer *)net.layers[i]; backward_connected_layer(layer, prev_input, prev_delta); } + else if(net.types[i] == COST){ + cost_layer layer = *(cost_layer *)net.layers[i]; + backward_cost_layer(layer, prev_input, prev_delta); + } } - return error; } float train_network_datum(network net, float *x, float *y) { - forward_network(net, x, 1); + forward_network(net, x, y, 1); //int class = get_predicted_class_network(net); - float error = backward_network(net, x, y); + backward_network(net, x); + float error = get_network_cost(net); update_network(net); //return (y[class]?1:0); return error; @@ -287,8 +381,9 @@ float train_network_batch(network net, data d, int n) int index = rand()%d.X.rows; float *x = d.X.vals[index]; float *y = d.y.vals[index]; - forward_network(net, x, 1); - sum += backward_network(net, x, y); + forward_network(net, x, y, 1); + backward_network(net, x); + sum += get_network_cost(net); } update_network(net); } @@ -351,7 +446,8 @@ int get_network_output_size_layer(network net, int i) else if(net.types[i] == CONNECTED){ connected_layer layer = *(connected_layer *)net.layers[i]; return layer.outputs; - } else if(net.types[i] == DROPOUT){ + } + else if(net.types[i] == DROPOUT){ dropout_layer layer = *(dropout_layer *) net.layers[i]; return layer.inputs; } @@ -396,7 +492,8 @@ int resize_network(network net, int h, int w, int c) int get_network_output_size(network net) { - int i = net.n-1; + int i; + for(i = net.n-1; i > 0; --i) if(net.types[i] != COST) break; return get_network_output_size_layer(net, i); } @@ -457,7 +554,7 @@ void visualize_network(network net) float *network_predict(network net, float *input) { - forward_network(net, input, 0); + forward_network(net, input, 0, 0); float *out = get_network_output(net); return out; } diff --git a/src/network.h b/src/network.h index 65ace579..37c145db 100644 --- a/src/network.h +++ b/src/network.h @@ -13,7 +13,9 @@ typedef enum { SOFTMAX, NORMALIZATION, DROPOUT, - CROP + FREEWEIGHT, + CROP, + COST } LAYER_TYPE; typedef struct { @@ -34,12 +36,16 @@ typedef struct { } network; #ifdef GPU -void forward_network_gpu(network net, cl_mem input, int train); +void forward_network_gpu(network net, cl_mem input, cl_mem truth, int train); +void backward_network_gpu(network net, cl_mem input); +void update_network_gpu(network net); +cl_mem get_network_output_cl_layer(network net, int i); +cl_mem get_network_delta_cl_layer(network net, int i); #endif network make_network(int n, int batch); -void forward_network(network net, float *input, int train); -float backward_network(network net, float *input, float *truth); +void forward_network(network net, float *input, float *truth, int train); +void backward_network(network net, float *input); void update_network(network net); float train_network_sgd(network net, data d, int n); float train_network_batch(network net, data d, int n); @@ -60,6 +66,7 @@ void print_network(network net); void visualize_network(network net); int resize_network(network net, int h, int w, int c); int get_network_input_size(network net); +float get_network_cost(network net); #endif diff --git a/src/opencl.c b/src/opencl.c index bcc0f09e..5aec33c9 100644 --- a/src/opencl.c +++ b/src/opencl.c @@ -1,11 +1,12 @@ #ifdef GPU -#include "opencl.h" #include #include #include #include #include +#include "opencl.h" +#include "utils.h" cl_info cl = {0}; @@ -103,6 +104,7 @@ cl_program cl_fprog(char *filename, char *options, cl_info info) char src[64*1024]; memset(src, 0, 64*1024); FILE *fil=fopen(filename,"r"); + if(fil == 0) file_error(filename); srcsize=fread(src, sizeof src, 1, fil); fclose(fil); const char *srcptr[]={src}; diff --git a/src/parser.c b/src/parser.c index 5c991a54..9bd2eb76 100644 --- a/src/parser.c +++ b/src/parser.c @@ -5,12 +5,14 @@ #include "parser.h" #include "activations.h" #include "crop_layer.h" +#include "cost_layer.h" #include "convolutional_layer.h" #include "connected_layer.h" #include "maxpool_layer.h" #include "normalization_layer.h" #include "softmax_layer.h" #include "dropout_layer.h" +#include "freeweight_layer.h" #include "list.h" #include "option_list.h" #include "utils.h" @@ -24,8 +26,10 @@ int is_convolutional(section *s); int is_connected(section *s); int is_maxpool(section *s); int is_dropout(section *s); +int is_freeweight(section *s); int is_softmax(section *s); int is_crop(section *s); +int is_cost(section *s); int is_normalization(section *s); list *read_cfg(char *filename); @@ -182,6 +186,20 @@ softmax_layer *parse_softmax(list *options, network *net, int count) return layer; } +cost_layer *parse_cost(list *options, network *net, int count) +{ + int input; + if(count == 0){ + input = option_find_int(options, "input",1); + net->batch = option_find_int(options, "batch",1); + }else{ + input = get_network_output_size_layer(*net, count-1); + } + cost_layer *layer = make_cost_layer(net->batch, input); + option_unused(options); + return layer; +} + crop_layer *parse_crop(list *options, network *net, int count) { float learning_rate, momentum, decay; @@ -234,6 +252,20 @@ maxpool_layer *parse_maxpool(list *options, network *net, int count) return layer; } +freeweight_layer *parse_freeweight(list *options, network *net, int count) +{ + int input; + if(count == 0){ + net->batch = option_find_int(options, "batch",1); + input = option_find_int(options, "input",1); + }else{ + input = get_network_output_size_layer(*net, count-1); + } + freeweight_layer *layer = make_freeweight_layer(net->batch,input); + option_unused(options); + return layer; +} + dropout_layer *parse_dropout(list *options, network *net, int count) { int input; @@ -295,6 +327,10 @@ network parse_network_cfg(char *filename) crop_layer *layer = parse_crop(options, &net, count); net.types[count] = CROP; net.layers[count] = layer; + }else if(is_cost(s)){ + cost_layer *layer = parse_cost(options, &net, count); + net.types[count] = COST; + net.layers[count] = layer; }else if(is_softmax(s)){ softmax_layer *layer = parse_softmax(options, &net, count); net.types[count] = SOFTMAX; @@ -311,6 +347,10 @@ network parse_network_cfg(char *filename) dropout_layer *layer = parse_dropout(options, &net, count); net.types[count] = DROPOUT; net.layers[count] = layer; + }else if(is_freeweight(s)){ + freeweight_layer *layer = parse_freeweight(options, &net, count); + net.types[count] = FREEWEIGHT; + net.layers[count] = layer; }else{ fprintf(stderr, "Type not recognized: %s\n", s->type); } @@ -328,6 +368,10 @@ int is_crop(section *s) { return (strcmp(s->type, "[crop]")==0); } +int is_cost(section *s) +{ + return (strcmp(s->type, "[cost]")==0); +} int is_convolutional(section *s) { return (strcmp(s->type, "[conv]")==0 @@ -347,6 +391,10 @@ int is_dropout(section *s) { return (strcmp(s->type, "[dropout]")==0); } +int is_freeweight(section *s) +{ + return (strcmp(s->type, "[freeweight]")==0); +} int is_softmax(section *s) { @@ -447,6 +495,25 @@ void print_convolutional_cfg(FILE *fp, convolutional_layer *l, network net, int for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]); fprintf(fp, "\n\n"); } + +void print_freeweight_cfg(FILE *fp, freeweight_layer *l, network net, int count) +{ + fprintf(fp, "[freeweight]\n"); + if(count == 0){ + fprintf(fp, "batch=%d\ninput=%d\n",l->batch, l->inputs); + } + fprintf(fp, "\n"); +} + +void print_dropout_cfg(FILE *fp, dropout_layer *l, network net, int count) +{ + fprintf(fp, "[dropout]\n"); + if(count == 0){ + fprintf(fp, "batch=%d\ninput=%d\n", l->batch, l->inputs); + } + fprintf(fp, "probability=%g\n\n", l->probability); +} + void print_connected_cfg(FILE *fp, connected_layer *l, network net, int count) { int i; @@ -526,6 +593,14 @@ void print_softmax_cfg(FILE *fp, softmax_layer *l, network net, int count) fprintf(fp, "\n"); } +void print_cost_cfg(FILE *fp, cost_layer *l, network net, int count) +{ + fprintf(fp, "[cost]\n"); + if(count == 0) fprintf(fp, "batch=%d\ninput=%d\n", l->batch, l->inputs); + fprintf(fp, "\n"); +} + + void save_network(network net, char *filename) { FILE *fp = fopen(filename, "w"); @@ -541,10 +616,16 @@ void save_network(network net, char *filename) print_crop_cfg(fp, (crop_layer *)net.layers[i], net, i); else if(net.types[i] == MAXPOOL) print_maxpool_cfg(fp, (maxpool_layer *)net.layers[i], net, i); + else if(net.types[i] == FREEWEIGHT) + print_freeweight_cfg(fp, (freeweight_layer *)net.layers[i], net, i); + else if(net.types[i] == DROPOUT) + print_dropout_cfg(fp, (dropout_layer *)net.layers[i], net, i); else if(net.types[i] == NORMALIZATION) print_normalization_cfg(fp, (normalization_layer *)net.layers[i], net, i); else if(net.types[i] == SOFTMAX) print_softmax_cfg(fp, (softmax_layer *)net.layers[i], net, i); + else if(net.types[i] == COST) + print_cost_cfg(fp, (cost_layer *)net.layers[i], net, i); } fclose(fp); }