diff --git a/Makefile b/Makefile index 44c930f8..415c5226 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ endif LDFLAGS=`pkg-config --libs opencv` -lm VPATH=./src/ -OBJ=network.o image.o tests.o convolutional_layer.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o +OBJ=network.o image.o tests.o convolutional_layer.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o all: cnn diff --git a/dog.jpg b/dog.jpg index 16d05ab1..3b9f7abd 100644 Binary files a/dog.jpg and b/dog.jpg differ diff --git a/src/connected_layer.c b/src/connected_layer.c index 0344c71a..6871b2ee 100644 --- a/src/connected_layer.c +++ b/src/connected_layer.c @@ -1,5 +1,6 @@ #include "connected_layer.h" #include "utils.h" +#include "mini_blas.h" #include #include @@ -35,55 +36,99 @@ connected_layer *make_connected_layer(int inputs, int outputs, ACTIVATION activa return layer; } +void update_connected_layer(connected_layer layer, double step, double momentum, double decay) +{ + int i; + for(i = 0; i < layer.outputs; ++i){ + layer.bias_momentum[i] = step*(layer.bias_updates[i]) + momentum*layer.bias_momentum[i]; + layer.biases[i] += layer.bias_momentum[i]; + } + for(i = 0; i < layer.outputs*layer.inputs; ++i){ + layer.weight_momentum[i] = step*(layer.weight_updates[i] - decay*layer.weights[i]) + momentum*layer.weight_momentum[i]; + layer.weights[i] += layer.weight_momentum[i]; + } + memset(layer.bias_updates, 0, layer.outputs*sizeof(double)); + memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(double)); +} + void forward_connected_layer(connected_layer layer, double *input) { - int i, j; + int i; + memcpy(layer.output, layer.biases, layer.outputs*sizeof(double)); + int m = 1; + int k = layer.inputs; + int n = layer.outputs; + double *a = input; + double *b = layer.weights; + double *c = layer.output; + gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); for(i = 0; i < layer.outputs; ++i){ - layer.output[i] = layer.biases[i]; - for(j = 0; j < layer.inputs; ++j){ - layer.output[i] += input[j]*layer.weights[i*layer.inputs + j]; - } layer.output[i] = activate(layer.output[i], layer.activation); } } void learn_connected_layer(connected_layer layer, double *input) { - int i, j; + int i; for(i = 0; i < layer.outputs; ++i){ layer.delta[i] *= gradient(layer.output[i], layer.activation); layer.bias_updates[i] += layer.delta[i]; - for(j = 0; j < layer.inputs; ++j){ - layer.weight_updates[i*layer.inputs + j] += layer.delta[i]*input[j]; - } } -} - -void update_connected_layer(connected_layer layer, double step, double momentum, double decay) -{ - int i,j; - for(i = 0; i < layer.outputs; ++i){ - layer.bias_momentum[i] = step*(layer.bias_updates[i]) + momentum*layer.bias_momentum[i]; - layer.biases[i] += layer.bias_momentum[i]; - for(j = 0; j < layer.inputs; ++j){ - int index = i*layer.inputs+j; - layer.weight_momentum[index] = step*(layer.weight_updates[index] - decay*layer.weights[index]) + momentum*layer.weight_momentum[index]; - layer.weights[index] += layer.weight_momentum[index]; - } - } - memset(layer.bias_updates, 0, layer.outputs*sizeof(double)); - memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(double)); + int m = layer.inputs; + int k = 1; + int n = layer.outputs; + double *a = input; + double *b = layer.delta; + double *c = layer.weight_updates; + gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); } void backward_connected_layer(connected_layer layer, double *input, double *delta) { - int i, j; + memset(delta, 0, layer.inputs*sizeof(double)); - for(j = 0; j < layer.inputs; ++j){ - delta[j] = 0; - for(i = 0; i < layer.outputs; ++i){ - delta[j] += layer.delta[i]*layer.weights[i*layer.inputs + j]; - } - } + int m = layer.inputs; + int k = layer.outputs; + int n = 1; + + double *a = layer.weights; + double *b = layer.delta; + double *c = delta; + + gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); } +/* + void forward_connected_layer(connected_layer layer, double *input) + { + int i, j; + for(i = 0; i < layer.outputs; ++i){ + layer.output[i] = layer.biases[i]; + for(j = 0; j < layer.inputs; ++j){ + layer.output[i] += input[j]*layer.weights[i*layer.inputs + j]; + } + layer.output[i] = activate(layer.output[i], layer.activation); + } + } + void learn_connected_layer(connected_layer layer, double *input) + { + int i, j; + for(i = 0; i < layer.outputs; ++i){ + layer.delta[i] *= gradient(layer.output[i], layer.activation); + layer.bias_updates[i] += layer.delta[i]; + for(j = 0; j < layer.inputs; ++j){ + layer.weight_updates[i*layer.inputs + j] += layer.delta[i]*input[j]; + } + } + } + void backward_connected_layer(connected_layer layer, double *input, double *delta) + { + int i, j; + for(j = 0; j < layer.inputs; ++j){ + delta[j] = 0; + for(i = 0; i < layer.outputs; ++i){ + delta[j] += layer.delta[i]*layer.weights[i*layer.inputs + j]; + } + } + } + */ diff --git a/src/image.c b/src/image.c index 74b88325..df8e1b8f 100644 --- a/src/image.c +++ b/src/image.c @@ -207,6 +207,7 @@ image make_random_image(int h, int w, int c) int i; for(i = 0; i < h*w*c; ++i){ out.data[i] = rand_normal(); + //out.data[i] = rand()%3; } return out; } diff --git a/src/mini_blas.c b/src/mini_blas.c new file mode 100644 index 00000000..b15ba8e6 --- /dev/null +++ b/src/mini_blas.c @@ -0,0 +1,80 @@ + +void gemm(int TA, int TB, int M, int N, int K, double ALPHA, + double *A, int lda, + double *B, int ldb, + double BETA, + double *C, int ldc) +{ + // Assume TA = TB = 0, beta = 1 LULZ + int i,j,k; + for(i = 0; i < M; ++i){ + for(k = 0; k < K; ++k){ + for(j = 0; j < N; ++j){ + C[i*ldc+j] += ALPHA*A[i*lda+k]*B[k*ldb+j]; + } + } + } +} + +void im2row(double *image, int h, int w, int c, int size, int stride, double *matrix) +{ + int i; + int mc = c; + int mw = (size*size); + int mh = ((h-size)/stride+1)*((w-size)/stride+1); + int msize = mc*mw*mh; + for(i = 0; i < msize; ++i){ + int channel = i/(mh*mw); + int block = (i%(mh*mw))/mw; + int position = i%mw; + int block_h = block/((w-size)/stride+1); + int block_w = block%((w-size)/stride+1); + int ph, pw, pc; + ph = position/size+block_h; + pw = position%size+block_w; + pc = channel; + matrix[i] = image[pc*h*w+ph*w+pw]; + } +} +void im2col(double *image, int h, int w, int c, int size, int stride, double *matrix) +{ + int b,p; + int blocks = ((h-size)/stride+1)*((w-size)/stride+1); + int pixels = (size*size*c); + for(b = 0; b < blocks; ++b){ + int block_h = b/((w-size)/stride+1); + int block_w = b%((w-size)/stride+1); + for(p = 0; p < pixels; ++p){ + int ph, pw, pc; + int position = p%(size*size); + pc = p/(size*size); + ph = position/size+block_h; + pw = position%size+block_w; + matrix[b+p*blocks] = image[pc*h*w+ph*w+pw]; + } + } +} + +//From Berkeley Vision's Caffe! +void im2col_cpu(double* data_im, const int channels, + const int height, const int width, const int ksize, const int stride, + double* data_col) + { + int c,h,w; + int height_col = (height - ksize) / stride + 1; + int width_col = (width - ksize) / stride + 1; + int channels_col = channels * ksize * ksize; + for ( c = 0; c < channels_col; ++c) { + int w_offset = c % ksize; + int h_offset = (c / ksize) % ksize; + int c_im = c / ksize / ksize; + for ( h = 0; h < height_col; ++h) { + for ( w = 0; w < width_col; ++w) { + data_col[(c * height_col + h) * width_col + w] = + data_im[(c_im * height + h * stride + h_offset) * width + + w * stride + w_offset]; + } + } + } +} + diff --git a/src/mini_blas.h b/src/mini_blas.h new file mode 100644 index 00000000..cdf3a86e --- /dev/null +++ b/src/mini_blas.h @@ -0,0 +1,10 @@ +void gemm(int TA, int TB, int M, int N, int K, double ALPHA, + double *A, int lda, + double *B, int ldb, + double BETA, + double *C, int ldc); +void im2row(double *image, int h, int w, int c, int size, int stride, double *matrix); +void im2col(double *image, int h, int w, int c, int size, int stride, double *matrix); +void im2col_cpu(double* data_im, const int channels, + const int height, const int width, const int ksize, const int stride, + double* data_col); diff --git a/src/tests.c b/src/tests.c index 2a50bacf..ce131e7f 100644 --- a/src/tests.c +++ b/src/tests.c @@ -7,6 +7,7 @@ #include "data.h" #include "matrix.h" #include "utils.h" +#include "mini_blas.h" #include #include @@ -28,6 +29,35 @@ void test_convolve() show_image_layers(edge, "Test Convolve"); } +void test_convolve_matrix() +{ + image dog = load_image("dog.jpg"); + printf("dog channels %d\n", dog.c); + + int size = 11; + int stride = 1; + int n = 40; + double *filters = make_random_image(size, size, dog.c*n).data; + + int mw = ((dog.h-size)/stride+1)*((dog.w-size)/stride+1); + int mh = (size*size*dog.c); + double *matrix = calloc(mh*mw, sizeof(double)); + + image edge = make_image((dog.h-size)/stride+1, (dog.w-size)/stride+1, n); + + + int i; + clock_t start = clock(), end; + for(i = 0; i < 1000; ++i){ + im2col_cpu(dog.data, dog.c, dog.h, dog.w, size, stride, matrix); + gemm(0,0,n,mw,mh,1,filters,mh,matrix,mw,1,edge.data,mw); + } + end = clock(); + printf("Convolutions: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC); + show_image_layers(edge, "Test Convolve"); + cvWaitKey(0); +} + void test_color() { image dog = load_image("test_color.png"); @@ -199,7 +229,7 @@ void test_nist() { srand(444444); srand(888888); - network net = parse_network_cfg("nist.cfg"); + network net = parse_network_cfg("nist_basic.cfg"); data train = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10); data test = load_categorical_data_csv("mnist/mnist_test.csv",0,10); normalize_data_rows(train); @@ -216,8 +246,8 @@ void test_nist() end = clock(); printf("Time: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC); start=end; - visualize_network(net); - cvWaitKey(100); + //visualize_network(net); + //cvWaitKey(100); //lr /= 2; if(count%5 == 0 && 0){ double train_acc = network_accuracy(net, train); @@ -336,13 +366,59 @@ void test_split() printf("%d, %d, %d\n", train.X.rows, split[0].X.rows, split[1].X.rows); } +double *random_matrix(int rows, int cols) +{ + int i, j; + double *m = calloc(rows*cols, sizeof(double)); + for(i = 0; i < rows; ++i){ + for(j = 0; j < cols; ++j){ + m[i*cols+j] = (double)rand()/RAND_MAX; + } + } + return m; +} + +void test_blas() +{ + int m = 6025, n = 20, k = 11*11*3; + double *a = random_matrix(m,k); + double *b = random_matrix(k,n); + double *c = random_matrix(m,n); + int i; + for(i = 0; i<1000; ++i){ + gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); + } +} + +void test_im2row() +{ + int h = 20; + int w = 20; + int c = 3; + int stride = 1; + int size = 11; + image test = make_random_image(h,w,c); + int mc = 1; + int mw = ((h-size)/stride+1)*((w-size)/stride+1); + int mh = (size*size*c); + int msize = mc*mw*mh; + double *matrix = calloc(msize, sizeof(double)); + int i; + for(i = 0; i < 1000; ++i){ + im2col_cpu(test.data, c, h, w, size, stride, matrix); + image render = double_to_image(mh, mw, mc, matrix); + } +} int main() { + //test_blas(); + test_convolve_matrix(); +// test_im2row(); //test_kernel_update(); //test_split(); //test_ensemble(); - test_nist(); + //test_nist(); //test_full(); //test_random_preprocess(); //test_random_classify();