diff --git a/Makefile b/Makefile index ee382b47..9c3043b0 100644 --- a/Makefile +++ b/Makefile @@ -23,19 +23,21 @@ CFLAGS= $(COMMON) $(OPTS) LDFLAGS+=`pkg-config --libs opencv` -lm VPATH=./src/ EXEC=cnn +OBJDIR=./obj/ -OBJ=network.o image.o tests.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o convolutional_layer.o gemm.o normalization_layer.o opencl.o im2col.o col2im.o axpy.o +OBJ=network.o image.o cnn.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o convolutional_layer.o gemm.o normalization_layer.o opencl.o im2col.o col2im.o axpy.o +OBJS = $(addprefix $(OBJDIR), $(OBJ)) all: $(EXEC) -$(EXEC): $(OBJ) +$(EXEC): $(OBJS) $(CC) $(CFLAGS) $(LDFLAGS) $^ -o $@ -%.o: %.c +$(OBJDIR)%.o: %.c $(CC) $(CFLAGS) -c $< -o $@ .PHONY: clean clean: - rm -rf $(OBJ) $(EXEC) + rm -rf $(OBJS) $(EXEC) diff --git a/src/activations.cl b/src/activations.cl index 19428b1c..6ab135a1 100644 --- a/src/activations.cl +++ b/src/activations.cl @@ -2,6 +2,12 @@ typedef enum{ SIGMOID, RELU, LINEAR, RAMP, TANH }ACTIVATION; +float linear_activate(float x){return x;} +float sigmoid_activate(float x){return 1./(1. + exp(-x));} +float relu_activate(float x){return x*(x>0);} +float ramp_activate(float x){return x*(x>0)+.1*x;} +float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);} + float activate(float x, ACTIVATION a, float dropout) { //if((float)rand()/RAND_MAX < dropout) return 0; diff --git a/src/tests.c b/src/cnn.c similarity index 95% rename from src/tests.c rename to src/cnn.c index 81054046..96b9463c 100644 --- a/src/tests.c +++ b/src/cnn.c @@ -52,7 +52,7 @@ void test_convolve_matrix() int i; clock_t start = clock(), end; for(i = 0; i < 1000; ++i){ - im2col_cpu(dog.data, 1, dog.c, dog.h, dog.w, size, stride, matrix); + im2col_cpu(dog.data, 1, dog.c, dog.h, dog.w, size, stride, 0, matrix); gemm(0,0,n,mw,mh,1,filters,mh,matrix,mw,1,edge.data,mw); } end = clock(); @@ -76,7 +76,7 @@ void verify_convolutional_layer() int size = 3; float eps = .00000001; image test = make_random_image(5,5, 1); - convolutional_layer layer = *make_convolutional_layer(1,test.h,test.w,test.c, n, size, stride, RELU); + convolutional_layer layer = *make_convolutional_layer(1,test.h,test.w,test.c, n, size, stride, 0, RELU); image out = get_convolutional_image(layer); float **jacobian = calloc(test.h*test.w*test.c, sizeof(float)); @@ -301,7 +301,7 @@ void test_vince() void test_nist() { srand(444444); - srand(888888); + srand(222222); network net = parse_network_cfg("cfg/nist.cfg"); data train = load_categorical_data_csv("data/mnist/mnist_train.csv", 0, 10); data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10); @@ -309,22 +309,26 @@ void test_nist() normalize_data_rows(test); //randomize_data(train); int count = 0; - float lr = .00005; + float lr = .000075; float momentum = .9; float decay = 0.0001; decay = 0; //clock_t start = clock(), end; - int batch = 10000; - while(++count <= 10000){ - float loss = train_network_sgd(net, train, batch, lr, momentum, decay); + int iters = 100; + while(++count <= 10){ + clock_t start = clock(), end; + float loss = train_network_sgd(net, train, iters, lr, momentum, decay); + end = clock(); float test_acc = network_accuracy(net, test); - printf("%3d %5f %5f\n",count, loss, test_acc); + printf("%d: %f %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC, lr, momentum, decay); + //printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*1000, loss, lr, momentum, decay); //end = clock(); //printf("Time: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); //start=end; //lr *= .5; } + //save_network(net, "cfg/nist_basic_trained.cfg"); } void test_ensemble() @@ -431,7 +435,7 @@ void test_im2row() float *matrix = calloc(msize, sizeof(float)); int i; for(i = 0; i < 1000; ++i){ - im2col_cpu(test.data, 1, c, h, w, size, stride, matrix); + im2col_cpu(test.data, 1, c, h, w, size, stride, 0, matrix); //image render = float_to_image(mh, mw, mc, matrix); } } @@ -442,34 +446,36 @@ void flip_network() save_network(net, "cfg/voc_imagenet_rev.cfg"); } -void train_VOC() +void tune_VOC() { network net = parse_network_cfg("cfg/voc_start.cfg"); srand(2222222); int i = 20; char *labels[] = {"aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair","cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep","sofa","train","tvmonitor"}; - float lr = .00001; + float lr = .000005; float momentum = .9; - float decay = 0.01; + float decay = 0.0001; while(i++ < 1000 || 1){ - data train = load_data_image_pathfile_random("images/VOC2012/val_paths.txt", 1000, labels, 20, 300, 400); + data train = load_data_image_pathfile_random("/home/pjreddie/VOC2012/trainval_paths.txt", 10, labels, 20, 256, 256); - image im = float_to_image(300, 400, 3,train.X.vals[0]); + image im = float_to_image(256, 256, 3,train.X.vals[0]); show_image(im, "input"); visualize_network(net); cvWaitKey(100); - normalize_data_rows(train); + translate_data_rows(train, -144); clock_t start = clock(), end; - float loss = train_network_sgd(net, train, 1000, lr, momentum, decay); + float loss = train_network_sgd(net, train, 10, lr, momentum, decay); end = clock(); printf("%d: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", i, loss, (float)(end-start)/CLOCKS_PER_SEC, lr, momentum, decay); free_data(train); + /* if(i%10==0){ char buff[256]; - sprintf(buff, "cfg/voc_clean_ramp_%d.cfg", i); + sprintf(buff, "/home/pjreddie/voc_cfg/voc_ramp_%d.cfg", i); save_network(net, buff); } + */ //lr *= .99; } } @@ -778,7 +784,7 @@ int main(int argc, char *argv[]) //test_cifar10(); //test_vince(); //test_full(); - //train_VOC(); + //tune_VOC(); //features_VOC_image(argv[1], argv[2], argv[3], 0); //features_VOC_image(argv[1], argv[2], argv[3], 1); //features_VOC_image_size(argv[1], atoi(argv[2]), atoi(argv[3])); diff --git a/src/col2im.c b/src/col2im.c index e69de29b..05205676 100644 --- a/src/col2im.c +++ b/src/col2im.c @@ -0,0 +1,47 @@ +inline void col2im_set_pixel(float *im, int height, int width, int channels, + int row, int col, int channel, int pad, float val) +{ + row -= pad; + col -= pad; + + if (row < 0 || col < 0 || + row >= height || col >= width) return; + im[col + width*(row + channel*height)] = val; +} +//This one might be too, can't remember. +void col2im_cpu(float* data_col, + const int batch, const int channels, const int height, const int width, + const int ksize, const int stride, int pad, float* data_im) +{ + int c,h,w,b; + int height_col = (height - ksize) / stride + 1; + int width_col = (width - ksize) / stride + 1; + if (pad){ + height_col = 1 + (height-1) / stride; + width_col = 1 + (width-1) / stride; + pad = ksize/2; + } + int channels_col = channels * ksize * ksize; + int im_size = height*width*channels; + int col_size = height_col*width_col*channels_col; + for (b = 0; b < batch; ++b) { + for (c = 0; c < channels_col; ++c) { + int w_offset = c % ksize; + int h_offset = (c / ksize) % ksize; + int c_im = c / ksize / ksize; + for (h = 0; h < height_col; ++h) { + for (w = 0; w < width_col; ++w) { + int im_row = h_offset + h * stride; + int im_col = w_offset + w * stride; + double val = data_col[(c * height_col + h) * width_col + w]; + col2im_set_pixel(data_im, height, width, channels, + im_row, im_col, c_im, pad, val); + } + } + } + data_im += im_size; + data_col+= col_size; + } +} + + diff --git a/src/connected_layer.c b/src/connected_layer.c index 72cb3fb1..d9750993 100644 --- a/src/connected_layer.c +++ b/src/connected_layer.c @@ -57,8 +57,11 @@ void update_connected_layer(connected_layer layer, float step, float momentum, f void forward_connected_layer(connected_layer layer, float *input, int train) { + int i; if(!train) layer.dropout = 0; - memcpy(layer.output, layer.biases, layer.outputs*sizeof(float)); + for(i = 0; i < layer.batch; ++i){ + memcpy(layer.output+i*layer.outputs, layer.biases, layer.outputs*sizeof(float)); + } int m = layer.batch; int k = layer.inputs; int n = layer.outputs; @@ -82,16 +85,16 @@ void backward_connected_layer(connected_layer layer, float *input, float *delta) float *a = input; float *b = layer.delta; float *c = layer.weight_updates; - gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); + gemm(1,0,m,n,k,1,a,k,b,n,1,c,n); - m = layer.inputs; + m = layer.batch; k = layer.outputs; - n = layer.batch; + n = layer.inputs; - a = layer.weights; - b = layer.delta; + a = layer.delta; + b = layer.weights; c = delta; - if(c) gemm(0,0,m,n,k,1,a,k,b,n,0,c,n); + if(c) gemm(0,1,m,n,k,1,a,k,b,k,0,c,n); } diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 5aa76ee5..f473aefa 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -5,12 +5,18 @@ int convolutional_out_height(convolutional_layer layer) { - return (layer.h-layer.size)/layer.stride + 1; + int h = layer.h; + if (!layer.pad) h -= layer.size; + else h -= 1; + return h/layer.stride + 1; } int convolutional_out_width(convolutional_layer layer) { - return (layer.w-layer.size)/layer.stride + 1; + int w = layer.w; + if (!layer.pad) w -= layer.size; + else w -= 1; + return w/layer.stride + 1; } image get_convolutional_image(convolutional_layer layer) @@ -31,7 +37,7 @@ image get_convolutional_delta(convolutional_layer layer) return float_to_image(h,w,c,layer.delta); } -convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation) +convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation) { int i; size = 2*(size/2)+1; //HA! And you thought you'd use an even sized filter... @@ -43,6 +49,7 @@ convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, in layer->batch = batch; layer->stride = stride; layer->size = size; + layer->pad = pad; layer->filters = calloc(c*n*size*size, sizeof(float)); layer->filter_updates = calloc(c*n*size*size, sizeof(float)); @@ -64,6 +71,17 @@ convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, in layer->output = calloc(layer->batch*out_h * out_w * n, sizeof(float)); layer->delta = calloc(layer->batch*out_h * out_w * n, sizeof(float)); #ifdef GPU + layer->filters_cl = cl_make_array(layer->filters, c*n*size*size); + layer->filter_updates_cl = cl_make_array(layer->filter_updates, c*n*size*size); + layer->filter_momentum_cl = cl_make_array(layer->filter_momentum, c*n*size*size); + + layer->biases_cl = cl_make_array(layer->biases, n); + layer->bias_updates_cl = cl_make_array(layer->bias_updates, n); + layer->bias_momentum_cl = cl_make_array(layer->bias_momentum, n); + + layer->col_image_cl = cl_make_array(layer->col_image, layer->batch*out_h*out_w*size*size*c); + layer->delta_cl = cl_make_array(layer->delta, layer->batch*out_h*out_w*n); + layer->output_cl = cl_make_array(layer->output, layer->batch*out_h*out_w*n); #endif layer->activation = activation; @@ -91,12 +109,14 @@ void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c) void bias_output(const convolutional_layer layer) { - int i,j; + int i,j,b; int out_h = convolutional_out_height(layer); int out_w = convolutional_out_width(layer); - for(i = 0; i < layer.n; ++i){ - for(j = 0; j < out_h*out_w; ++j){ - layer.output[i*out_h*out_w + j] = layer.biases[i]; + for(b = 0; b < layer.batch; ++b){ + for(i = 0; i < layer.n; ++i){ + for(j = 0; j < out_h*out_w; ++j){ + layer.output[(b*layer.n + i)*out_h*out_w + j] = layer.biases[i]; + } } } } @@ -114,7 +134,7 @@ void forward_convolutional_layer(const convolutional_layer layer, float *in) float *b = layer.col_image; float *c = layer.output; im2col_cpu(in,layer.batch, layer.c, layer.h, layer.w, - layer.size, layer.stride, b); + layer.size, layer.stride, layer.pad, b); bias_output(layer); gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); activate_array(layer.output, m*n, layer.activation, 0.); @@ -169,7 +189,6 @@ void backward_convolutional_layer(convolutional_layer layer, float *delta) gemm(0,1,m,n,k,1,a,k,b,k,1,c,n); if(delta){ - int i; m = layer.size*layer.size*layer.c; k = layer.n; n = convolutional_out_height(layer)* @@ -183,9 +202,7 @@ void backward_convolutional_layer(convolutional_layer layer, float *delta) gemm(1,0,m,n,k,1,a,m,b,n,0,c,n); memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float)); - for(i = 0; i < layer.batch; ++i){ - col2im_cpu(c+i*n/layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, delta+i*n/layer.batch); - } + col2im_cpu(c, layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, delta); } } diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index 2deea62c..e0722f8d 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -14,6 +14,7 @@ typedef struct { int n; int size; int stride; + int pad; float *filters; float *filter_updates; float *filter_momentum; @@ -47,7 +48,7 @@ typedef struct { void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in); #endif -convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation); +convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation); void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c); void forward_convolutional_layer(const convolutional_layer layer, float *in); void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay); diff --git a/src/data.c b/src/data.c index 6d2061ed..a2432af1 100644 --- a/src/data.c +++ b/src/data.c @@ -166,6 +166,14 @@ void scale_data_rows(data d, float s) } } +void translate_data_rows(data d, float s) +{ + int i; + for(i = 0; i < d.X.rows; ++i){ + translate_array(d.X.vals[i], d.X.cols, s); + } +} + void normalize_data_rows(data d) { int i; diff --git a/src/data.h b/src/data.h index dfbbf72f..c639d5fa 100644 --- a/src/data.h +++ b/src/data.h @@ -22,6 +22,7 @@ list *get_paths(char *filename); data load_categorical_data_csv(char *filename, int target, int k); void normalize_data_rows(data d); void scale_data_rows(data d, float s); +void translate_data_rows(data d, float s); void randomize_data(data d); data *split_data(data d, int part, int total); diff --git a/src/detection_layer.c b/src/detection_layer.c new file mode 100644 index 00000000..65370795 --- /dev/null +++ b/src/detection_layer.c @@ -0,0 +1,72 @@ +int detection_out_height(detection_layer layer) +{ + return layer.size + layer.h*layer.stride; +} + +int detection_out_width(detection_layer layer) +{ + return layer.size + layer.w*layer.stride; +} + +detection_layer *make_detection_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation) +{ + int i; + size = 2*(size/2)+1; //HA! And you thought you'd use an even sized filter... + detection_layer *layer = calloc(1, sizeof(detection_layer)); + layer->h = h; + layer->w = w; + layer->c = c; + layer->n = n; + layer->batch = batch; + layer->stride = stride; + layer->size = size; + assert(c%n == 0); + + layer->filters = calloc(c*size*size, sizeof(float)); + layer->filter_updates = calloc(c*size*size, sizeof(float)); + layer->filter_momentum = calloc(c*size*size, sizeof(float)); + + float scale = 1./(size*size*c); + for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*(rand_uniform()); + + int out_h = detection_out_height(*layer); + int out_w = detection_out_width(*layer); + + layer->output = calloc(layer->batch * out_h * out_w * n, sizeof(float)); + layer->delta = calloc(layer->batch * out_h * out_w * n, sizeof(float)); + + layer->activation = activation; + + fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n); + srand(0); + + return layer; +} + +void forward_detection_layer(const detection_layer layer, float *in) +{ + int out_h = detection_out_height(layer); + int out_w = detection_out_width(layer); + int i,j,fh, fw,c; + memset(layer.output, 0, layer->batch*layer->n*out_h*out_w*sizeof(float)); + for(c = 0; c < layer.c; ++c){ + for(i = 0; i < layer.h; ++i){ + for(j = 0; j < layer.w; ++j){ + float val = layer->input[j+(i + c*layer.h)*layer.w]; + for(fh = 0; fh < layer.size; ++fh){ + for(fw = 0; fw < layer.size; ++fw){ + int h = i*layer.stride + fh; + int w = j*layer.stride + fw; + layer.output[w+(h+c/n*out_h)*out_w] += val*layer->filters[fw+(fh+c*layer.size)*layer.size]; + } + } + } + } + } +} + +void backward_detection_layer(const detection_layer layer, float *delta) +{ +} + + diff --git a/src/detection_layer.h b/src/detection_layer.h new file mode 100644 index 00000000..fad0281e --- /dev/null +++ b/src/detection_layer.h @@ -0,0 +1,40 @@ +#ifndef DETECTION_LAYER_H +#define DETECTION_LAYER_H + +typedef struct { + int batch; + int h,w,c; + int n; + int size; + int stride; + + float *filters; + float *filter_updates; + float *filter_momentum; + + float *biases; + float *bias_updates; + float *bias_momentum; + + float *col_image; + float *delta; + float *output; + + #ifdef GPU + cl_mem filters_cl; + cl_mem filter_updates_cl; + cl_mem filter_momentum_cl; + + cl_mem biases_cl; + cl_mem bias_updates_cl; + cl_mem bias_momentum_cl; + + cl_mem col_image_cl; + cl_mem delta_cl; + cl_mem output_cl; + #endif + + ACTIVATION activation; +} convolutional_layer; + +#endif diff --git a/src/gemm.cl b/src/gemm.cl index 91375a77..9e45783b 100644 --- a/src/gemm.cl +++ b/src/gemm.cl @@ -27,8 +27,8 @@ __kernel void gemm(int TA, int TB, int M, int N, int K, float ALPHA, int brow = i + sub_row; int bcol = col_block*BLOCK + sub_col; - Asub[sub_row][sub_col] = TA ? A[arow + acol*lda] : A[arow*lda + acol]; - Bsub[sub_row][sub_col] = TB ? B[brow + bcol*ldb] : B[brow*ldb + bcol]; + if(arow < M && acol < K)Asub[sub_row][sub_col] = TA ? A[arow + acol*lda] : A[arow*lda + acol]; + if(brow < K && bcol < N)Bsub[sub_row][sub_col] = TB ? B[brow + bcol*ldb] : B[brow*ldb + bcol]; barrier(CLK_LOCAL_MEM_FENCE); diff --git a/src/im2col.c b/src/im2col.c index 899f73ad..42bddf52 100644 --- a/src/im2col.c +++ b/src/im2col.c @@ -1,27 +1,45 @@ #include "mini_blas.h" +inline float im2col_get_pixel(float *im, int height, int width, int channels, + int row, int col, int channel, int pad) +{ + row -= pad; + col -= pad; + + if (row < 0 || col < 0 || + row >= height || col >= width) return 0; + return im[col + width*(row + channel*height)]; +} + //From Berkeley Vision's Caffe! //https://github.com/BVLC/caffe/blob/master/LICENSE void im2col_cpu(float* data_im, const int batch, const int channels, const int height, const int width, - const int ksize, const int stride, float* data_col) + const int ksize, const int stride, int pad, float* data_col) { int c,h,w,b; int height_col = (height - ksize) / stride + 1; int width_col = (width - ksize) / stride + 1; + if (pad){ + height_col = 1 + (height-1) / stride; + width_col = 1 + (width-1) / stride; + pad = ksize/2; + } int channels_col = channels * ksize * ksize; int im_size = height*width*channels; int col_size = height_col*width_col*channels_col; - for(b = 0; b < batch; ++b){ - for ( c = 0; c < channels_col; ++c) { + for (b = 0; b < batch; ++b) { + for (c = 0; c < channels_col; ++c) { int w_offset = c % ksize; int h_offset = (c / ksize) % ksize; int c_im = c / ksize / ksize; - for ( h = 0; h < height_col; ++h) { - for ( w = 0; w < width_col; ++w) { + for (h = 0; h < height_col; ++h) { + for (w = 0; w < width_col; ++w) { + int im_row = h_offset + h * stride; + int im_col = w_offset + w * stride; data_col[(c * height_col + h) * width_col + w] = - data_im[(c_im * height + h * stride + h_offset) * width - + w * stride + w_offset]; + im2col_get_pixel(data_im, height, width, channels, + im_row, im_col, c_im, pad); } } } diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c index 413816a6..54a734a8 100644 --- a/src/maxpool_layer.c +++ b/src/maxpool_layer.c @@ -19,7 +19,6 @@ image get_maxpool_delta(maxpool_layer layer) maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int stride) { - c = c*batch; fprintf(stderr, "Maxpool Layer: %d x %d x %d image, %d stride\n", h,w,c,stride); maxpool_layer *layer = calloc(1, sizeof(maxpool_layer)); layer->batch = batch; @@ -27,8 +26,8 @@ maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int stride) layer->w = w; layer->c = c; layer->stride = stride; - layer->output = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(float)); - layer->delta = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(float)); + layer->output = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c*batch, sizeof(float)); + layer->delta = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c*batch, sizeof(float)); return layer; } @@ -37,22 +36,30 @@ void resize_maxpool_layer(maxpool_layer *layer, int h, int w, int c) layer->h = h; layer->w = w; layer->c = c; - layer->output = realloc(layer->output, ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * c * sizeof(float)); - layer->delta = realloc(layer->delta, ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * c * sizeof(float)); + layer->output = realloc(layer->output, ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * c * layer->batch* sizeof(float)); + layer->delta = realloc(layer->delta, ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * c * layer->batch*sizeof(float)); } void forward_maxpool_layer(const maxpool_layer layer, float *in) { - image input = float_to_image(layer.h, layer.w, layer.c, in); - image output = get_maxpool_image(layer); - int i,j,k; - for(i = 0; i < output.h*output.w*output.c; ++i) output.data[i] = -DBL_MAX; - for(k = 0; k < input.c; ++k){ - for(i = 0; i < input.h; ++i){ - for(j = 0; j < input.w; ++j){ - float val = get_pixel(input, i, j, k); - float cur = get_pixel(output, i/layer.stride, j/layer.stride, k); - if(val > cur) set_pixel(output, i/layer.stride, j/layer.stride, k, val); + int b; + for(b = 0; b < layer.batch; ++b){ + image input = float_to_image(layer.h, layer.w, layer.c, in+b*layer.h*layer.w*layer.c); + + int h = (layer.h-1)/layer.stride + 1; + int w = (layer.w-1)/layer.stride + 1; + int c = layer.c; + image output = float_to_image(h,w,c,layer.output+b*h*w*c); + + int i,j,k; + for(i = 0; i < output.h*output.w*output.c; ++i) output.data[i] = -DBL_MAX; + for(k = 0; k < input.c; ++k){ + for(i = 0; i < input.h; ++i){ + for(j = 0; j < input.w; ++j){ + float val = get_pixel(input, i, j, k); + float cur = get_pixel(output, i/layer.stride, j/layer.stride, k); + if(val > cur) set_pixel(output, i/layer.stride, j/layer.stride, k, val); + } } } } @@ -60,21 +67,28 @@ void forward_maxpool_layer(const maxpool_layer layer, float *in) void backward_maxpool_layer(const maxpool_layer layer, float *in, float *delta) { - image input = float_to_image(layer.h, layer.w, layer.c, in); - image input_delta = float_to_image(layer.h, layer.w, layer.c, delta); - image output_delta = get_maxpool_delta(layer); - image output = get_maxpool_image(layer); - int i,j,k; - for(k = 0; k < input.c; ++k){ - for(i = 0; i < input.h; ++i){ - for(j = 0; j < input.w; ++j){ - float val = get_pixel(input, i, j, k); - float cur = get_pixel(output, i/layer.stride, j/layer.stride, k); - float d = get_pixel(output_delta, i/layer.stride, j/layer.stride, k); - if(val == cur) { - set_pixel(input_delta, i, j, k, d); + int b; + for(b = 0; b < layer.batch; ++b){ + image input = float_to_image(layer.h, layer.w, layer.c, in+b*layer.h*layer.w*layer.c); + image input_delta = float_to_image(layer.h, layer.w, layer.c, delta+b*layer.h*layer.w*layer.c); + int h = (layer.h-1)/layer.stride + 1; + int w = (layer.w-1)/layer.stride + 1; + int c = layer.c; + image output = float_to_image(h,w,c,layer.output+b*h*w*c); + image output_delta = float_to_image(h,w,c,layer.delta+b*h*w*c); + + int i,j,k; + for(k = 0; k < input.c; ++k){ + for(i = 0; i < input.h; ++i){ + for(j = 0; j < input.w; ++j){ + float val = get_pixel(input, i, j, k); + float cur = get_pixel(output, i/layer.stride, j/layer.stride, k); + float d = get_pixel(output_delta, i/layer.stride, j/layer.stride, k); + if(val == cur) { + set_pixel(input_delta, i, j, k, d); + } + else set_pixel(input_delta, i, j, k, 0); } - else set_pixel(input_delta, i, j, k, 0); } } } diff --git a/src/mini_blas.c b/src/mini_blas.c index eb6953d7..0227b37c 100644 --- a/src/mini_blas.c +++ b/src/mini_blas.c @@ -17,28 +17,6 @@ void pm(int M, int N, float *A) printf("\n"); } -//This one might be too, can't remember. -void col2im_cpu(float* data_col, const int channels, - const int height, const int width, const int ksize, const int stride, - float* data_im) -{ - int c,h,w; - int height_col = (height - ksize) / stride + 1; - int width_col = (width - ksize) / stride + 1; - int channels_col = channels * ksize * ksize; - for ( c = 0; c < channels_col; ++c) { - int w_offset = c % ksize; - int h_offset = (c / ksize) % ksize; - int c_im = c / ksize / ksize; - for ( h = 0; h < height_col; ++h) { - for ( w = 0; w < width_col; ++w) { - data_im[(c_im * height + h * stride + h_offset) * width - + w * stride + w_offset]+= data_col[(c * height_col + h) * width_col + w]; - } - } - } -} - float *random_matrix(int rows, int cols) { int i; diff --git a/src/mini_blas.h b/src/mini_blas.h index cfbb6c46..bf5debb8 100644 --- a/src/mini_blas.h +++ b/src/mini_blas.h @@ -27,11 +27,11 @@ void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA, void im2col_cpu(float* data_im, const int batch, const int channels, const int height, const int width, - const int ksize, const int stride, float* data_col); + const int ksize, const int stride, int pad, float* data_col); -void col2im_cpu(float* data_col, const int channels, - const int height, const int width, const int ksize, const int stride, - float* data_im); +void col2im_cpu(float* data_col, + const int batch, const int channels, const int height, const int width, + const int ksize, const int stride, int pad, float* data_im); void test_blas(); void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA, diff --git a/src/network.c b/src/network.c index b75eddf1..ef801109 100644 --- a/src/network.c +++ b/src/network.c @@ -113,10 +113,9 @@ void save_network(network net, char *filename) fclose(fp); } +#ifdef GPU void forward_network(network net, float *input, int train) { - int i; - #ifdef GPU cl_setup(); size_t size = get_network_input_size(net); if(!net.input_cl){ @@ -126,16 +125,12 @@ void forward_network(network net, float *input, int train) } cl_write_array(net.input_cl, input, size); cl_mem input_cl = net.input_cl; - #endif + int i; for(i = 0; i < net.n; ++i){ if(net.types[i] == CONVOLUTIONAL){ convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - #ifdef GPU forward_convolutional_layer_gpu(layer, input_cl); input_cl = layer.output_cl; - #else - forward_convolutional_layer(layer, input); - #endif input = layer.output; } else if(net.types[i] == CONNECTED){ @@ -161,6 +156,41 @@ void forward_network(network net, float *input, int train) } } +#else + +void forward_network(network net, float *input, int train) +{ + int i; + for(i = 0; i < net.n; ++i){ + if(net.types[i] == CONVOLUTIONAL){ + convolutional_layer layer = *(convolutional_layer *)net.layers[i]; + forward_convolutional_layer(layer, input); + input = layer.output; + } + else if(net.types[i] == CONNECTED){ + connected_layer layer = *(connected_layer *)net.layers[i]; + forward_connected_layer(layer, input, train); + input = layer.output; + } + else if(net.types[i] == SOFTMAX){ + softmax_layer layer = *(softmax_layer *)net.layers[i]; + forward_softmax_layer(layer, input); + input = layer.output; + } + else if(net.types[i] == MAXPOOL){ + maxpool_layer layer = *(maxpool_layer *)net.layers[i]; + forward_maxpool_layer(layer, input); + input = layer.output; + } + else if(net.types[i] == NORMALIZATION){ + normalization_layer layer = *(normalization_layer *)net.layers[i]; + forward_normalization_layer(layer, input); + input = layer.output; + } + } +} +#endif + void update_network(network net, float step, float momentum, float decay) { int i; @@ -238,9 +268,10 @@ float calculate_error_network(network net, float *truth) float sum = 0; float *delta = get_network_delta(net); float *out = get_network_output(net); - int i, k = get_network_output_size(net); - for(i = 0; i < k; ++i){ - //printf("%f, ", out[i]); + int i; + for(i = 0; i < get_network_output_size(net)*net.batch; ++i){ + //if(i %get_network_output_size(net) == 0) printf("\n"); + //printf("%5.2f %5.2f, ", out[i], truth[i]); delta[i] = truth[i] - out[i]; sum += delta[i]*delta[i]; } @@ -305,20 +336,38 @@ float train_network_datum(network net, float *x, float *y, float step, float mom float train_network_sgd(network net, data d, int n, float step, float momentum,float decay) { - int i; - float error = 0; - int correct = 0; - int pos = 0; + int batch = net.batch; + float *X = calloc(batch*d.X.cols, sizeof(float)); + float *y = calloc(batch*d.y.cols, sizeof(float)); + + int i,j; + float sum = 0; for(i = 0; i < n; ++i){ - int index = rand()%d.X.rows; - float err = train_network_datum(net, d.X.vals[index], d.y.vals[index], step, momentum, decay); + for(j = 0; j < batch; ++j){ + int index = rand()%d.X.rows; + memcpy(X+j*d.X.cols, d.X.vals[index], d.X.cols*sizeof(float)); + memcpy(y+j*d.y.cols, d.y.vals[index], d.y.cols*sizeof(float)); + } + float err = train_network_datum(net, X, y, step, momentum, decay); + sum += err; + //train_network_datum(net, X, y, step, momentum, decay); + /* float *y = d.y.vals[index]; int class = get_predicted_class_network(net); correct += (y[class]?1:0); - if(y[1]){ - error += err; - ++pos; + */ + +/* + for(j = 0; j < d.y.cols*batch; ++j){ + printf("%6.3f ", y[j]); } + printf("\n"); + for(j = 0; j < d.y.cols*batch; ++j){ + printf("%6.3f ", get_network_output(net)[j]); + } + printf("\n"); + printf("\n"); + */ //printf("%d %f %f\n", i,net.output[0], d.y.vals[index][0]); @@ -327,7 +376,9 @@ float train_network_sgd(network net, data d, int n, float step, float momentum,f //} } //printf("Accuracy: %f\n",(float) correct/n); - return error/pos; + free(X); + free(y); + return (float)sum/(n*batch); } float train_network_batch(network net, data d, int n, float step, float momentum,float decay) { @@ -448,7 +499,7 @@ int get_network_output_size(network net) int get_network_input_size(network net) { - return get_network_output_size_layer(net, 0); + return get_network_input_size_layer(net, 0); } image get_network_image_layer(network net, int i) @@ -505,15 +556,24 @@ float *network_predict(network net, float *input) matrix network_predict_data(network net, data test) { - int i,j; + int i,j,b; int k = get_network_output_size(net); matrix pred = make_matrix(test.X.rows, k); - for(i = 0; i < test.X.rows; ++i){ - float *out = network_predict(net, test.X.vals[i]); - for(j = 0; j < k; ++j){ - pred.vals[i][j] = out[j]; + float *X = calloc(net.batch*test.X.rows, sizeof(float)); + for(i = 0; i < test.X.rows; i += net.batch){ + for(b = 0; b < net.batch; ++b){ + if(i+b == test.X.rows) break; + memcpy(X+b*test.X.cols, test.X.vals[i+b], test.X.cols*sizeof(float)); + } + float *out = network_predict(net, X); + for(b = 0; b < net.batch; ++b){ + if(i+b == test.X.rows) break; + for(j = 0; j < k; ++j){ + pred.vals[i+b][j] = out[j+b*k]; + } } } + free(X); return pred; } diff --git a/src/opencl.c b/src/opencl.c index d06c75fa..d78537b4 100644 --- a/src/opencl.c +++ b/src/opencl.c @@ -32,7 +32,8 @@ cl_info cl_init() if(num_devices > MAX_DEVICES) num_devices = MAX_DEVICES; int index = getpid()%num_devices; printf("%d rand, %d devices, %d index\n", getpid(), num_devices, index); - info.device = devices[index]; + //info.device = devices[index]; + info.device = devices[1]; fprintf(stderr, "Found %d device(s)\n", num_devices); check_error(info); @@ -102,4 +103,21 @@ void cl_write_array(cl_mem mem, float *x, int n) check_error(cl); } +void cl_copy_array(cl_mem src, cl_mem dst, int n) +{ + cl_setup(); + clEnqueueCopyBuffer(cl.queue, src, dst, 0, 0, sizeof(float)*n,0,0,0); + check_error(cl); +} + +cl_mem cl_make_array(float *x, int n) +{ + cl_setup(); + cl_mem mem = clCreateBuffer(cl.context, + CL_MEM_READ_WRITE|CL_MEM_COPY_HOST_PTR, + sizeof(float)*n, x, &cl.error); + check_error(cl); + return mem; +} + #endif diff --git a/src/opencl.h b/src/opencl.h index eafb3e72..a7ee0bdb 100644 --- a/src/opencl.h +++ b/src/opencl.h @@ -23,5 +23,7 @@ void check_error(cl_info info); cl_kernel get_kernel(char *filename, char *kernelname, char *options); void cl_read_array(cl_mem mem, float *x, int n); void cl_write_array(cl_mem mem, float *x, int n); +cl_mem cl_make_array(float *x, int n); +void cl_copy_array(cl_mem src, cl_mem dst, int n); #endif #endif diff --git a/src/parser.c b/src/parser.c index 5d6aa1c4..b008882d 100644 --- a/src/parser.c +++ b/src/parser.c @@ -48,6 +48,7 @@ convolutional_layer *parse_convolutional(list *options, network net, int count) int n = option_find_int(options, "filters",1); int size = option_find_int(options, "size",1); int stride = option_find_int(options, "stride",1); + int pad = option_find_int(options, "pad",0); char *activation_s = option_find_str(options, "activation", "sigmoid"); ACTIVATION activation = get_activation(activation_s); if(count == 0){ @@ -62,7 +63,7 @@ convolutional_layer *parse_convolutional(list *options, network net, int count) c = m.c; if(h == 0) error("Layer before convolutional layer must output image."); } - convolutional_layer *layer = make_convolutional_layer(net.batch,h,w,c,n,size,stride, activation); + convolutional_layer *layer = make_convolutional_layer(net.batch,h,w,c,n,size,stride,pad,activation); char *data = option_find_str(options, "data", 0); if(data){ char *curr = data;