diff --git a/Makefile b/Makefile index 4c1bb148..640f3082 100644 --- a/Makefile +++ b/Makefile @@ -1,18 +1,21 @@ CC=gcc COMMON=-Wall `pkg-config --cflags opencv` UNAME = $(shell uname) +OPTS=-O3 ifeq ($(UNAME), Darwin) -COMMON += -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include +COMMON+= -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include +LDFLAGS= -framework OpenCL else -COMMON += -march=native +OPTS+= -march=native -flto +LDFLAGS= -lOpenCL endif -CFLAGS= $(COMMON) -Ofast -flto +CFLAGS= $(COMMON) $(OPTS) #CFLAGS= $(COMMON) -O0 -g -LDFLAGS=`pkg-config --libs opencv` -lm +LDFLAGS+=`pkg-config --libs opencv` -lm VPATH=./src/ EXEC=cnn -OBJ=network.o image.o tests.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o convolutional_layer.o +OBJ=network.o image.o tests.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o convolutional_layer.o opencl.o gpu_gemm.o cpu_gemm.o normalization_layer.o all: $(EXEC) diff --git a/dog.jpg b/dog.jpg deleted file mode 100644 index 3b9f7abd..00000000 Binary files a/dog.jpg and /dev/null differ diff --git a/src/connected_layer.c b/src/connected_layer.c index 07fad695..16a39be9 100644 --- a/src/connected_layer.c +++ b/src/connected_layer.c @@ -7,16 +7,17 @@ #include #include -connected_layer *make_connected_layer(int inputs, int outputs, ACTIVATION activation) +connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation) { fprintf(stderr, "Connected Layer: %d inputs, %d outputs\n", inputs, outputs); int i; connected_layer *layer = calloc(1, sizeof(connected_layer)); layer->inputs = inputs; layer->outputs = outputs; + layer->batch=batch; - layer->output = calloc(outputs, sizeof(float*)); - layer->delta = calloc(outputs, sizeof(float*)); + layer->output = calloc(batch*outputs, sizeof(float*)); + layer->delta = calloc(batch*outputs, sizeof(float*)); layer->weight_updates = calloc(inputs*outputs, sizeof(float)); layer->weight_adapt = calloc(inputs*outputs, sizeof(float)); @@ -78,14 +79,14 @@ void forward_connected_layer(connected_layer layer, float *input) { int i; memcpy(layer.output, layer.biases, layer.outputs*sizeof(float)); - int m = 1; + int m = layer.batch; int k = layer.inputs; int n = layer.outputs; float *a = input; float *b = layer.weights; float *c = layer.output; gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); - for(i = 0; i < layer.outputs; ++i){ + for(i = 0; i < layer.outputs*layer.batch; ++i){ layer.output[i] = activate(layer.output[i], layer.activation); } //for(i = 0; i < layer.outputs; ++i) if(i%(layer.outputs/10+1)==0) printf("%f, ", layer.output[i]); printf("\n"); @@ -94,12 +95,12 @@ void forward_connected_layer(connected_layer layer, float *input) void learn_connected_layer(connected_layer layer, float *input) { int i; - for(i = 0; i < layer.outputs; ++i){ + for(i = 0; i < layer.outputs*layer.batch; ++i){ layer.delta[i] *= gradient(layer.output[i], layer.activation); - layer.bias_updates[i] += layer.delta[i]; + layer.bias_updates[i%layer.batch] += layer.delta[i]/layer.batch; } int m = layer.inputs; - int k = 1; + int k = layer.batch; int n = layer.outputs; float *a = input; float *b = layer.delta; @@ -113,7 +114,7 @@ void backward_connected_layer(connected_layer layer, float *input, float *delta) int m = layer.inputs; int k = layer.outputs; - int n = 1; + int n = layer.batch; float *a = layer.weights; float *b = layer.delta; diff --git a/src/connected_layer.h b/src/connected_layer.h index 4b17c59b..83ae914f 100644 --- a/src/connected_layer.h +++ b/src/connected_layer.h @@ -4,6 +4,7 @@ #include "activations.h" typedef struct{ + int batch; int inputs; int outputs; float *weights; @@ -25,7 +26,7 @@ typedef struct{ } connected_layer; -connected_layer *make_connected_layer(int inputs, int outputs, ACTIVATION activation); +connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation); void forward_connected_layer(connected_layer layer, float *input); void backward_connected_layer(connected_layer layer, float *input, float *delta); diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 8d8efc11..6916eebc 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -31,7 +31,7 @@ image get_convolutional_delta(convolutional_layer layer) return float_to_image(h,w,c,layer.delta); } -convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int size, int stride, ACTIVATION activation) +convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation) { int i; size = 2*(size/2)+1; //HA! And you thought you'd use an even sized filter... @@ -40,6 +40,7 @@ convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int si layer->w = w; layer->c = c; layer->n = n; + layer->batch = batch; layer->stride = stride; layer->size = size; @@ -56,12 +57,12 @@ convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int si //layer->biases[i] = rand_normal()*scale + scale; layer->biases[i] = 0; } - int out_h = (h-size)/stride + 1; - int out_w = (w-size)/stride + 1; + int out_h = convolutional_out_height(*layer); + int out_w = convolutional_out_width(*layer); - layer->col_image = calloc(out_h*out_w*size*size*c, sizeof(float)); - layer->output = calloc(out_h * out_w * n, sizeof(float)); - layer->delta = calloc(out_h * out_w * n, sizeof(float)); + layer->col_image = calloc(layer->batch*out_h*out_w*size*size*c, sizeof(float)); + layer->output = calloc(layer->batch*out_h * out_w * n, sizeof(float)); + layer->delta = calloc(layer->batch*out_h * out_w * n, sizeof(float)); layer->activation = activation; fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n); @@ -70,21 +71,39 @@ convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int si return layer; } +void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c) +{ + layer->h = h; + layer->w = w; + layer->c = c; + int out_h = convolutional_out_height(*layer); + int out_w = convolutional_out_width(*layer); + + layer->col_image = realloc(layer->col_image, + layer->batch*out_h*out_w*layer->size*layer->size*layer->c*sizeof(float)); + layer->output = realloc(layer->output, + layer->batch*out_h * out_w * layer->n*sizeof(float)); + layer->delta = realloc(layer->delta, + layer->batch*out_h * out_w * layer->n*sizeof(float)); +} + void forward_convolutional_layer(const convolutional_layer layer, float *in) { int i; int m = layer.n; int k = layer.size*layer.size*layer.c; - int n = ((layer.h-layer.size)/layer.stride + 1)* - ((layer.w-layer.size)/layer.stride + 1); + int n = convolutional_out_height(layer)* + convolutional_out_width(layer)* + layer.batch; memset(layer.output, 0, m*n*sizeof(float)); float *a = layer.filters; float *b = layer.col_image; float *c = layer.output; - - im2col_cpu(in, layer.c, layer.h, layer.w, layer.size, layer.stride, b); + for(i = 0; i < layer.batch; ++i){ + im2col_cpu(in+i*(n/layer.batch), layer.c, layer.h, layer.w, layer.size, layer.stride, b+i*(n/layer.batch)); + } gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); for(i = 0; i < m*n; ++i){ @@ -97,9 +116,10 @@ void forward_convolutional_layer(const convolutional_layer layer, float *in) void gradient_delta_convolutional_layer(convolutional_layer layer) { int i; - int size = convolutional_out_height(layer) - *convolutional_out_width(layer) - *layer.n; + int size = convolutional_out_height(layer)* + convolutional_out_width(layer)* + layer.n* + layer.batch; for(i = 0; i < size; ++i){ layer.delta[i] *= gradient(layer.output[i], layer.activation); } @@ -107,15 +127,17 @@ void gradient_delta_convolutional_layer(convolutional_layer layer) void learn_bias_convolutional_layer(convolutional_layer layer) { - int i,j; + int i,j,b; int size = convolutional_out_height(layer) *convolutional_out_width(layer); - for(i = 0; i < layer.n; ++i){ - float sum = 0; - for(j = 0; j < size; ++j){ - sum += layer.delta[j+i*size]; + for(b = 0; b < layer.batch; ++b){ + for(i = 0; i < layer.n; ++i){ + float sum = 0; + for(j = 0; j < size; ++j){ + sum += layer.delta[j+size*(i+b*layer.n)]; + } + layer.bias_updates[i] += sum/size; } - layer.bias_updates[i] += sum/size; } } @@ -125,8 +147,9 @@ void learn_convolutional_layer(convolutional_layer layer) learn_bias_convolutional_layer(layer); int m = layer.n; int n = layer.size*layer.size*layer.c; - int k = ((layer.h-layer.size)/layer.stride + 1)* - ((layer.w-layer.size)/layer.stride + 1); + int k = convolutional_out_height(layer)* + convolutional_out_width(layer)* + layer.batch; float *a = layer.delta; float *b = layer.col_image; @@ -137,10 +160,12 @@ void learn_convolutional_layer(convolutional_layer layer) void backward_convolutional_layer(convolutional_layer layer, float *delta) { + int i; int m = layer.size*layer.size*layer.c; int k = layer.n; - int n = ((layer.h-layer.size)/layer.stride + 1)* - ((layer.w-layer.size)/layer.stride + 1); + int n = convolutional_out_height(layer)* + convolutional_out_width(layer)* + layer.batch; float *a = layer.filters; float *b = layer.delta; @@ -150,8 +175,10 @@ void backward_convolutional_layer(convolutional_layer layer, float *delta) memset(c, 0, m*n*sizeof(float)); gemm(1,0,m,n,k,1,a,m,b,n,1,c,n); - memset(delta, 0, layer.h*layer.w*layer.c*sizeof(float)); - col2im_cpu(c, layer.c, layer.h, layer.w, layer.size, layer.stride, delta); + memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float)); + for(i = 0; i < layer.batch; ++i){ + col2im_cpu(c+i*n/layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, delta+i*n/layer.batch); + } } void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay) @@ -225,7 +252,7 @@ void update_convolutional_layer(convolutional_layer layer, float step, float mom void test_convolutional_layer() { - convolutional_layer l = *make_convolutional_layer(4,4,1,1,3,1,LINEAR); + convolutional_layer l = *make_convolutional_layer(1,4,4,1,1,3,1,LINEAR); float input[] = {1,2,3,4, 5,6,7,8, 9,10,11,12, @@ -258,52 +285,48 @@ image get_convolutional_filter(convolutional_layer layer, int i) return float_to_image(h,w,c,layer.filters+i*h*w*c); } -void visualize_convolutional_layer(convolutional_layer layer, char *window) +image *weighted_sum_filters(convolutional_layer layer, image *prev_filters) { - int color = 1; - int border = 1; - int h,w,c; - int size = layer.size; - h = size; - w = (size + border) * layer.n - border; - c = layer.c; - if(c != 3 || !color){ - h = (h+border)*c - border; - c = 1; + image *filters = calloc(layer.n, sizeof(image)); + int i,j,k,c; + if(!prev_filters){ + for(i = 0; i < layer.n; ++i){ + filters[i] = copy_image(get_convolutional_filter(layer, i)); + } } - - image filters = make_image(h,w,c); - int i,j; - for(i = 0; i < layer.n; ++i){ - int w_offset = i*(size+border); - image k = get_convolutional_filter(layer, i); - //printf("%f ** ", layer.biases[i]); - //print_image(k); - image copy = copy_image(k); - normalize_image(copy); - for(j = 0; j < k.c; ++j){ - //set_pixel(copy,0,0,j,layer.biases[i]); - } - if(c == 3 && color){ - embed_image(copy, filters, 0, w_offset); - } - else{ - for(j = 0; j < k.c; ++j){ - int h_offset = j*(size+border); - image layer = get_image_layer(k, j); - embed_image(layer, filters, h_offset, w_offset); - free_image(layer); + else{ + image base = prev_filters[0]; + for(i = 0; i < layer.n; ++i){ + image filter = get_convolutional_filter(layer, i); + filters[i] = make_image(base.h, base.w, base.c); + for(j = 0; j < layer.size; ++j){ + for(k = 0; k < layer.size; ++k){ + for(c = 0; c < layer.c; ++c){ + float weight = get_pixel(filter, j, k, c); + image prev_filter = copy_image(prev_filters[c]); + scale_image(prev_filter, weight); + add_into_image(prev_filter, filters[i], 0,0); + free_image(prev_filter); + } + } } } - free_image(copy); } - image delta = get_convolutional_delta(layer); - image dc = collapse_image_layers(delta, 1); - char buff[256]; - sprintf(buff, "%s: Delta", window); - show_image(dc, buff); - free_image(dc); - show_image(filters, window); - free_image(filters); + return filters; +} + +image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters) +{ + image *single_filters = weighted_sum_filters(layer, 0); + show_images(single_filters, layer.n, window); + + image delta = get_convolutional_image(layer); + image dc = collapse_image_layers(delta, 1); + char buff[256]; + sprintf(buff, "%s: Output", window); + show_image(dc, buff); + save_image(dc, buff); + free_image(dc); + return single_filters; } diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index 8ca69b1b..7404defd 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -5,6 +5,7 @@ #include "activations.h" typedef struct { + int batch; int h,w,c; int n; int size; @@ -24,11 +25,12 @@ typedef struct { ACTIVATION activation; } convolutional_layer; -convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int size, int stride, ACTIVATION activation); +convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation); +void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c); void forward_convolutional_layer(const convolutional_layer layer, float *in); void learn_convolutional_layer(convolutional_layer layer); void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay); -void visualize_convolutional_layer(convolutional_layer layer, char *window); +image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters); void backward_convolutional_layer(convolutional_layer layer, float *delta); diff --git a/src/cpu_gemm.c b/src/cpu_gemm.c new file mode 100644 index 00000000..437b39a4 --- /dev/null +++ b/src/cpu_gemm.c @@ -0,0 +1,86 @@ +#include "mini_blas.h" + +void cpu_gemm_nn(int TA, int TB, int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float BETA, + float *C, int ldc) +{ + int i,j,k; + for(i = 0; i < M; ++i){ + for(k = 0; k < K; ++k){ + register float A_PART = ALPHA*A[i*lda+k]; + for(j = 0; j < N; ++j){ + C[i*ldc+j] += A_PART*B[k*ldb+j]; + } + } + } +} + +void cpu_gemm_nt(int TA, int TB, int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float BETA, + float *C, int ldc) +{ + int i,j,k; + for(i = 0; i < M; ++i){ + for(j = 0; j < N; ++j){ + register float sum = 0; + for(k = 0; k < K; ++k){ + sum += ALPHA*A[i*lda+k]*B[k+j*ldb]; + } + C[i*ldc+j] += sum; + } + } +} + +void cpu_gemm_tn(int TA, int TB, int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float BETA, + float *C, int ldc) +{ + int i,j,k; + for(i = 0; i < M; ++i){ + for(k = 0; k < K; ++k){ + register float A_PART = ALPHA*A[k*lda+i]; + for(j = 0; j < N; ++j){ + C[i*ldc+j] += A_PART*B[k*ldb+j]; + } + } + } +} +void cpu_gemm_tt(int TA, int TB, int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float BETA, + float *C, int ldc) +{ + int i,j,k; + for(i = 0; i < M; ++i){ + for(j = 0; j < N; ++j){ + for(k = 0; k < K; ++k){ + C[i*ldc+j] += ALPHA*A[i+k*lda]*B[k+j*ldb]; + } + } + } +} + + +void cpu_gemm(int TA, int TB, int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float BETA, + float *C, int ldc) +{ + // Assume beta = 1 LULZ + if(!TA && !TB) + cpu_gemm_nn( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); + else if(TA && !TB) + cpu_gemm_tn( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); + else if(!TA && TB) + cpu_gemm_nt( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); + else + cpu_gemm_tt( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); +} diff --git a/src/data.c b/src/data.c index f44f5daf..39ece116 100644 --- a/src/data.c +++ b/src/data.c @@ -119,6 +119,30 @@ data load_categorical_data_csv(char *filename, int target, int k) return d; } +data load_cifar10_data(char *filename) +{ + data d; + d.shallow = 0; + unsigned long i,j; + matrix X = make_matrix(10000, 3072); + matrix y = make_matrix(10000, 10); + d.X = X; + d.y = y; + + FILE *fp = fopen(filename, "rb"); + for(i = 0; i < 10000; ++i){ + unsigned char bytes[3073]; + fread(bytes, 1, 3073, fp); + int class = bytes[0]; + y.vals[i][class] = 1; + for(j = 0; j < X.cols; ++j){ + X.vals[i][j] = (double)bytes[j+1]; + } + } + fclose(fp); + return d; +} + void randomize_data(data d) { int i; diff --git a/src/data.h b/src/data.h index 4df0c687..dfbbf72f 100644 --- a/src/data.h +++ b/src/data.h @@ -17,6 +17,7 @@ data load_data_image_pathfile_part(char *filename, int part, int total, char **labels, int k, int h, int w); data load_data_image_pathfile_random(char *filename, int n, char **labels, int k, int h, int w); +data load_cifar10_data(char *filename); list *get_paths(char *filename); data load_categorical_data_csv(char *filename, int target, int k); void normalize_data_rows(data d); diff --git a/src/gemm.cl b/src/gemm.cl new file mode 100644 index 00000000..7c868f41 --- /dev/null +++ b/src/gemm.cl @@ -0,0 +1,72 @@ + + +__kernel void gemm(int TA, int TB, int M, int N, int K, float ALPHA, + __global float *A, int lda, + __global float *B, int ldb, + float BETA, + __global float *C, int ldc) +{ + __local float Asub[BLOCK][BLOCK]; + __local float Bsub[BLOCK][BLOCK]; + + float val = 0; + + int row_block = get_group_id(0); + int col_block = get_group_id(1); + + int sub_row = get_local_id(0); + int sub_col = get_local_id(1); + + int row = row_block*BLOCK + sub_row; + int col = col_block*BLOCK + sub_col; + + int i,j; + for(i = 0; i < K; i += BLOCK){ + int arow = row_block*BLOCK + sub_row; + int acol = i + sub_col; + + int brow = i + sub_row; + int bcol = col_block*BLOCK + sub_col; + + Asub[sub_row][sub_col] = TA ? A[arow + acol*lda] : A[arow*lda + acol]; + Bsub[sub_row][sub_col] = TB ? B[brow + bcol*ldb] : B[brow*ldb + bcol]; + + barrier(CLK_LOCAL_MEM_FENCE); + + for(j = 0; j < BLOCK && i+j +#include +#include +#include +#include + +#include "opencl.h" +#include "mini_blas.h" + +#define STR_HELPER(x) #x +#define STR(x) STR_HELPER(x) + +#define BLOCK 8 + +cl_kernel get_gemm_kernel() +{ + static int init = 0; + static cl_kernel gemm_kernel; + if(!init){ + gemm_kernel = get_kernel("src/gemm.cl", "gemm", "-D BLOCK=" STR(BLOCK) ); + init = 1; + } + return gemm_kernel; +} + +void gpu_gemm(int TA, int TB, int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float BETA, + float *C, int ldc) +{ + cl_setup(); + cl_kernel gemm_kernel = get_gemm_kernel(); + cl_context context = cl.context; + cl_command_queue queue = cl.queue; + + size_t size = sizeof(float)*(TA ? lda*K:lda*M); + cl_mem A_gpu = clCreateBuffer(context, + CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR, + size, A, &cl.error); + check_error(cl); + + size = sizeof(float)*(TB ? ldb*N:ldb*K); + cl_mem B_gpu = clCreateBuffer(context, + CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR, + size, B, &cl.error); + check_error(cl); + + size = sizeof(float)*(ldc*M); + cl_mem C_gpu = clCreateBuffer(context, + CL_MEM_WRITE_ONLY|CL_MEM_COPY_HOST_PTR, + size, C, &cl.error); + check_error(cl); + + cl_uint i = 0; + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TA), (void*) &TA); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TB), (void*) &TB); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(M), (void*) &M); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(N), (void*) &N); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(K), (void*) &K); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ALPHA), (void*) &ALPHA); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(A_gpu), (void*) &A_gpu); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(lda), (void*) &lda); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(B_gpu), (void*) &B_gpu); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldb), (void*) &ldb); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(BETA), (void*) &BETA); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(C_gpu), (void*) &C_gpu); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldc), (void*) &ldc); + check_error(cl); + + const size_t global_size[] = {ceil((float)M/BLOCK)*BLOCK, ceil((float)N/BLOCK)*BLOCK}; + const size_t local_size[] = {BLOCK, BLOCK}; + //printf("%zd %zd %zd %zd\n", global_size[0], global_size[1], local_size[0], local_size[1]); + + clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0); + check_error(cl); + clEnqueueReadBuffer(queue, C_gpu, CL_TRUE, 0, size, C, 0, 0, 0); + check_error(cl); + + clReleaseMemObject(A_gpu); + clReleaseMemObject(B_gpu); + clReleaseMemObject(C_gpu); + +} + +/* +cl_kernel get_gemm_kernel_slow() +{ + static int init = 0; + static cl_kernel gemm_kernel; + if(!init){ + gemm_kernel = get_kernel("src/gemm.cl", "gemm_slow"); + init = 1; + } + return gemm_kernel; +} + +void gpu_gemm_slow(int TA, int TB, int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float BETA, + float *C, int ldc) +{ + cl_setup(); + cl_kernel gemm_kernel = get_gemm_kernel_slow(); + cl_context context = cl.context; + cl_command_queue queue = cl.queue; + + size_t size = sizeof(float)*(TA ? lda*K:lda*M); + cl_mem A_gpu = clCreateBuffer(context, + CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR, + size, A, &cl.error); + check_error(cl); + + size = sizeof(float)*(TB ? ldb*N:ldb*K); + cl_mem B_gpu = clCreateBuffer(context, + CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR, + size, B, &cl.error); + check_error(cl); + + size = sizeof(float)*(ldc*M); + cl_mem C_gpu = clCreateBuffer(context, + CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR, + size, C, &cl.error); + check_error(cl); + + cl_uint i = 0; + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TA), (void*) &TA); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TB), (void*) &TB); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(M), (void*) &M); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(N), (void*) &N); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(K), (void*) &K); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ALPHA), (void*) &ALPHA); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(A_gpu), (void*) &A_gpu); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(lda), (void*) &lda); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(B_gpu), (void*) &B_gpu); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldb), (void*) &ldb); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(BETA), (void*) &BETA); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(C_gpu), (void*) &C_gpu); + cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldc), (void*) &ldc); + check_error(cl); + + const size_t global_size[] = {M, N}; + + clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, 0, 0, 0, 0); + clEnqueueReadBuffer(queue, C_gpu, CL_TRUE, 0, size, C, 0, 0, 0); + + clReleaseMemObject(A_gpu); + clReleaseMemObject(B_gpu); + clReleaseMemObject(C_gpu); + +} +*/ diff --git a/src/image.c b/src/image.c index 16679776..453919fb 100644 --- a/src/image.c +++ b/src/image.c @@ -113,6 +113,7 @@ image copy_image(image p) return copy; } + void show_image(image p, char *name) { int i,j,k; @@ -136,7 +137,7 @@ void show_image(image p, char *name) } } free_image(copy); - if(disp->height < 500 || disp->width < 500){ + if(disp->height < 500 || disp->width < 500 || disp->height > 1000){ int w = 1500; int h = w*p.h/p.w; if(h > 1000){ @@ -152,6 +153,30 @@ void show_image(image p, char *name) cvReleaseImage(&disp); } +void save_image(image p, char *name) +{ + int i,j,k; + image copy = copy_image(p); + normalize_image(copy); + + char buff[256]; + //sprintf(buff, "%s (%d)", name, windows); + sprintf(buff, "%s.png", name); + + IplImage *disp = cvCreateImage(cvSize(p.w,p.h), IPL_DEPTH_8U, p.c); + int step = disp->widthStep; + for(i = 0; i < p.h; ++i){ + for(j = 0; j < p.w; ++j){ + for(k= 0; k < p.c; ++k){ + disp->imageData[i*step + j*p.c + k] = (unsigned char)(get_pixel(copy,i,j,k)*255); + } + } + } + free_image(copy); + cvSaveImage(buff, disp,0); + cvReleaseImage(&disp); +} + void show_image_layers(image p, char *name) { int i; @@ -227,7 +252,19 @@ image make_random_image(int h, int w, int c) return out; } -void add_scalar_image(image m, float s) +void add_into_image(image src, image dest, int h, int w) +{ + int i,j,k; + for(k = 0; k < src.c; ++k){ + for(i = 0; i < src.h; ++i){ + for(j = 0; j < src.w; ++j){ + add_pixel(dest, h+i, w+j, k, get_pixel(src, i, j, k)); + } + } + } +} + +void translate_image(image m, float s) { int i; for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] += s; @@ -404,6 +441,20 @@ image get_image_layer(image m, int l) } return out; } +image get_sub_image(image m, int h, int w, int dh, int dw) +{ + image out = make_image(dh, dw, m.c); + int i,j,k; + for(k = 0; k < out.c; ++k){ + for(i = 0; i < dh; ++i){ + for(j = 0; j < dw; ++j){ + float val = get_pixel(m, h+i, w+j, k); + set_pixel(out, i, j, k, val); + } + } + } + return out; +} float get_pixel(image m, int x, int y, int c) { @@ -594,6 +645,121 @@ void print_image(image m) for(i =0 ; i < m.h*m.w*m.c; ++i) printf("%lf, ", m.data[i]); printf("\n"); } +image collapse_images_vert(image *ims, int n) +{ + int color = 1; + int border = 1; + int h,w,c; + w = ims[0].w; + h = (ims[0].h + border) * n - border; + c = ims[0].c; + if(c != 3 || !color){ + w = (w+border)*c - border; + c = 1; + } + + image filters = make_image(h,w,c); + int i,j; + for(i = 0; i < n; ++i){ + int h_offset = i*(ims[0].h+border); + image copy = copy_image(ims[i]); + //normalize_image(copy); + if(c == 3 && color){ + embed_image(copy, filters, h_offset, 0); + } + else{ + for(j = 0; j < copy.c; ++j){ + int w_offset = j*(ims[0].w+border); + image layer = get_image_layer(copy, j); + embed_image(layer, filters, h_offset, w_offset); + free_image(layer); + } + } + free_image(copy); + } + return filters; +} + +image collapse_images_horz(image *ims, int n) +{ + int color = 1; + int border = 1; + int h,w,c; + int size = ims[0].h; + h = size; + w = (ims[0].w + border) * n - border; + c = ims[0].c; + if(c != 3 || !color){ + h = (h+border)*c - border; + c = 1; + } + + image filters = make_image(h,w,c); + int i,j; + for(i = 0; i < n; ++i){ + int w_offset = i*(size+border); + image copy = copy_image(ims[i]); + //normalize_image(copy); + if(c == 3 && color){ + embed_image(copy, filters, 0, w_offset); + } + else{ + for(j = 0; j < copy.c; ++j){ + int h_offset = j*(size+border); + image layer = get_image_layer(copy, j); + embed_image(layer, filters, h_offset, w_offset); + free_image(layer); + } + } + free_image(copy); + } + return filters; +} + +void show_images(image *ims, int n, char *window) +{ + image m = collapse_images_vert(ims, n); + save_image(m, window); + show_image(m, window); + free_image(m); +} + +image grid_images(image **ims, int h, int w) +{ + int i; + image *rows = calloc(h, sizeof(image)); + for(i = 0; i < h; ++i){ + rows[i] = collapse_images_horz(ims[i], w); + } + image out = collapse_images_vert(rows, h); + for(i = 0; i < h; ++i){ + free_image(rows[i]); + } + free(rows); + return out; +} + +void test_grid() +{ + int i,j; + int num = 3; + int topk = 3; + image **vizs = calloc(num, sizeof(image*)); + for(i = 0; i < num; ++i){ + vizs[i] = calloc(topk, sizeof(image)); + for(j = 0; j < topk; ++j) vizs[i][j] = make_image(3,3,3); + } + image grid = grid_images(vizs, num, topk); + save_image(grid, "Test Grid"); + free_image(grid); +} + +void show_images_grid(image **ims, int h, int w, char *window) +{ + image out = grid_images(ims, h, w); + show_image(out, window); + free_image(out); +} void free_image(image m) { diff --git a/src/image.h b/src/image.h index 9f7d74d4..fe257425 100644 --- a/src/image.h +++ b/src/image.h @@ -1,6 +1,7 @@ #ifndef IMAGE_H #define IMAGE_H + #include "opencv2/highgui/highgui_c.h" #include "opencv2/imgproc/imgproc_c.h" typedef struct { @@ -12,7 +13,7 @@ typedef struct { image image_distance(image a, image b); void scale_image(image m, float s); -void add_scalar_image(image m, float s); +void translate_image(image m, float s); void normalize_image(image p); void z_normalize_image(image p); void threshold_image(image p, float t); @@ -21,11 +22,20 @@ void rotate_image(image m); void subtract_image(image a, image b); float avg_image_layer(image m, int l); void embed_image(image source, image dest, int h, int w); +void add_into_image(image src, image dest, int h, int w); image collapse_image_layers(image source, int border); +image collapse_images_horz(image *ims, int n); +image collapse_images_vert(image *ims, int n); +image get_sub_image(image m, int h, int w, int dh, int dw); void show_image(image p, char *name); +void save_image(image p, char *name); +void show_images(image *ims, int n, char *window); void show_image_layers(image p, char *name); void show_image_collapsed(image p, char *name); +void show_images_grid(image **ims, int h, int w, char *window); +void test_grid(); +image grid_images(image **ims, int h, int w); void print_image(image m); image make_image(int h, int w, int c); @@ -39,6 +49,7 @@ image ipl_to_image(IplImage* src); float get_pixel(image m, int x, int y, int c); float get_pixel_extend(image m, int x, int y, int c); +void add_pixel(image m, int x, int y, int c, float val); void set_pixel(image m, int x, int y, int c, float val); image get_image_layer(image m, int l); diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c index 8c409b94..413816a6 100644 --- a/src/maxpool_layer.c +++ b/src/maxpool_layer.c @@ -17,10 +17,12 @@ image get_maxpool_delta(maxpool_layer layer) return float_to_image(h,w,c,layer.delta); } -maxpool_layer *make_maxpool_layer(int h, int w, int c, int stride) +maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int stride) { + c = c*batch; fprintf(stderr, "Maxpool Layer: %d x %d x %d image, %d stride\n", h,w,c,stride); maxpool_layer *layer = calloc(1, sizeof(maxpool_layer)); + layer->batch = batch; layer->h = h; layer->w = w; layer->c = c; @@ -30,6 +32,15 @@ maxpool_layer *make_maxpool_layer(int h, int w, int c, int stride) return layer; } +void resize_maxpool_layer(maxpool_layer *layer, int h, int w, int c) +{ + layer->h = h; + layer->w = w; + layer->c = c; + layer->output = realloc(layer->output, ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * c * sizeof(float)); + layer->delta = realloc(layer->delta, ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * c * sizeof(float)); +} + void forward_maxpool_layer(const maxpool_layer layer, float *in) { image input = float_to_image(layer.h, layer.w, layer.c, in); diff --git a/src/maxpool_layer.h b/src/maxpool_layer.h index 27d6f55a..92d41e66 100644 --- a/src/maxpool_layer.h +++ b/src/maxpool_layer.h @@ -4,6 +4,7 @@ #include "image.h" typedef struct { + int batch; int h,w,c; int stride; float *delta; @@ -11,7 +12,8 @@ typedef struct { } maxpool_layer; image get_maxpool_image(maxpool_layer layer); -maxpool_layer *make_maxpool_layer(int h, int w, int c, int stride); +maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int stride); +void resize_maxpool_layer(maxpool_layer *layer, int h, int w, int c); void forward_maxpool_layer(const maxpool_layer layer, float *in); void backward_maxpool_layer(const maxpool_layer layer, float *in, float *delta); diff --git a/src/mini_blas.c b/src/mini_blas.c index 262798bc..bac3e226 100644 --- a/src/mini_blas.c +++ b/src/mini_blas.c @@ -3,6 +3,8 @@ #include #include #include +#include +#include "mini_blas.h" void pm(int M, int N, float *A) { @@ -17,42 +19,12 @@ void pm(int M, int N, float *A) } void gemm(int TA, int TB, int M, int N, int K, float ALPHA, - float *A, int lda, - float *B, int ldb, - float BETA, - float *C, int ldc) + float *A, int lda, + float *B, int ldb, + float BETA, + float *C, int ldc) { - // Assume beta = 1 LULZ - int i,j,k; - if(TB && !TA){ - for(i = 0; i < M; ++i){ - for(j = 0; j < N; ++j){ - register float sum = 0; - for(k = 0; k < K; ++k){ - sum += ALPHA*A[i*lda+k]*B[k+j*ldb]; - } - C[i*ldc+j] += sum; - } - } - }else if(TA && !TB){ - for(i = 0; i < M; ++i){ - for(k = 0; k < K; ++k){ - register float A_PART = ALPHA*A[k*lda+i]; - for(j = 0; j < N; ++j){ - C[i*ldc+j] += A_PART*B[k*ldb+j]; - } - } - } - }else{ - for(i = 0; i < M; ++i){ - for(k = 0; k < K; ++k){ - register float A_PART = ALPHA*A[i*lda+k]; - for(j = 0; j < N; ++j){ - C[i*ldc+j] += A_PART*B[k*ldb+j]; - } - } - } - } + gpu_gemm( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); } void im2row(float *image, int h, int w, int c, int size, int stride, float *matrix) @@ -150,16 +122,26 @@ float *random_matrix(int rows, int cols) void time_random_matrix(int TA, int TB, int m, int k, int n) { - float *a = random_matrix(m,k); - float *b = random_matrix(k,n); + float *a; + if(!TA) a = random_matrix(m,k); + else a = random_matrix(k,m); + int lda = (!TA)?k:m; + float *b; + if(!TB) b = random_matrix(k,n); + else b = random_matrix(n,k); + int ldb = (!TB)?n:k; + float *c = random_matrix(m,n); int i; clock_t start = clock(), end; for(i = 0; i<1000; ++i){ - gemm(TA,TB,m,n,k,1,a,k,b,n,1,c,n); + cpu_gemm(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n); } end = clock(); printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (float)(end-start)/CLOCKS_PER_SEC); + free(a); + free(b); + free(c); } void test_blas() @@ -167,9 +149,97 @@ void test_blas() time_random_matrix(0,0,100,100,100); time_random_matrix(1,0,100,100,100); time_random_matrix(0,1,100,100,100); + time_random_matrix(1,1,100,100,100); - time_random_matrix(0,1,1000,100,100); + time_random_matrix(0,0,1000,100,100); time_random_matrix(1,0,1000,100,100); + time_random_matrix(0,1,1000,100,100); + time_random_matrix(1,1,1000,100,100); + } +void time_gpu_random_matrix(int TA, int TB, int m, int k, int n) +{ + float *a; + if(!TA) a = random_matrix(m,k); + else a = random_matrix(k,m); + int lda = (!TA)?k:m; + float *b; + if(!TB) b = random_matrix(k,n); + else b = random_matrix(n,k); + int ldb = (!TB)?n:k; + + float *c = random_matrix(m,n); + int i; + clock_t start = clock(), end; + for(i = 0; i<1000; ++i){ + gpu_gemm(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n); + } + end = clock(); + printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (float)(end-start)/CLOCKS_PER_SEC); + free(a); + free(b); + free(c); +} + +void test_gpu_accuracy(int TA, int TB, int m, int k, int n) +{ + srand(0); + float *a; + if(!TA) a = random_matrix(m,k); + else a = random_matrix(k,m); + int lda = (!TA)?k:m; + float *b; + if(!TB) b = random_matrix(k,n); + else b = random_matrix(n,k); + int ldb = (!TB)?n:k; + + float *c = random_matrix(m,n); + float *c_gpu = random_matrix(m,n); + memset(c, 0, m*n*sizeof(float)); + memset(c_gpu, 0, m*n*sizeof(float)); + int i; + //pm(m,k,b); + gpu_gemm(TA,TB,m,n,k,1,a,lda,b,ldb,1,c_gpu,n); + //pm(m, n, c_gpu); + cpu_gemm(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n); + //pm(m, n, c); + double sse = 0; + for(i = 0; i < m*n; ++i) { + //printf("%f %f\n", c[i], c_gpu[i]); + sse += pow(c[i]-c_gpu[i], 2); + } + printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %g MSE\n",m,k,k,n, TA, TB, sse/(m*n)); + free(a); + free(b); + free(c); +} + +void test_gpu_blas() +{ + test_gpu_accuracy(0,0,17,10,10); + test_gpu_accuracy(1,0,17,10,10); + test_gpu_accuracy(0,1,17,10,10); + test_gpu_accuracy(1,1,17,10,10); + + test_gpu_accuracy(0,0,1000,10,100); + test_gpu_accuracy(1,0,1000,10,100); + test_gpu_accuracy(0,1,1000,10,100); + test_gpu_accuracy(1,1,1000,10,100); + + time_gpu_random_matrix(0,0,1000,1000,100); + time_random_matrix(0,0,1000,1000,100); + + time_gpu_random_matrix(0,1,1000,1000,100); + time_random_matrix(0,1,1000,1000,100); + + time_gpu_random_matrix(1,0,1000,1000,100); + time_random_matrix(1,0,1000,1000,100); + + time_gpu_random_matrix(1,1,1000,1000,100); + time_random_matrix(1,1,1000,1000,100); + +} + + diff --git a/src/mini_blas.h b/src/mini_blas.h index ff82a60c..56e4fa72 100644 --- a/src/mini_blas.h +++ b/src/mini_blas.h @@ -4,6 +4,7 @@ void gemm(int TA, int TB, int M, int N, int K, float ALPHA, float *B, int ldb, float BETA, float *C, int ldc); +float *random_matrix(int rows, int cols); void im2row(float *image, int h, int w, int c, int size, int stride, float *matrix); void im2col(float *image, int h, int w, int c, int size, int stride, float *matrix); void im2col_cpu(float* data_im, const int channels, @@ -13,3 +14,15 @@ void col2im_cpu(float* data_col, const int channels, const int height, const int width, const int ksize, const int stride, float* data_im); void test_blas(); + +void gpu_gemm(int TA, int TB, int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float BETA, + float *C, int ldc); +void cpu_gemm(int TA, int TB, int M, int N, int K, float ALPHA, + float *A, int lda, + float *B, int ldb, + float BETA, + float *C, int ldc); +void test_gpu_blas(); diff --git a/src/network.c b/src/network.c index b2fc9225..7d4b1fac 100644 --- a/src/network.c +++ b/src/network.c @@ -8,12 +8,14 @@ #include "convolutional_layer.h" //#include "old_conv.h" #include "maxpool_layer.h" +#include "normalization_layer.h" #include "softmax_layer.h" -network make_network(int n) +network make_network(int n, int batch) { network net; net.n = n; + net.batch = batch; net.layers = calloc(net.n, sizeof(void *)); net.types = calloc(net.n, sizeof(LAYER_TYPE)); net.outputs = 0; @@ -25,10 +27,11 @@ void print_convolutional_cfg(FILE *fp, convolutional_layer *l, int first) { int i; fprintf(fp, "[convolutional]\n"); - if(first) fprintf(fp, "height=%d\n" + if(first) fprintf(fp, "batch=%d\n" + "height=%d\n" "width=%d\n" "channels=%d\n", - l->h, l->w, l->c); + l->batch,l->h, l->w, l->c); fprintf(fp, "filters=%d\n" "size=%d\n" "stride=%d\n" @@ -38,17 +41,28 @@ void print_convolutional_cfg(FILE *fp, convolutional_layer *l, int first) fprintf(fp, "data="); for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]); for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]); + /* + int j,k; + for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]); + for(i = 0; i < l->n; ++i){ + for(j = l->c-1; j >= 0; --j){ + for(k = 0; k < l->size*l->size; ++k){ + fprintf(fp, "%g,", l->filters[i*(l->c*l->size*l->size)+j*l->size*l->size+k]); + } + } + } + */ fprintf(fp, "\n\n"); } void print_connected_cfg(FILE *fp, connected_layer *l, int first) { int i; fprintf(fp, "[connected]\n"); - if(first) fprintf(fp, "input=%d\n", l->inputs); + if(first) fprintf(fp, "batch=%d\ninput=%d\n", l->batch, l->inputs); fprintf(fp, "output=%d\n" - "activation=%s\n", - l->outputs, - get_activation_string(l->activation)); + "activation=%s\n", + l->outputs, + get_activation_string(l->activation)); fprintf(fp, "data="); for(i = 0; i < l->outputs; ++i) fprintf(fp, "%g,", l->biases[i]); for(i = 0; i < l->inputs*l->outputs; ++i) fprintf(fp, "%g,", l->weights[i]); @@ -58,17 +72,32 @@ void print_connected_cfg(FILE *fp, connected_layer *l, int first) void print_maxpool_cfg(FILE *fp, maxpool_layer *l, int first) { fprintf(fp, "[maxpool]\n"); - if(first) fprintf(fp, "height=%d\n" - "width=%d\n" - "channels=%d\n", - l->h, l->w, l->c); + if(first) fprintf(fp, "batch=%d\n" + "height=%d\n" + "width=%d\n" + "channels=%d\n", + l->batch,l->h, l->w, l->c); fprintf(fp, "stride=%d\n\n", l->stride); } +void print_normalization_cfg(FILE *fp, normalization_layer *l, int first) +{ + fprintf(fp, "[localresponsenormalization]\n"); + if(first) fprintf(fp, "batch=%d\n" + "height=%d\n" + "width=%d\n" + "channels=%d\n", + l->batch,l->h, l->w, l->c); + fprintf(fp, "size=%d\n" + "alpha=%g\n" + "beta=%g\n" + "kappa=%g\n\n", l->size, l->alpha, l->beta, l->kappa); +} + void print_softmax_cfg(FILE *fp, softmax_layer *l, int first) { fprintf(fp, "[softmax]\n"); - if(first) fprintf(fp, "input=%d\n", l->inputs); + if(first) fprintf(fp, "batch=%d\ninput=%d\n", l->batch, l->inputs); fprintf(fp, "\n"); } @@ -85,6 +114,8 @@ void save_network(network net, char *filename) print_connected_cfg(fp, (connected_layer *)net.layers[i], i==0); else if(net.types[i] == MAXPOOL) print_maxpool_cfg(fp, (maxpool_layer *)net.layers[i], i==0); + else if(net.types[i] == NORMALIZATION) + print_normalization_cfg(fp, (normalization_layer *)net.layers[i], i==0); else if(net.types[i] == SOFTMAX) print_softmax_cfg(fp, (softmax_layer *)net.layers[i], i==0); } @@ -115,6 +146,11 @@ void forward_network(network net, float *input) forward_maxpool_layer(layer, input); input = layer.output; } + else if(net.types[i] == NORMALIZATION){ + normalization_layer layer = *(normalization_layer *)net.layers[i]; + forward_normalization_layer(layer, input); + input = layer.output; + } } } @@ -132,6 +168,9 @@ void update_network(network net, float step, float momentum, float decay) else if(net.types[i] == SOFTMAX){ //maxpool_layer layer = *(maxpool_layer *)net.layers[i]; } + else if(net.types[i] == NORMALIZATION){ + //maxpool_layer layer = *(maxpool_layer *)net.layers[i]; + } else if(net.types[i] == CONNECTED){ connected_layer layer = *(connected_layer *)net.layers[i]; update_connected_layer(layer, step, momentum, decay); @@ -153,6 +192,9 @@ float *get_network_output_layer(network net, int i) } else if(net.types[i] == CONNECTED){ connected_layer layer = *(connected_layer *)net.layers[i]; return layer.output; + } else if(net.types[i] == NORMALIZATION){ + normalization_layer layer = *(normalization_layer *)net.layers[i]; + return layer.output; } return 0; } @@ -191,11 +233,11 @@ float calculate_error_network(network net, float *truth) float *out = get_network_output(net); int i, k = get_network_output_size(net); for(i = 0; i < k; ++i){ - printf("%f, ", out[i]); + //printf("%f, ", out[i]); delta[i] = truth[i] - out[i]; sum += delta[i]*delta[i]; } - printf("\n"); + //printf("\n"); return sum; } @@ -230,6 +272,10 @@ float backward_network(network net, float *input, float *truth) maxpool_layer layer = *(maxpool_layer *)net.layers[i]; if(i != 0) backward_maxpool_layer(layer, prev_input, prev_delta); } + else if(net.types[i] == NORMALIZATION){ + normalization_layer layer = *(normalization_layer *)net.layers[i]; + if(i != 0) backward_normalization_layer(layer, prev_input, prev_delta); + } else if(net.types[i] == SOFTMAX){ softmax_layer layer = *(softmax_layer *)net.layers[i]; if(i != 0) backward_softmax_layer(layer, prev_input, prev_delta); @@ -258,19 +304,26 @@ float train_network_sgd(network net, data d, int n, float step, float momentum,f int i; float error = 0; int correct = 0; + int pos = 0; for(i = 0; i < n; ++i){ int index = rand()%d.X.rows; - error += train_network_datum(net, d.X.vals[index], d.y.vals[index], step, momentum, decay); + float err = train_network_datum(net, d.X.vals[index], d.y.vals[index], step, momentum, decay); float *y = d.y.vals[index]; int class = get_predicted_class_network(net); correct += (y[class]?1:0); + if(y[1]){ + error += err; + ++pos; + } + + //printf("%d %f %f\n", i,net.output[0], d.y.vals[index][0]); //if((i+1)%10 == 0){ // printf("%d: %f\n", (i+1), (float)correct/(i+1)); //} } - printf("Accuracy: %f\n",(float) correct/n); - return error/n; + //printf("Accuracy: %f\n",(float) correct/n); + return error/pos; } float train_network_batch(network net, data d, int n, float step, float momentum,float decay) { @@ -304,7 +357,7 @@ void train_network(network net, data d, float step, float momentum, float decay) } visualize_network(net); cvWaitKey(100); - printf("Accuracy: %f\n", (float)correct/d.X.rows); + fprintf(stderr, "Accuracy: %f\n", (float)correct/d.X.rows); } int get_network_output_size_layer(network net, int i) @@ -330,29 +383,63 @@ int get_network_output_size_layer(network net, int i) return 0; } -int reset_network_size(network net, int h, int w, int c) +/* + int resize_network(network net, int h, int w, int c) + { + int i; + for (i = 0; i < net.n; ++i){ + if(net.types[i] == CONVOLUTIONAL){ + convolutional_layer *layer = (convolutional_layer *)net.layers[i]; + layer->h = h; + layer->w = w; + layer->c = c; + image output = get_convolutional_image(*layer); + h = output.h; + w = output.w; + c = output.c; + } + else if(net.types[i] == MAXPOOL){ + maxpool_layer *layer = (maxpool_layer *)net.layers[i]; + layer->h = h; + layer->w = w; + layer->c = c; + image output = get_maxpool_image(*layer); + h = output.h; + w = output.w; + c = output.c; + } + } + return 0; + } + */ + +int resize_network(network net, int h, int w, int c) { int i; for (i = 0; i < net.n; ++i){ if(net.types[i] == CONVOLUTIONAL){ convolutional_layer *layer = (convolutional_layer *)net.layers[i]; - layer->h = h; - layer->w = w; - layer->c = c; + resize_convolutional_layer(layer, h, w, c); image output = get_convolutional_image(*layer); h = output.h; w = output.w; c = output.c; - } - else if(net.types[i] == MAXPOOL){ + }else if(net.types[i] == MAXPOOL){ maxpool_layer *layer = (maxpool_layer *)net.layers[i]; - layer->h = h; - layer->w = w; - layer->c = c; + resize_maxpool_layer(layer, h, w, c); image output = get_maxpool_image(*layer); h = output.h; w = output.w; c = output.c; + }else if(net.types[i] == NORMALIZATION){ + normalization_layer *layer = (normalization_layer *)net.layers[i]; + resize_normalization_layer(layer, h, w, c); + image output = get_normalization_image(*layer); + h = output.h; + w = output.w; + c = output.c; + }else{ + error("Cannot resize this type of layer"); } } return 0; @@ -374,6 +461,10 @@ image get_network_image_layer(network net, int i) maxpool_layer layer = *(maxpool_layer *)net.layers[i]; return get_maxpool_image(layer); } + else if(net.types[i] == NORMALIZATION){ + normalization_layer layer = *(normalization_layer *)net.layers[i]; + return get_normalization_image(layer); + } return make_empty_image(0,0,0); } @@ -389,13 +480,18 @@ image get_network_image(network net) void visualize_network(network net) { + image *prev = 0; int i; char buff[256]; for(i = 0; i < net.n; ++i){ sprintf(buff, "Layer %d", i); if(net.types[i] == CONVOLUTIONAL){ convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - visualize_convolutional_layer(layer, buff); + prev = visualize_convolutional_layer(layer, buff, prev); + } + if(net.types[i] == NORMALIZATION){ + normalization_layer layer = *(normalization_layer *)net.layers[i]; + visualize_normalization_layer(layer, buff); } } } @@ -467,3 +563,4 @@ float network_accuracy(network net, data d) return acc; } + diff --git a/src/network.h b/src/network.h index c75804d3..f6dac7e6 100644 --- a/src/network.h +++ b/src/network.h @@ -9,18 +9,20 @@ typedef enum { CONVOLUTIONAL, CONNECTED, MAXPOOL, - SOFTMAX + SOFTMAX, + NORMALIZATION } LAYER_TYPE; typedef struct { int n; + int batch; void **layers; LAYER_TYPE *types; int outputs; float *output; } network; -network make_network(int n); +network make_network(int n, int batch); void forward_network(network net, float *input); float backward_network(network net, float *input, float *truth); void update_network(network net, float step, float momentum, float decay); @@ -41,7 +43,7 @@ int get_predicted_class_network(network net); void print_network(network net); void visualize_network(network net); void save_network(network net, char *filename); -int reset_network_size(network net, int h, int w, int c); +int resize_network(network net, int h, int w, int c); #endif diff --git a/src/normalization_layer.c b/src/normalization_layer.c new file mode 100644 index 00000000..2d844e0e --- /dev/null +++ b/src/normalization_layer.c @@ -0,0 +1,96 @@ +#include "normalization_layer.h" +#include + +image get_normalization_image(normalization_layer layer) +{ + int h = layer.h; + int w = layer.w; + int c = layer.c; + return float_to_image(h,w,c,layer.output); +} + +image get_normalization_delta(normalization_layer layer) +{ + int h = layer.h; + int w = layer.w; + int c = layer.c; + return float_to_image(h,w,c,layer.delta); +} + +normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa) +{ + fprintf(stderr, "Local Response Normalization Layer: %d x %d x %d image, %d size\n", h,w,c,size); + normalization_layer *layer = calloc(1, sizeof(normalization_layer)); + layer->batch = batch; + layer->h = h; + layer->w = w; + layer->c = c; + layer->kappa = kappa; + layer->size = size; + layer->alpha = alpha; + layer->beta = beta; + layer->output = calloc(h * w * c * batch, sizeof(float)); + layer->delta = calloc(h * w * c * batch, sizeof(float)); + layer->sums = calloc(h*w, sizeof(float)); + return layer; +} + +void resize_normalization_layer(normalization_layer *layer, int h, int w, int c) +{ + layer->h = h; + layer->w = w; + layer->c = c; + layer->output = realloc(layer->output, h * w * c * layer->batch * sizeof(float)); + layer->delta = realloc(layer->delta, h * w * c * layer->batch * sizeof(float)); + layer->sums = realloc(layer->sums, h*w * sizeof(float)); +} + +void add_square_array(float *src, float *dest, int n) +{ + int i; + for(i = 0; i < n; ++i){ + dest[i] += src[i]*src[i]; + } +} +void sub_square_array(float *src, float *dest, int n) +{ + int i; + for(i = 0; i < n; ++i){ + dest[i] -= src[i]*src[i]; + } +} + +void forward_normalization_layer(const normalization_layer layer, float *in) +{ + int i,j,k; + memset(layer.sums, 0, layer.h*layer.w*sizeof(float)); + int imsize = layer.h*layer.w; + for(j = 0; j < layer.size/2; ++j){ + if(j < layer.c) add_square_array(in+j*imsize, layer.sums, imsize); + } + for(k = 0; k < layer.c; ++k){ + int next = k+layer.size/2; + int prev = k-layer.size/2-1; + if(next < layer.c) add_square_array(in+next*imsize, layer.sums, imsize); + if(prev > 0) sub_square_array(in+prev*imsize, layer.sums, imsize); + for(i = 0; i < imsize; ++i){ + layer.output[k*imsize + i] = in[k*imsize+i] / pow(layer.kappa + layer.alpha * layer.sums[i], layer.beta); + } + } +} + +void backward_normalization_layer(const normalization_layer layer, float *in, float *delta) +{ + //TODO! +} + +void visualize_normalization_layer(normalization_layer layer, char *window) +{ + image delta = get_normalization_image(layer); + image dc = collapse_image_layers(delta, 1); + char buff[256]; + sprintf(buff, "%s: Output", window); + show_image(dc, buff); + save_image(dc, buff); + free_image(dc); +} diff --git a/src/normalization_layer.h b/src/normalization_layer.h new file mode 100644 index 00000000..fcf8af11 --- /dev/null +++ b/src/normalization_layer.h @@ -0,0 +1,26 @@ +#ifndef NORMALIZATION_LAYER_H +#define NORMALIZATION_LAYER_H + +#include "image.h" + +typedef struct { + int batch; + int h,w,c; + int size; + float alpha; + float beta; + float kappa; + float *delta; + float *output; + float *sums; +} normalization_layer; + +image get_normalization_image(normalization_layer layer); +normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa); +void resize_normalization_layer(normalization_layer *layer, int h, int w, int c); +void forward_normalization_layer(const normalization_layer layer, float *in); +void backward_normalization_layer(const normalization_layer layer, float *in, float *delta); +void visualize_normalization_layer(normalization_layer layer, char *window); + +#endif + diff --git a/src/opencl.c b/src/opencl.c new file mode 100644 index 00000000..193fba32 --- /dev/null +++ b/src/opencl.c @@ -0,0 +1,77 @@ +#include "opencl.h" +#include +#include +#include + +cl_info cl = {0}; + +void check_error(cl_info info) +{ + if (info.error != CL_SUCCESS) { + printf("\n Error number %d", info.error); + } +} + +cl_info cl_init() +{ + cl_info info; + info.initialized = 0; + cl_uint platforms, devices; + // Fetch the Platform and Device IDs; we only want one. + info.error=clGetPlatformIDs(1, &info.platform, &platforms); + check_error(info); + info.error=clGetDeviceIDs(info.platform, CL_DEVICE_TYPE_ALL, 1, &info.device, &devices); + check_error(info); + + cl_context_properties properties[]={ + CL_CONTEXT_PLATFORM, (cl_context_properties)info.platform, + 0}; + // Note that nVidia's OpenCL requires the platform property + info.context=clCreateContext(properties, 1, &info.device, 0, 0, &info.error); + check_error(info); + info.queue = clCreateCommandQueue(info.context, info.device, 0, &info.error); + check_error(info); + info.initialized = 1; + return info; +} + +cl_program cl_fprog(char *filename, char *options, cl_info info) +{ + size_t srcsize; + char src[8192]; + memset(src, 0, 8192); + FILE *fil=fopen(filename,"r"); + srcsize=fread(src, sizeof src, 1, fil); + fclose(fil); + const char *srcptr[]={src}; + // Submit the source code of the example kernel to OpenCL + cl_program prog=clCreateProgramWithSource(info.context,1, srcptr, &srcsize, &info.error); + check_error(info); + char build_c[4096]; + // and compile it (after this we could extract the compiled version) + info.error=clBuildProgram(prog, 0, 0, options, 0, 0); + if ( info.error != CL_SUCCESS ) { + fprintf(stderr, "Error Building Program: %d\n", info.error); + clGetProgramBuildInfo( prog, info.device, CL_PROGRAM_BUILD_LOG, 4096, build_c, 0); + fprintf(stderr, "Build Log for %s program:\n%s\n", filename, build_c); + } + return prog; +} + +void cl_setup() +{ + if(!cl.initialized){ + cl = cl_init(); + } +} + +cl_kernel get_kernel(char *filename, char *kernelname, char *options) +{ + cl_setup(); + cl_program prog = cl_fprog(filename, options, cl); + cl_kernel kernel=clCreateKernel(prog, kernelname, &cl.error); + check_error(cl); + return kernel; +} + + diff --git a/src/opencl.h b/src/opencl.h new file mode 100644 index 00000000..59efbae0 --- /dev/null +++ b/src/opencl.h @@ -0,0 +1,21 @@ +#ifdef __APPLE__ +#include +#else +#include +#endif + +typedef struct { + int initialized; + cl_int error; + cl_platform_id platform; + cl_device_id device; + cl_context context; + cl_command_queue queue; +}cl_info; + +extern cl_info cl; + +void cl_setup(); +void check_error(cl_info info); +cl_kernel get_kernel(char *filename, char *kernelname, char *options); + diff --git a/src/parser.c b/src/parser.c index cf35a94a..4aa0a79b 100644 --- a/src/parser.c +++ b/src/parser.c @@ -7,6 +7,7 @@ #include "convolutional_layer.h" #include "connected_layer.h" #include "maxpool_layer.h" +#include "normalization_layer.h" #include "softmax_layer.h" #include "list.h" #include "option_list.h" @@ -21,6 +22,7 @@ int is_convolutional(section *s); int is_connected(section *s); int is_maxpool(section *s); int is_softmax(section *s); +int is_normalization(section *s); list *read_cfg(char *filename); void free_section(section *s) @@ -52,6 +54,7 @@ convolutional_layer *parse_convolutional(list *options, network net, int count) h = option_find_int(options, "height",1); w = option_find_int(options, "width",1); c = option_find_int(options, "channels",1); + net.batch = option_find_int(options, "batch",1); }else{ image m = get_network_image_layer(net, count-1); h = m.h; @@ -59,7 +62,7 @@ convolutional_layer *parse_convolutional(list *options, network net, int count) c = m.c; if(h == 0) error("Layer before convolutional layer must output image."); } - convolutional_layer *layer = make_convolutional_layer(h,w,c,n,size,stride, activation); + convolutional_layer *layer = make_convolutional_layer(net.batch,h,w,c,n,size,stride, activation); char *data = option_find_str(options, "data", 0); if(data){ char *curr = data; @@ -90,10 +93,11 @@ connected_layer *parse_connected(list *options, network net, int count) ACTIVATION activation = get_activation(activation_s); if(count == 0){ input = option_find_int(options, "input",1); + net.batch = option_find_int(options, "batch",1); }else{ input = get_network_output_size_layer(net, count-1); } - connected_layer *layer = make_connected_layer(input, output, activation); + connected_layer *layer = make_connected_layer(net.batch, input, output, activation); char *data = option_find_str(options, "data", 0); if(data){ char *curr = data; @@ -120,10 +124,11 @@ softmax_layer *parse_softmax(list *options, network net, int count) int input; if(count == 0){ input = option_find_int(options, "input",1); + net.batch = option_find_int(options, "batch",1); }else{ input = get_network_output_size_layer(net, count-1); } - softmax_layer *layer = make_softmax_layer(input); + softmax_layer *layer = make_softmax_layer(net.batch, input); option_unused(options); return layer; } @@ -136,6 +141,7 @@ maxpool_layer *parse_maxpool(list *options, network net, int count) h = option_find_int(options, "height",1); w = option_find_int(options, "width",1); c = option_find_int(options, "channels",1); + net.batch = option_find_int(options, "batch",1); }else{ image m = get_network_image_layer(net, count-1); h = m.h; @@ -143,7 +149,31 @@ maxpool_layer *parse_maxpool(list *options, network net, int count) c = m.c; if(h == 0) error("Layer before convolutional layer must output image."); } - maxpool_layer *layer = make_maxpool_layer(h,w,c,stride); + maxpool_layer *layer = make_maxpool_layer(net.batch,h,w,c,stride); + option_unused(options); + return layer; +} + +normalization_layer *parse_normalization(list *options, network net, int count) +{ + int h,w,c; + int size = option_find_int(options, "size",1); + float alpha = option_find_float(options, "alpha", 0.); + float beta = option_find_float(options, "beta", 1.); + float kappa = option_find_float(options, "kappa", 1.); + if(count == 0){ + h = option_find_int(options, "height",1); + w = option_find_int(options, "width",1); + c = option_find_int(options, "channels",1); + net.batch = option_find_int(options, "batch",1); + }else{ + image m = get_network_image_layer(net, count-1); + h = m.h; + w = m.w; + c = m.c; + if(h == 0) error("Layer before convolutional layer must output image."); + } + normalization_layer *layer = make_normalization_layer(net.batch,h,w,c,size, alpha, beta, kappa); option_unused(options); return layer; } @@ -151,7 +181,7 @@ maxpool_layer *parse_maxpool(list *options, network net, int count) network parse_network_cfg(char *filename) { list *sections = read_cfg(filename); - network net = make_network(sections->size); + network net = make_network(sections->size, 0); node *n = sections->front; int count = 0; @@ -162,18 +192,27 @@ network parse_network_cfg(char *filename) convolutional_layer *layer = parse_convolutional(options, net, count); net.types[count] = CONVOLUTIONAL; net.layers[count] = layer; + net.batch = layer->batch; }else if(is_connected(s)){ connected_layer *layer = parse_connected(options, net, count); net.types[count] = CONNECTED; net.layers[count] = layer; + net.batch = layer->batch; }else if(is_softmax(s)){ softmax_layer *layer = parse_softmax(options, net, count); net.types[count] = SOFTMAX; net.layers[count] = layer; + net.batch = layer->batch; }else if(is_maxpool(s)){ maxpool_layer *layer = parse_maxpool(options, net, count); net.types[count] = MAXPOOL; net.layers[count] = layer; + net.batch = layer->batch; + }else if(is_normalization(s)){ + normalization_layer *layer = parse_normalization(options, net, count); + net.types[count] = NORMALIZATION; + net.layers[count] = layer; + net.batch = layer->batch; }else{ fprintf(stderr, "Type not recognized: %s\n", s->type); } @@ -208,6 +247,11 @@ int is_softmax(section *s) return (strcmp(s->type, "[soft]")==0 || strcmp(s->type, "[softmax]")==0); } +int is_normalization(section *s) +{ + return (strcmp(s->type, "[lrnorm]")==0 + || strcmp(s->type, "[localresponsenormalization]")==0); +} int read_option(char *s, list *options) { diff --git a/src/softmax_layer.c b/src/softmax_layer.c index b6b7ff35..12684238 100644 --- a/src/softmax_layer.c +++ b/src/softmax_layer.c @@ -3,13 +3,14 @@ #include #include -softmax_layer *make_softmax_layer(int inputs) +softmax_layer *make_softmax_layer(int batch, int inputs) { fprintf(stderr, "Softmax Layer: %d inputs\n", inputs); softmax_layer *layer = calloc(1, sizeof(softmax_layer)); + layer->batch = batch; layer->inputs = inputs; - layer->output = calloc(inputs, sizeof(float)); - layer->delta = calloc(inputs, sizeof(float)); + layer->output = calloc(inputs*batch, sizeof(float)); + layer->delta = calloc(inputs*batch, sizeof(float)); return layer; } @@ -28,28 +29,30 @@ void forward_softmax_layer(const softmax_layer layer, float *input) */ void forward_softmax_layer(const softmax_layer layer, float *input) { - int i; - float sum = 0; - float largest = 0; - for(i = 0; i < layer.inputs; ++i){ - if(input[i] > largest) largest = input[i]; - } - for(i = 0; i < layer.inputs; ++i){ - sum += exp(input[i]-largest); - //printf("%f, ", input[i]); - } - //printf("\n"); - if(sum) sum = largest+log(sum); - else sum = largest-100; - for(i = 0; i < layer.inputs; ++i){ - layer.output[i] = exp(input[i]-sum); + int i,b; + for(b = 0; b < layer.batch; ++b){ + float sum = 0; + float largest = 0; + for(i = 0; i < layer.inputs; ++i){ + if(input[i+b*layer.inputs] > largest) largest = input[i+b*layer.inputs]; + } + for(i = 0; i < layer.inputs; ++i){ + sum += exp(input[i+b*layer.inputs]-largest); + //printf("%f, ", input[i]); + } + //printf("\n"); + if(sum) sum = largest+log(sum); + else sum = largest-100; + for(i = 0; i < layer.inputs; ++i){ + layer.output[i+b*layer.inputs] = exp(input[i+b*layer.inputs]-sum); + } } } void backward_softmax_layer(const softmax_layer layer, float *input, float *delta) { int i; - for(i = 0; i < layer.inputs; ++i){ + for(i = 0; i < layer.inputs*layer.batch; ++i){ delta[i] = layer.delta[i]; } } diff --git a/src/softmax_layer.h b/src/softmax_layer.h index bfcd390f..414030c6 100644 --- a/src/softmax_layer.h +++ b/src/softmax_layer.h @@ -3,11 +3,12 @@ typedef struct { int inputs; + int batch; float *delta; float *output; } softmax_layer; -softmax_layer *make_softmax_layer(int inputs); +softmax_layer *make_softmax_layer(int batch, int inputs); void forward_softmax_layer(const softmax_layer layer, float *input); void backward_softmax_layer(const softmax_layer layer, float *input, float *delta); diff --git a/src/tests.c b/src/tests.c index 1c7b01d9..0319947f 100644 --- a/src/tests.c +++ b/src/tests.c @@ -1,5 +1,4 @@ #include "connected_layer.h" -//#include "old_conv.h" #include "convolutional_layer.h" #include "maxpool_layer.h" #include "network.h" @@ -19,649 +18,796 @@ void test_convolve() { - image dog = load_image("dog.jpg",300,400); - printf("dog channels %d\n", dog.c); - image kernel = make_random_image(3,3,dog.c); - image edge = make_image(dog.h, dog.w, 1); - int i; - clock_t start = clock(), end; - for(i = 0; i < 1000; ++i){ - convolve(dog, kernel, 1, 0, edge, 1); - } - end = clock(); - printf("Convolutions: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); - show_image_layers(edge, "Test Convolve"); + image dog = load_image("dog.jpg",300,400); + printf("dog channels %d\n", dog.c); + image kernel = make_random_image(3,3,dog.c); + image edge = make_image(dog.h, dog.w, 1); + int i; + clock_t start = clock(), end; + for(i = 0; i < 1000; ++i){ + convolve(dog, kernel, 1, 0, edge, 1); + } + end = clock(); + printf("Convolutions: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); + show_image_layers(edge, "Test Convolve"); } void test_convolve_matrix() { - image dog = load_image("dog.jpg",300,400); - printf("dog channels %d\n", dog.c); - - int size = 11; - int stride = 4; - int n = 40; - float *filters = make_random_image(size, size, dog.c*n).data; + image dog = load_image("dog.jpg",300,400); + printf("dog channels %d\n", dog.c); - int mw = ((dog.h-size)/stride+1)*((dog.w-size)/stride+1); - int mh = (size*size*dog.c); - float *matrix = calloc(mh*mw, sizeof(float)); + int size = 11; + int stride = 4; + int n = 40; + float *filters = make_random_image(size, size, dog.c*n).data; - image edge = make_image((dog.h-size)/stride+1, (dog.w-size)/stride+1, n); + int mw = ((dog.h-size)/stride+1)*((dog.w-size)/stride+1); + int mh = (size*size*dog.c); + float *matrix = calloc(mh*mw, sizeof(float)); + + image edge = make_image((dog.h-size)/stride+1, (dog.w-size)/stride+1, n); - int i; - clock_t start = clock(), end; - for(i = 0; i < 1000; ++i){ - im2col_cpu(dog.data, dog.c, dog.h, dog.w, size, stride, matrix); - gemm(0,0,n,mw,mh,1,filters,mh,matrix,mw,1,edge.data,mw); - } - end = clock(); - printf("Convolutions: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); - show_image_layers(edge, "Test Convolve"); - cvWaitKey(0); + int i; + clock_t start = clock(), end; + for(i = 0; i < 1000; ++i){ + im2col_cpu(dog.data, dog.c, dog.h, dog.w, size, stride, matrix); + gemm(0,0,n,mw,mh,1,filters,mh,matrix,mw,1,edge.data,mw); + } + end = clock(); + printf("Convolutions: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); + show_image_layers(edge, "Test Convolve"); + cvWaitKey(0); } void test_color() { - image dog = load_image("test_color.png", 300, 400); - show_image_layers(dog, "Test Color"); + image dog = load_image("test_color.png", 300, 400); + show_image_layers(dog, "Test Color"); } void verify_convolutional_layer() { - srand(0); - int i; - int n = 1; - int stride = 1; - int size = 3; - float eps = .00000001; - image test = make_random_image(5,5, 1); - convolutional_layer layer = *make_convolutional_layer(test.h,test.w,test.c, n, size, stride, RELU); - image out = get_convolutional_image(layer); - float **jacobian = calloc(test.h*test.w*test.c, sizeof(float)); - - forward_convolutional_layer(layer, test.data); - image base = copy_image(out); + srand(0); + int i; + int n = 1; + int stride = 1; + int size = 3; + float eps = .00000001; + image test = make_random_image(5,5, 1); + convolutional_layer layer = *make_convolutional_layer(1,test.h,test.w,test.c, n, size, stride, RELU); + image out = get_convolutional_image(layer); + float **jacobian = calloc(test.h*test.w*test.c, sizeof(float)); - for(i = 0; i < test.h*test.w*test.c; ++i){ - test.data[i] += eps; - forward_convolutional_layer(layer, test.data); - image partial = copy_image(out); - subtract_image(partial, base); - scale_image(partial, 1/eps); - jacobian[i] = partial.data; - test.data[i] -= eps; - } - float **jacobian2 = calloc(out.h*out.w*out.c, sizeof(float)); - image in_delta = make_image(test.h, test.w, test.c); - image out_delta = get_convolutional_delta(layer); - for(i = 0; i < out.h*out.w*out.c; ++i){ - out_delta.data[i] = 1; - backward_convolutional_layer(layer, in_delta.data); - image partial = copy_image(in_delta); - jacobian2[i] = partial.data; - out_delta.data[i] = 0; - } - int j; - float *j1 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(float)); - float *j2 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(float)); - for(i = 0; i < test.h*test.w*test.c; ++i){ - for(j =0 ; j < out.h*out.w*out.c; ++j){ - j1[i*out.h*out.w*out.c + j] = jacobian[i][j]; - j2[i*out.h*out.w*out.c + j] = jacobian2[j][i]; - printf("%f %f\n", jacobian[i][j], jacobian2[j][i]); - } - } + forward_convolutional_layer(layer, test.data); + image base = copy_image(out); + + for(i = 0; i < test.h*test.w*test.c; ++i){ + test.data[i] += eps; + forward_convolutional_layer(layer, test.data); + image partial = copy_image(out); + subtract_image(partial, base); + scale_image(partial, 1/eps); + jacobian[i] = partial.data; + test.data[i] -= eps; + } + float **jacobian2 = calloc(out.h*out.w*out.c, sizeof(float)); + image in_delta = make_image(test.h, test.w, test.c); + image out_delta = get_convolutional_delta(layer); + for(i = 0; i < out.h*out.w*out.c; ++i){ + out_delta.data[i] = 1; + backward_convolutional_layer(layer, in_delta.data); + image partial = copy_image(in_delta); + jacobian2[i] = partial.data; + out_delta.data[i] = 0; + } + int j; + float *j1 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(float)); + float *j2 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(float)); + for(i = 0; i < test.h*test.w*test.c; ++i){ + for(j =0 ; j < out.h*out.w*out.c; ++j){ + j1[i*out.h*out.w*out.c + j] = jacobian[i][j]; + j2[i*out.h*out.w*out.c + j] = jacobian2[j][i]; + printf("%f %f\n", jacobian[i][j], jacobian2[j][i]); + } + } - image mj1 = float_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j1); - image mj2 = float_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j2); - printf("%f %f\n", avg_image_layer(mj1,0), avg_image_layer(mj2,0)); - show_image(mj1, "forward jacobian"); - show_image(mj2, "backward jacobian"); + image mj1 = float_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j1); + image mj2 = float_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j2); + printf("%f %f\n", avg_image_layer(mj1,0), avg_image_layer(mj2,0)); + show_image(mj1, "forward jacobian"); + show_image(mj2, "backward jacobian"); } void test_load() { - image dog = load_image("dog.jpg", 300, 400); - show_image(dog, "Test Load"); - show_image_layers(dog, "Test Load"); + image dog = load_image("dog.jpg", 300, 400); + show_image(dog, "Test Load"); + show_image_layers(dog, "Test Load"); } void test_upsample() { - image dog = load_image("dog.jpg", 300, 400); - int n = 3; - image up = make_image(n*dog.h, n*dog.w, dog.c); - upsample_image(dog, n, up); - show_image(up, "Test Upsample"); - show_image_layers(up, "Test Upsample"); + image dog = load_image("dog.jpg", 300, 400); + int n = 3; + image up = make_image(n*dog.h, n*dog.w, dog.c); + upsample_image(dog, n, up); + show_image(up, "Test Upsample"); + show_image_layers(up, "Test Upsample"); } void test_rotate() { - int i; - image dog = load_image("dog.jpg",300,400); - clock_t start = clock(), end; - for(i = 0; i < 1001; ++i){ - rotate_image(dog); - } - end = clock(); - printf("Rotations: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); - show_image(dog, "Test Rotate"); + int i; + image dog = load_image("dog.jpg",300,400); + clock_t start = clock(), end; + for(i = 0; i < 1001; ++i){ + rotate_image(dog); + } + end = clock(); + printf("Rotations: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); + show_image(dog, "Test Rotate"); - image random = make_random_image(3,3,3); - show_image(random, "Test Rotate Random"); - rotate_image(random); - show_image(random, "Test Rotate Random"); - rotate_image(random); - show_image(random, "Test Rotate Random"); + image random = make_random_image(3,3,3); + show_image(random, "Test Rotate Random"); + rotate_image(random); + show_image(random, "Test Rotate Random"); + rotate_image(random); + show_image(random, "Test Rotate Random"); } void test_parser() { - network net = parse_network_cfg("test_parser.cfg"); - float input[1]; - int count = 0; - - float avgerr = 0; - while(++count < 100000000){ - float v = ((float)rand()/RAND_MAX); - float truth = v*v; - input[0] = v; - forward_network(net, input); - float *out = get_network_output(net); - float *delta = get_network_delta(net); - float err = pow((out[0]-truth),2.); - avgerr = .99 * avgerr + .01 * err; - if(count % 1000000 == 0) printf("%f %f :%f AVG %f \n", truth, out[0], err, avgerr); - delta[0] = truth - out[0]; - backward_network(net, input, &truth); - update_network(net, .001,0,0); - } + network net = parse_network_cfg("test_parser.cfg"); + float input[1]; + int count = 0; + + float avgerr = 0; + while(++count < 100000000){ + float v = ((float)rand()/RAND_MAX); + float truth = v*v; + input[0] = v; + forward_network(net, input); + float *out = get_network_output(net); + float *delta = get_network_delta(net); + float err = pow((out[0]-truth),2.); + avgerr = .99 * avgerr + .01 * err; + if(count % 1000000 == 0) printf("%f %f :%f AVG %f \n", truth, out[0], err, avgerr); + delta[0] = truth - out[0]; + backward_network(net, input, &truth); + update_network(net, .001,0,0); + } } void test_data() { - char *labels[] = {"cat","dog"}; - data train = load_data_image_pathfile_random("train_paths.txt", 101,labels, 2, 300, 400); - free_data(train); + char *labels[] = {"cat","dog"}; + data train = load_data_image_pathfile_random("train_paths.txt", 101,labels, 2, 300, 400); + free_data(train); } void train_full() { - network net = parse_network_cfg("cfg/imagenet.cfg"); - srand(2222222); - int i = 0; - char *labels[] = {"cat","dog"}; - float lr = .00001; - float momentum = .9; - float decay = 0.01; - while(1){ - i += 1000; - data train = load_data_image_pathfile_random("images/assira/train.list", 1000, labels, 2, 256, 256); - image im = float_to_image(256, 256, 3,train.X.vals[0]); - //visualize_network(net); - //cvWaitKey(100); - //show_image(im, "input"); - //cvWaitKey(100); - //scale_data_rows(train, 1./255.); - normalize_data_rows(train); - clock_t start = clock(), end; - float loss = train_network_sgd(net, train, 1000, lr, momentum, decay); - end = clock(); - printf("%d: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", i, loss, (float)(end-start)/CLOCKS_PER_SEC, lr, momentum, decay); - free_data(train); - if(i%10000==0){ - char buff[256]; - sprintf(buff, "cfg/assira_backup_%d.cfg", i); - save_network(net, buff); - } - //lr *= .99; - } + network net = parse_network_cfg("cfg/imagenet.cfg"); + srand(2222222); + int i = 0; + char *labels[] = {"cat","dog"}; + float lr = .00001; + float momentum = .9; + float decay = 0.01; + while(1){ + i += 1000; + data train = load_data_image_pathfile_random("images/assira/train.list", 1000, labels, 2, 256, 256); + //image im = float_to_image(256, 256, 3,train.X.vals[0]); + //visualize_network(net); + //cvWaitKey(100); + //show_image(im, "input"); + //cvWaitKey(100); + //scale_data_rows(train, 1./255.); + normalize_data_rows(train); + clock_t start = clock(), end; + float loss = train_network_sgd(net, train, 1000, lr, momentum, decay); + end = clock(); + printf("%d: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", i, loss, (float)(end-start)/CLOCKS_PER_SEC, lr, momentum, decay); + free_data(train); + if(i%10000==0){ + char buff[256]; + sprintf(buff, "cfg/assira_backup_%d.cfg", i); + save_network(net, buff); + } + //lr *= .99; + } +} + +void test_visualize() +{ + network net = parse_network_cfg("cfg/voc_imagenet.cfg"); + srand(2222222); + visualize_network(net); + cvWaitKey(0); } void test_full() { - network net = parse_network_cfg("cfg/backup_1300.cfg"); - srand(2222222); - int i,j; - int total = 100; - char *labels[] = {"cat","dog"}; - FILE *fp = fopen("preds.txt","w"); - for(i = 0; i < total; ++i){ - visualize_network(net); - cvWaitKey(100); - data test = load_data_image_pathfile_part("images/assira/test.list", i, total, labels, 2, 256, 256); - image im = float_to_image(256, 256, 3,test.X.vals[0]); - show_image(im, "input"); - cvWaitKey(100); - normalize_data_rows(test); - for(j = 0; j < test.X.rows; ++j){ - float *x = test.X.vals[j]; - forward_network(net, x); - int class = get_predicted_class_network(net); - fprintf(fp, "%d\n", class); - } - free_data(test); - } - fclose(fp); + network net = parse_network_cfg("cfg/backup_1300.cfg"); + srand(2222222); + int i,j; + int total = 100; + char *labels[] = {"cat","dog"}; + FILE *fp = fopen("preds.txt","w"); + for(i = 0; i < total; ++i){ + visualize_network(net); + cvWaitKey(100); + data test = load_data_image_pathfile_part("images/assira/test.list", i, total, labels, 2, 256, 256); + image im = float_to_image(256, 256, 3,test.X.vals[0]); + show_image(im, "input"); + cvWaitKey(100); + normalize_data_rows(test); + for(j = 0; j < test.X.rows; ++j){ + float *x = test.X.vals[j]; + forward_network(net, x); + int class = get_predicted_class_network(net); + fprintf(fp, "%d\n", class); + } + free_data(test); + } + fclose(fp); +} + +void test_cifar10() +{ + data test = load_cifar10_data("images/cifar10/test_batch.bin"); + scale_data_rows(test, 1./255); + network net = parse_network_cfg("cfg/cifar10.cfg"); + int count = 0; + float lr = .000005; + float momentum = .99; + float decay = 0.001; + decay = 0; + int batch = 10000; + while(++count <= 10000){ + char buff[256]; + sprintf(buff, "images/cifar10/data_batch_%d.bin", rand()%5+1); + data train = load_cifar10_data(buff); + scale_data_rows(train, 1./255); + train_network_sgd(net, train, batch, lr, momentum, decay); + //printf("%5f %5f\n",(double)count*batch/train.X.rows, loss); + + float test_acc = network_accuracy(net, test); + printf("%5f %5f\n",(double)count*batch/train.X.rows/5, 1-test_acc); + free_data(train); + } + +} + +void test_vince() +{ + network net = parse_network_cfg("cfg/vince.cfg"); + data train = load_categorical_data_csv("images/vince.txt", 144, 2); + normalize_data_rows(train); + + int count = 0; + float lr = .00005; + float momentum = .9; + float decay = 0.0001; + decay = 0; + int batch = 10000; + while(++count <= 10000){ + float loss = train_network_sgd(net, train, batch, lr, momentum, decay); + printf("%5f %5f\n",(double)count*batch/train.X.rows, loss); + } } void test_nist() { - srand(444444); - srand(888888); - network net = parse_network_cfg("nist.cfg"); - data train = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10); - data test = load_categorical_data_csv("mnist/mnist_test.csv",0,10); - normalize_data_rows(train); - normalize_data_rows(test); - //randomize_data(train); - int count = 0; - float lr = .0005; - float momentum = .9; - float decay = 0.001; - clock_t start = clock(), end; - while(++count <= 100){ - //visualize_network(net); - float loss = train_network_sgd(net, train, 1000, lr, momentum, decay); - printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*100, loss, lr, momentum, decay); - end = clock(); - printf("Time: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); - start=end; - //cvWaitKey(100); - //lr /= 2; - if(count%5 == 0){ - float train_acc = network_accuracy(net, train); - fprintf(stderr, "\nTRAIN: %f\n", train_acc); - float test_acc = network_accuracy(net, test); - fprintf(stderr, "TEST: %f\n\n", test_acc); - printf("%d, %f, %f\n", count, train_acc, test_acc); - //lr *= .5; - } - } + srand(444444); + srand(888888); + network net = parse_network_cfg("cfg/nist_basic.cfg"); + data train = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10); + data test = load_categorical_data_csv("mnist/mnist_test.csv",0,10); + normalize_data_rows(train); + normalize_data_rows(test); + //randomize_data(train); + int count = 0; + float lr = .00005; + float momentum = .9; + float decay = 0.0001; + decay = 0; + //clock_t start = clock(), end; + int batch = 10000; + while(++count <= 10000){ + float loss = train_network_sgd(net, train, batch, lr, momentum, decay); + printf("%5f %5f\n",(double)count*batch/train.X.rows, loss); + //printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*1000, loss, lr, momentum, decay); + //end = clock(); + //printf("Time: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); + //start=end; + /* + if(count%5 == 0){ + float train_acc = network_accuracy(net, train); + fprintf(stderr, "\nTRAIN: %f\n", train_acc); + float test_acc = network_accuracy(net, test); + fprintf(stderr, "TEST: %f\n\n", test_acc); + printf("%d, %f, %f\n", count, train_acc, test_acc); + //lr *= .5; + } + */ + } } void test_ensemble() { - int i; - srand(888888); - data d = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10); - normalize_data_rows(d); - data test = load_categorical_data_csv("mnist/mnist_test.csv", 0,10); - normalize_data_rows(test); - data train = d; - // data *split = split_data(d, 1, 10); - // data train = split[0]; - // data test = split[1]; - matrix prediction = make_matrix(test.y.rows, test.y.cols); - int n = 30; - for(i = 0; i < n; ++i){ - int count = 0; - float lr = .0005; - float momentum = .9; - float decay = .01; - network net = parse_network_cfg("nist.cfg"); - while(++count <= 15){ - float acc = train_network_sgd(net, train, train.X.rows, lr, momentum, decay); - printf("Training Accuracy: %lf Learning Rate: %f Momentum: %f Decay: %f\n", acc, lr, momentum, decay ); - lr /= 2; - } - matrix partial = network_predict_data(net, test); - float acc = matrix_accuracy(test.y, partial); - printf("Model Accuracy: %lf\n", acc); - matrix_add_matrix(partial, prediction); - acc = matrix_accuracy(test.y, prediction); - printf("Current Ensemble Accuracy: %lf\n", acc); - free_matrix(partial); - } - float acc = matrix_accuracy(test.y, prediction); - printf("Full Ensemble Accuracy: %lf\n", acc); + int i; + srand(888888); + data d = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10); + normalize_data_rows(d); + data test = load_categorical_data_csv("mnist/mnist_test.csv", 0,10); + normalize_data_rows(test); + data train = d; + // data *split = split_data(d, 1, 10); + // data train = split[0]; + // data test = split[1]; + matrix prediction = make_matrix(test.y.rows, test.y.cols); + int n = 30; + for(i = 0; i < n; ++i){ + int count = 0; + float lr = .0005; + float momentum = .9; + float decay = .01; + network net = parse_network_cfg("nist.cfg"); + while(++count <= 15){ + float acc = train_network_sgd(net, train, train.X.rows, lr, momentum, decay); + printf("Training Accuracy: %lf Learning Rate: %f Momentum: %f Decay: %f\n", acc, lr, momentum, decay ); + lr /= 2; + } + matrix partial = network_predict_data(net, test); + float acc = matrix_accuracy(test.y, partial); + printf("Model Accuracy: %lf\n", acc); + matrix_add_matrix(partial, prediction); + acc = matrix_accuracy(test.y, prediction); + printf("Current Ensemble Accuracy: %lf\n", acc); + free_matrix(partial); + } + float acc = matrix_accuracy(test.y, prediction); + printf("Full Ensemble Accuracy: %lf\n", acc); } void test_random_classify() { - network net = parse_network_cfg("connected.cfg"); - matrix m = csv_to_matrix("train.csv"); - //matrix ho = hold_out_matrix(&m, 2500); - float *truth = pop_column(&m, 0); - //float *ho_truth = pop_column(&ho, 0); - int i; - clock_t start = clock(), end; - int count = 0; - while(++count <= 300){ - for(i = 0; i < m.rows; ++i){ - int index = rand()%m.rows; - //image p = float_to_image(1690,1,1,m.vals[index]); - //normalize_image(p); - forward_network(net, m.vals[index]); - float *out = get_network_output(net); - float *delta = get_network_delta(net); - //printf("%f\n", out[0]); - delta[0] = truth[index] - out[0]; - // printf("%f\n", delta[0]); - //printf("%f %f\n", truth[index], out[0]); - //backward_network(net, m.vals[index], ); - update_network(net, .00001, 0,0); - } - //float test_acc = error_network(net, m, truth); - //float valid_acc = error_network(net, ho, ho_truth); - //printf("%f, %f\n", test_acc, valid_acc); - //fprintf(stderr, "%5d: %f Valid: %f\n",count, test_acc, valid_acc); - //if(valid_acc > .70) break; - } - end = clock(); - FILE *fp = fopen("submission/out.txt", "w"); - matrix test = csv_to_matrix("test.csv"); - truth = pop_column(&test, 0); - for(i = 0; i < test.rows; ++i){ - forward_network(net, test.vals[i]); - float *out = get_network_output(net); - if(fabs(out[0]) < .5) fprintf(fp, "0\n"); - else fprintf(fp, "1\n"); - } - fclose(fp); - printf("Neural Net Learning: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); + network net = parse_network_cfg("connected.cfg"); + matrix m = csv_to_matrix("train.csv"); + //matrix ho = hold_out_matrix(&m, 2500); + float *truth = pop_column(&m, 0); + //float *ho_truth = pop_column(&ho, 0); + int i; + clock_t start = clock(), end; + int count = 0; + while(++count <= 300){ + for(i = 0; i < m.rows; ++i){ + int index = rand()%m.rows; + //image p = float_to_image(1690,1,1,m.vals[index]); + //normalize_image(p); + forward_network(net, m.vals[index]); + float *out = get_network_output(net); + float *delta = get_network_delta(net); + //printf("%f\n", out[0]); + delta[0] = truth[index] - out[0]; + // printf("%f\n", delta[0]); + //printf("%f %f\n", truth[index], out[0]); + //backward_network(net, m.vals[index], ); + update_network(net, .00001, 0,0); + } + //float test_acc = error_network(net, m, truth); + //float valid_acc = error_network(net, ho, ho_truth); + //printf("%f, %f\n", test_acc, valid_acc); + //fprintf(stderr, "%5d: %f Valid: %f\n",count, test_acc, valid_acc); + //if(valid_acc > .70) break; + } + end = clock(); + FILE *fp = fopen("submission/out.txt", "w"); + matrix test = csv_to_matrix("test.csv"); + truth = pop_column(&test, 0); + for(i = 0; i < test.rows; ++i){ + forward_network(net, test.vals[i]); + float *out = get_network_output(net); + if(fabs(out[0]) < .5) fprintf(fp, "0\n"); + else fprintf(fp, "1\n"); + } + fclose(fp); + printf("Neural Net Learning: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); } void test_split() { - data train = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10); - data *split = split_data(train, 0, 13); - printf("%d, %d, %d\n", train.X.rows, split[0].X.rows, split[1].X.rows); + data train = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10); + data *split = split_data(train, 0, 13); + printf("%d, %d, %d\n", train.X.rows, split[0].X.rows, split[1].X.rows); } void test_im2row() { - int h = 20; - int w = 20; - int c = 3; - int stride = 1; - int size = 11; - image test = make_random_image(h,w,c); - int mc = 1; - int mw = ((h-size)/stride+1)*((w-size)/stride+1); - int mh = (size*size*c); - int msize = mc*mw*mh; - float *matrix = calloc(msize, sizeof(float)); - int i; - for(i = 0; i < 1000; ++i){ - im2col_cpu(test.data, c, h, w, size, stride, matrix); - //image render = float_to_image(mh, mw, mc, matrix); - } + int h = 20; + int w = 20; + int c = 3; + int stride = 1; + int size = 11; + image test = make_random_image(h,w,c); + int mc = 1; + int mw = ((h-size)/stride+1)*((w-size)/stride+1); + int mh = (size*size*c); + int msize = mc*mw*mh; + float *matrix = calloc(msize, sizeof(float)); + int i; + for(i = 0; i < 1000; ++i){ + im2col_cpu(test.data, c, h, w, size, stride, matrix); + //image render = float_to_image(mh, mw, mc, matrix); + } +} + +void flip_network() +{ + network net = parse_network_cfg("cfg/voc_imagenet_orig.cfg"); + save_network(net, "cfg/voc_imagenet_rev.cfg"); } void train_VOC() { - network net = parse_network_cfg("cfg/voc_start.cfg"); - srand(2222222); - int i = 20; - char *labels[] = {"aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair","cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep","sofa","train","tvmonitor"}; - float lr = .00001; - float momentum = .9; - float decay = 0.01; - while(i++ < 1000 || 1){ - data train = load_data_image_pathfile_random("images/VOC2012/val_paths.txt", 1000, labels, 20, 300, 400); + network net = parse_network_cfg("cfg/voc_start.cfg"); + srand(2222222); + int i = 20; + char *labels[] = {"aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair","cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep","sofa","train","tvmonitor"}; + float lr = .00001; + float momentum = .9; + float decay = 0.01; + while(i++ < 1000 || 1){ + data train = load_data_image_pathfile_random("images/VOC2012/val_paths.txt", 1000, labels, 20, 300, 400); - image im = float_to_image(300, 400, 3,train.X.vals[0]); - show_image(im, "input"); - visualize_network(net); - cvWaitKey(100); + image im = float_to_image(300, 400, 3,train.X.vals[0]); + show_image(im, "input"); + visualize_network(net); + cvWaitKey(100); - normalize_data_rows(train); - clock_t start = clock(), end; - float loss = train_network_sgd(net, train, 1000, lr, momentum, decay); - end = clock(); - printf("%d: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", i, loss, (float)(end-start)/CLOCKS_PER_SEC, lr, momentum, decay); - free_data(train); - if(i%10==0){ - char buff[256]; - sprintf(buff, "cfg/voc_clean_ramp_%d.cfg", i); - save_network(net, buff); - } - //lr *= .99; - } + normalize_data_rows(train); + clock_t start = clock(), end; + float loss = train_network_sgd(net, train, 1000, lr, momentum, decay); + end = clock(); + printf("%d: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", i, loss, (float)(end-start)/CLOCKS_PER_SEC, lr, momentum, decay); + free_data(train); + if(i%10==0){ + char buff[256]; + sprintf(buff, "cfg/voc_clean_ramp_%d.cfg", i); + save_network(net, buff); + } + //lr *= .99; + } } int voc_size(int x) { - x = x-1+3; - x = x-1+3; - x = x-1+3; - x = (x-1)*2+1; - x = x-1+5; - x = (x-1)*2+1; - x = (x-1)*4+11; - return x; + x = x-1+3; + x = x-1+3; + x = x-1+3; + x = (x-1)*2+1; + x = x-1+5; + x = (x-1)*2+1; + x = (x-1)*4+11; + return x; } image features_output_size(network net, IplImage *src, int outh, int outw) { - int h = voc_size(outh); - int w = voc_size(outw); - printf("%d %d\n", h, w); + int h = voc_size(outh); + int w = voc_size(outw); + fprintf(stderr, "%d %d\n", h, w); - IplImage *sized = cvCreateImage(cvSize(w,h), src->depth, src->nChannels); - cvResize(src, sized, CV_INTER_LINEAR); - image im = ipl_to_image(sized); - reset_network_size(net, im.h, im.w, im.c); - forward_network(net, im.data); - image out = get_network_image_layer(net, 6); - //printf("%d %d\n%d %d\n", outh, out.h, outw, out.w); - free_image(im); - cvReleaseImage(&sized); - return copy_image(out); + IplImage *sized = cvCreateImage(cvSize(w,h), src->depth, src->nChannels); + cvResize(src, sized, CV_INTER_LINEAR); + image im = ipl_to_image(sized); + normalize_array(im.data, im.h*im.w*im.c); + resize_network(net, im.h, im.w, im.c); + forward_network(net, im.data); + image out = get_network_image_layer(net, 6); + free_image(im); + cvReleaseImage(&sized); + return copy_image(out); } -void features_VOC(int part, int total) +void features_VOC_image_size(char *image_path, int h, int w) { - int i,j, count = 0; - network net = parse_network_cfg("cfg/voc_imagenet.cfg"); - char *path_file = "images/VOC2012/all_paths.txt"; - char *out_dir = "voc_features/"; - list *paths = get_paths(path_file); - node *n = paths->front; - int size = paths->size; - for(count = 0; count < part*size/total; ++count) n = n->next; - while(n && count++ < (part+1)*size/total){ - char *path = (char *)n->val; - char buff[1024]; - sprintf(buff, "%s%s.txt",out_dir, path); - printf("%s\n", path); - FILE *fp = fopen(buff, "w"); - if(fp == 0) file_error(buff); + int j; + network net = parse_network_cfg("cfg/voc_imagenet.cfg"); + fprintf(stderr, "%s\n", image_path); - IplImage* src = 0; - if( (src = cvLoadImage(path,-1)) == 0 ) - { - printf("Cannot load file image %s\n", path); - exit(0); - } - int w = src->width; - int h = src->height; - int sbin = 8; - int interval = 10; - double scale = pow(2., 1./interval); - int m = (wfront; + int h = voc_size(1), w = voc_size(1); + int num = get_network_image(net).c; + image **vizs = calloc(num, sizeof(image*)); + float **score = calloc(num, sizeof(float *)); + for(i = 0; i < num; ++i){ + vizs[i] = calloc(topk, sizeof(image)); + for(j = 0; j < topk; ++j) vizs[i][j] = make_image(h,w,3); + score[i] = calloc(topk, sizeof(float)); + } - for(i = 0; i < interval; ++i){ - double factor = 1./pow(scale, i); - double ih = round(h*factor); - double iw = round(w*factor); - int ex_h = round(ih/4.) - 2; - int ex_w = round(iw/4.) - 2; - ims[i] = features_output_size(net, src, ex_h, ex_w); + int count = 0; + while(n){ + ++count; + char *image_path = (char *)n->val; + image im = load_image(image_path, 0, 0); + n = n->next; + if(im.h < 200 || im.w < 200) continue; + printf("Processing %dx%d image\n", im.h, im.w); + resize_network(net, im.h, im.w, im.c); + //scale_image(im, 1./255); + translate_image(im, -144); + forward_network(net, im.data); + image out = get_network_image(net); - ih = round(h*factor); - iw = round(w*factor); - ex_h = round(ih/8.) - 2; - ex_w = round(iw/8.) - 2; - ims[i+interval] = features_output_size(net, src, ex_h, ex_w); - for(j = i+interval; j < max_scale; j += interval){ - factor /= 2.; - ih = round(h*factor); - iw = round(w*factor); - ex_h = round(ih/8.) - 2; - ex_w = round(iw/8.) - 2; - ims[j+interval] = features_output_size(net, src, ex_h, ex_w); - } - } - for(i = 0; i < max_scale+interval; ++i){ - image out = ims[i]; - //printf("%d, %d\n", out.h, out.w); - fprintf(fp, "%d, %d, %d\n",out.c, out.h, out.w); - for(j = 0; j < out.c*out.h*out.w; ++j){ - if(j != 0)fprintf(fp, ","); - fprintf(fp, "%g", out.data[j]); - } - fprintf(fp, "\n"); - free_image(out); - } - free(ims); - fclose(fp); - cvReleaseImage(&src); - n = n->next; - } + int dh = (im.h - h)/(out.h-1); + int dw = (im.w - w)/(out.w-1); + //printf("%d %d\n", dh, dw); + for(k = 0; k < out.c; ++k){ + float topv = 0; + int topi = -1; + int topj = -1; + for(i = 0; i < out.h; ++i){ + for(j = 0; j < out.w; ++j){ + float val = get_pixel(out, i, j, k); + if(val > topv){ + topv = val; + topi = i; + topj = j; + } + } + } + if(topv){ + image sub = get_sub_image(im, dh*topi, dw*topj, h, w); + for(l = 0; l < topk; ++l){ + if(topv > score[k][l]){ + float swap = score[k][l]; + score[k][l] = topv; + topv = swap; + + image swapi = vizs[k][l]; + vizs[k][l] = sub; + sub = swapi; + } + } + free_image(sub); + } + } + free_image(im); + if(count%50 == 0){ + image grid = grid_images(vizs, num, topk); + //show_image(grid, "IMAGENET Visualization"); + save_image(grid, "IMAGENET Grid Single Nonorm"); + free_image(grid); + } + } + //cvWaitKey(0); +} + +void visualize_imagenet_features(char *filename) +{ + int i,j,k; + network net = parse_network_cfg("cfg/voc_imagenet.cfg"); + list *plist = get_paths(filename); + node *n = plist->front; + int h = voc_size(1), w = voc_size(1); + int num = get_network_image(net).c; + image *vizs = calloc(num, sizeof(image)); + for(i = 0; i < num; ++i) vizs[i] = make_image(h, w, 3); + while(n){ + char *image_path = (char *)n->val; + image im = load_image(image_path, 0, 0); + printf("Processing %dx%d image\n", im.h, im.w); + resize_network(net, im.h, im.w, im.c); + forward_network(net, im.data); + image out = get_network_image(net); + + int dh = (im.h - h)/h; + int dw = (im.w - w)/w; + for(i = 0; i < out.h; ++i){ + for(j = 0; j < out.w; ++j){ + image sub = get_sub_image(im, dh*i, dw*j, h, w); + for(k = 0; k < out.c; ++k){ + float val = get_pixel(out, i, j, k); + //printf("%f, ", val); + image sub_c = copy_image(sub); + scale_image(sub_c, val); + add_into_image(sub_c, vizs[k], 0, 0); + free_image(sub_c); + } + free_image(sub); + } + } + //printf("\n"); + show_images(vizs, 10, "IMAGENET Visualization"); + cvWaitKey(1000); + n = n->next; + } + cvWaitKey(0); +} + +void visualize_cat() +{ + network net = parse_network_cfg("cfg/voc_imagenet.cfg"); + image im = load_image("data/cat.png", 0, 0); + printf("Processing %dx%d image\n", im.h, im.w); + resize_network(net, im.h, im.w, im.c); + forward_network(net, im.data); + + image out = get_network_image(net); + visualize_network(net); + cvWaitKey(1000); + cvWaitKey(0); } void features_VOC_image(char *image_file, char *image_dir, char *out_dir) { int flip = 1; - int interval = 4; - int i,j; - network net = parse_network_cfg("cfg/voc_imagenet.cfg"); - char image_path[1024]; - sprintf(image_path, "%s%s",image_dir, image_file); - char out_path[1024]; - if (flip)sprintf(out_path, "%s%d/%s_r.txt",out_dir, interval, image_file); + int interval = 4; + int i,j; + network net = parse_network_cfg("cfg/voc_imagenet.cfg"); + char image_path[1024]; + sprintf(image_path, "%s/%s",image_dir, image_file); + char out_path[1024]; + if (flip)sprintf(out_path, "%s%d/%s_r.txt",out_dir, interval, image_file); else sprintf(out_path, "%s%d/%s.txt",out_dir, interval, image_file); - printf("%s\n", image_file); - FILE *fp = fopen(out_path, "w"); - if(fp == 0) file_error(out_path); + printf("%s\n", image_file); + FILE *fp = fopen(out_path, "w"); + if(fp == 0) file_error(out_path); - IplImage* src = 0; - if( (src = cvLoadImage(image_path,-1)) == 0 ) file_error(image_path); -if(flip)cvFlip(src, 0, 1); - int w = src->width; - int h = src->height; - int sbin = 8; - double scale = pow(2., 1./interval); - int m = (wwidth; + int h = src->height; + int sbin = 8; + double scale = pow(2., 1./interval); + int m = (w= interval"); + image *ims = calloc(max_scale+interval, sizeof(image)); - for(i = 0; i < interval; ++i){ - double factor = 1./pow(scale, i); - double ih = round(h*factor); - double iw = round(w*factor); - int ex_h = round(ih/4.) - 2; - int ex_w = round(iw/4.) - 2; - ims[i] = features_output_size(net, src, ex_h, ex_w); + for(i = 0; i < interval; ++i){ + double factor = 1./pow(scale, i); + double ih = round(h*factor); + double iw = round(w*factor); + int ex_h = round(ih/4.) - 2; + int ex_w = round(iw/4.) - 2; + ims[i] = features_output_size(net, src, ex_h, ex_w); - ih = round(h*factor); - iw = round(w*factor); - ex_h = round(ih/8.) - 2; - ex_w = round(iw/8.) - 2; - ims[i+interval] = features_output_size(net, src, ex_h, ex_w); - for(j = i+interval; j < max_scale; j += interval){ - factor /= 2.; - ih = round(h*factor); - iw = round(w*factor); - ex_h = round(ih/8.) - 2; - ex_w = round(iw/8.) - 2; - ims[j+interval] = features_output_size(net, src, ex_h, ex_w); - } - } - for(i = 0; i < max_scale+interval; ++i){ - image out = ims[i]; - fprintf(fp, "%d, %d, %d\n",out.c, out.h, out.w); - for(j = 0; j < out.c*out.h*out.w; ++j){ - if(j != 0)fprintf(fp, ","); - fprintf(fp, "%g", out.data[j]); - } - fprintf(fp, "\n"); - free_image(out); - } - free(ims); - fclose(fp); - cvReleaseImage(&src); + ih = round(h*factor); + iw = round(w*factor); + ex_h = round(ih/8.) - 2; + ex_w = round(iw/8.) - 2; + ims[i+interval] = features_output_size(net, src, ex_h, ex_w); + for(j = i+interval; j < max_scale; j += interval){ + factor /= 2.; + ih = round(h*factor); + iw = round(w*factor); + ex_h = round(ih/8.) - 2; + ex_w = round(iw/8.) - 2; + ims[j+interval] = features_output_size(net, src, ex_h, ex_w); + } + } + for(i = 0; i < max_scale+interval; ++i){ + image out = ims[i]; + fprintf(fp, "%d, %d, %d\n",out.c, out.h, out.w); + for(j = 0; j < out.c*out.h*out.w; ++j){ + if(j != 0)fprintf(fp, ","); + fprintf(fp, "%g", out.data[j]); + } + fprintf(fp, "\n"); + free_image(out); + } + free(ims); + fclose(fp); + cvReleaseImage(&src); } void test_distribution() { - IplImage* img = 0; - if( (img = cvLoadImage("im_small.jpg",-1)) == 0 ) file_error("im_small.jpg"); - network net = parse_network_cfg("cfg/voc_features.cfg"); - int h = img->height/8-2; - int w = img->width/8-2; - image out = features_output_size(net, img, h, w); - int c = out.c; - out.c = 1; - show_image(out, "output"); - out.c = c; - image input = ipl_to_image(img); - show_image(input, "input"); - CvScalar s; - int i,j; - image affects = make_image(input.h, input.w, 1); - int count = 0; - for(i = 0; iheight; i += 1){ - for(j = 0; j < img->width; j += 1){ - IplImage *copy = cvCloneImage(img); - s=cvGet2D(copy,i,j); // get the (i,j) pixel value - printf("%d/%d\n", count++, img->height*img->width); - s.val[0]=0; - s.val[1]=0; - s.val[2]=0; - cvSet2D(copy,i,j,s); // set the (i,j) pixel value - image mod = features_output_size(net, copy, h, w); - image dist = image_distance(out, mod); - show_image(affects, "affects"); - cvWaitKey(1); - cvReleaseImage(©); - //affects.data[i*affects.w + j] += dist.data[3*dist.w+5]; - affects.data[i*affects.w + j] += dist.data[1*dist.w+1]; - free_image(mod); - free_image(dist); - } - } - show_image(affects, "Origins"); - cvWaitKey(0); - cvWaitKey(0); + IplImage* img = 0; + if( (img = cvLoadImage("im_small.jpg",-1)) == 0 ) file_error("im_small.jpg"); + network net = parse_network_cfg("cfg/voc_features.cfg"); + int h = img->height/8-2; + int w = img->width/8-2; + image out = features_output_size(net, img, h, w); + int c = out.c; + out.c = 1; + show_image(out, "output"); + out.c = c; + image input = ipl_to_image(img); + show_image(input, "input"); + CvScalar s; + int i,j; + image affects = make_image(input.h, input.w, 1); + int count = 0; + for(i = 0; iheight; i += 1){ + for(j = 0; j < img->width; j += 1){ + IplImage *copy = cvCloneImage(img); + s=cvGet2D(copy,i,j); // get the (i,j) pixel value + printf("%d/%d\n", count++, img->height*img->width); + s.val[0]=0; + s.val[1]=0; + s.val[2]=0; + cvSet2D(copy,i,j,s); // set the (i,j) pixel value + image mod = features_output_size(net, copy, h, w); + image dist = image_distance(out, mod); + show_image(affects, "affects"); + cvWaitKey(1); + cvReleaseImage(©); + //affects.data[i*affects.w + j] += dist.data[3*dist.w+5]; + affects.data[i*affects.w + j] += dist.data[1*dist.w+1]; + free_image(mod); + free_image(dist); + } + } + show_image(affects, "Origins"); + cvWaitKey(0); + cvWaitKey(0); } int main(int argc, char *argv[]) { - //train_full(); - //test_distribution(); - //feenableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW); + //train_full(); + //test_distribution(); + //feenableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW); - //test_blas(); - //test_convolve_matrix(); - // test_im2row(); - //test_split(); - //test_ensemble(); - //test_nist(); - //test_full(); - //train_VOC(); - features_VOC_image(argv[1], argv[2], argv[3]); - printf("Success!\n"); - //test_random_preprocess(); - //test_random_classify(); - //test_parser(); - //test_backpropagate(); - //test_ann(); - //test_convolve(); - //test_upsample(); - //test_rotate(); - //test_load(); - //test_network(); - //test_convolutional_layer(); - //verify_convolutional_layer(); - //test_color(); - //cvWaitKey(0); - return 0; + //test_blas(); + //test_visualize(); + //test_gpu_blas(); + //test_blas(); + //test_convolve_matrix(); + // test_im2row(); + //test_split(); + //test_ensemble(); + //test_nist(); + //test_cifar10(); + //test_vince(); + //test_full(); + //train_VOC(); + //features_VOC_image(argv[1], argv[2], argv[3]); + //features_VOC_image_size(argv[1], atoi(argv[2]), atoi(argv[3])); + //visualize_imagenet_features("data/assira/train.list"); + visualize_imagenet_topk("data/VOC2012.list"); + //visualize_cat(); + //flip_network(); + //test_visualize(); + fprintf(stderr, "Success!\n"); + //test_random_preprocess(); + //test_random_classify(); + //test_parser(); + //test_backpropagate(); + //test_ann(); + //test_convolve(); + //test_upsample(); + //test_rotate(); + //test_load(); + //test_network(); + //test_convolutional_layer(); + //verify_convolutional_layer(); + //test_color(); + //cvWaitKey(0); + return 0; } diff --git a/test.jpg b/test.jpg deleted file mode 100644 index f7b6cb8d..00000000 Binary files a/test.jpg and /dev/null differ diff --git a/test_color.png b/test_color.png deleted file mode 100644 index 1a1836e8..00000000 Binary files a/test_color.png and /dev/null differ diff --git a/test_dog.jpg b/test_dog.jpg deleted file mode 100644 index aa98311a..00000000 Binary files a/test_dog.jpg and /dev/null differ diff --git a/test_hinton.jpg b/test_hinton.jpg deleted file mode 100644 index 25b38210..00000000 Binary files a/test_hinton.jpg and /dev/null differ