From 158bb1bee9951875dbe3474d84c6663431e18301 Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Tue, 21 Oct 2014 14:49:18 -0700 Subject: [PATCH] softmax on gpu --- Makefile | 2 +- src/cnn.c | 17 ++++--- src/connected_layer.c | 7 +++ src/convolutional_layer.c | 33 ++++++++++-- src/gemm.c | 52 +++++++++++++------ src/maxpool_layer.c | 89 ++++++++++++++++++++++++++++++-- src/maxpool_layer.cl | 73 +++++++++++++++++++++++++++ src/maxpool_layer.h | 13 ++++- src/mini_blas.c | 2 +- src/network.c | 71 ++++++++++++++++++-------- src/opencl.c | 29 ++++++++++- src/opencl.h | 5 ++ src/softmax_layer.c | 103 +++++++++++++++++++++++--------------- src/softmax_layer.cl | 21 ++++++++ src/softmax_layer.h | 13 ++++- src/utils.c | 5 ++ src/utils.h | 2 + 17 files changed, 440 insertions(+), 97 deletions(-) create mode 100644 src/maxpool_layer.cl create mode 100644 src/softmax_layer.cl diff --git a/Makefile b/Makefile index 315e6269..b5ad1eb0 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ CC=gcc -GPU=0 +GPU=1 COMMON=-Wall -Wfatal-errors `pkg-config --cflags opencv` -I/usr/local/cuda/include/ ifeq ($(GPU), 1) COMMON+=-DGPU diff --git a/src/cnn.c b/src/cnn.c index bfba26a1..7e90a809 100644 --- a/src/cnn.c +++ b/src/cnn.c @@ -281,15 +281,17 @@ void test_data() void train_assira() { network net = parse_network_cfg("cfg/assira.cfg"); + int imgs = 1000/net.batch+1; + //imgs = 1; srand(2222222); int i = 0; char *labels[] = {"cat","dog"}; while(1){ i += 1000; - data train = load_data_image_pathfile_random("data/assira/train.list", 1000, labels, 2, 256, 256); + data train = load_data_image_pathfile_random("data/assira/train.list", imgs*net.batch, labels, 2, 256, 256); normalize_data_rows(train); clock_t start = clock(), end; - float loss = train_network_sgd_gpu(net, train, 10); + float loss = train_network_sgd_gpu(net, train, imgs); end = clock(); printf("%d: %f, Time: %lf seconds\n", i, loss, (float)(end-start)/CLOCKS_PER_SEC ); free_data(train); @@ -358,7 +360,7 @@ void train_cifar10() data train = load_all_cifar10(); while(++count <= 10000){ clock_t start = clock(), end; - float loss = train_network_sgd_gpu(net, train, iters); + float loss = train_network_sgd(net, train, iters); end = clock(); //visualize_network(net); //cvWaitKey(5000); @@ -369,7 +371,7 @@ void train_cifar10() float test_acc = network_accuracy(net, test); printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC, net.learning_rate, net.momentum, net.decay); char buff[256]; - sprintf(buff, "/home/pjreddie/cifar/cifar2_%d.cfg", count); + sprintf(buff, "/home/pjreddie/cifar/cifar10_2_%d.cfg", count); save_network(net, buff); }else{ printf("%d: Loss: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", count, loss, (float)(end-start)/CLOCKS_PER_SEC, net.learning_rate, net.momentum, net.decay); @@ -435,7 +437,7 @@ void train_nist() int iters = 10000/net.batch; while(++count <= 2000){ clock_t start = clock(), end; - float loss = train_network_sgd(net, train, iters); + float loss = train_network_sgd_gpu(net, train, iters); end = clock(); float test_acc = network_accuracy(net, test); //float test_acc = 0; @@ -893,7 +895,8 @@ void test_distribution() int main(int argc, char *argv[]) { - //train_assira(); + //test_blas(); + train_assira(); //test_distribution(); //feenableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW); @@ -907,7 +910,7 @@ int main(int argc, char *argv[]) //test_ensemble(); //test_nist_single(); //test_nist(); - train_nist(); + //train_nist(); //test_convolutional_layer(); //test_col2im(); //test_cifar10(); diff --git a/src/connected_layer.c b/src/connected_layer.c index ba83dc3c..b41ae91f 100644 --- a/src/connected_layer.c +++ b/src/connected_layer.c @@ -108,6 +108,12 @@ void backward_connected_layer(connected_layer layer, float *input, float *delta) #ifdef GPU +void pull_connected_layer(connected_layer layer) +{ + cl_read_array(layer.weights_cl, layer.weights, layer.inputs*layer.outputs); + cl_read_array(layer.biases_cl, layer.biases, layer.outputs); +} + void update_connected_layer_gpu(connected_layer layer) { axpy_ongpu(layer.outputs, layer.learning_rate, layer.bias_updates_cl, 1, layer.biases_cl, 1); @@ -116,6 +122,7 @@ void update_connected_layer_gpu(connected_layer layer) scal_ongpu(layer.inputs*layer.outputs, 1.-layer.learning_rate*layer.decay, layer.weights_cl, 1); axpy_ongpu(layer.inputs*layer.outputs, layer.learning_rate, layer.weight_updates_cl, 1, layer.weights_cl, 1); scal_ongpu(layer.inputs*layer.outputs, layer.momentum, layer.weight_updates_cl, 1); + pull_connected_layer(layer); } void forward_connected_layer_gpu(connected_layer layer, cl_mem input) diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 00de1538..0ed5a995 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -2,6 +2,7 @@ #include "utils.h" #include "mini_blas.h" #include +#include int convolutional_out_height(convolutional_layer layer) { @@ -341,6 +342,8 @@ void bias_output_gpu(const convolutional_layer layer) check_error(cl); } +//#define TIMEIT + void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in) { int i; @@ -349,10 +352,21 @@ void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in) int n = convolutional_out_height(layer)* convolutional_out_width(layer); - //cl_write_array(layer.filters_cl, layer.filters, m*k); - //cl_write_array(layer.biases_cl, layer.biases, m); bias_output_gpu(layer); + + #ifdef TIMEIT + clock_t time = clock(); + printf("Forward\n"); + #endif + im2col_ongpu(in, layer.batch, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_cl); + + #ifdef TIMEIT + clFinish(cl.queue); + printf("Im2col %f\n", sec(clock()-time)); + time = clock(); + #endif + for(i = 0; i < layer.batch; ++i){ cl_mem a = layer.filters_cl; cl_mem b = cl_sub_array(layer.col_image_cl, i*k*n, k*n); @@ -361,8 +375,14 @@ void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in) clReleaseMemObject(b); clReleaseMemObject(c); } + #ifdef TIMEIT + clFinish(cl.queue); + printf("Gemm %f\n", sec(clock()-time)); + #endif activate_array_ongpu(layer.output_cl, m*n*layer.batch, layer.activation); - //cl_read_array(layer.output_cl, layer.output, m*n*layer.batch); + #ifdef TIMEIT + cl_read_array(layer.output_cl, layer.output, m*n*layer.batch); + #endif } void backward_convolutional_layer_gpu(convolutional_layer layer, cl_mem delta_cl) @@ -408,6 +428,12 @@ void backward_convolutional_layer_gpu(convolutional_layer layer, cl_mem delta_cl } } +void pull_convolutional_layer(convolutional_layer layer) +{ + cl_read_array(layer.filters_cl, layer.filters, layer.c*layer.n*layer.size*layer.size); + cl_read_array(layer.biases_cl, layer.biases, layer.n); +} + void update_convolutional_layer_gpu(convolutional_layer layer) { int size = layer.size*layer.size*layer.c*layer.n; @@ -417,6 +443,7 @@ void update_convolutional_layer_gpu(convolutional_layer layer) scal_ongpu(size, 1.-layer.learning_rate*layer.decay, layer.filters_cl, 1); axpy_ongpu(size, layer.learning_rate, layer.filter_updates_cl, 1, layer.filters_cl, 1); scal_ongpu(size, layer.momentum, layer.filter_updates_cl, 1); + pull_convolutional_layer(layer); } diff --git a/src/gemm.c b/src/gemm.c index 65542bcc..fa78daf9 100644 --- a/src/gemm.c +++ b/src/gemm.c @@ -1,4 +1,5 @@ #include "mini_blas.h" +#include void gemm(int TA, int TB, int M, int N, int K, float ALPHA, float *A, int lda, @@ -35,7 +36,7 @@ void gemm_nt(int M, int N, int K, float ALPHA, for(j = 0; j < N; ++j){ register float sum = 0; for(k = 0; k < K; ++k){ - sum += ALPHA*A[i*lda+k]*B[k+j*ldb]; + sum += ALPHA*A[i*lda+k]*B[j*ldb + k]; } C[i*ldc+j] += sum; } @@ -57,6 +58,7 @@ void gemm_tn(int M, int N, int K, float ALPHA, } } } + void gemm_tt(int M, int N, int K, float ALPHA, float *A, int lda, float *B, int ldb, @@ -65,9 +67,11 @@ void gemm_tt(int M, int N, int K, float ALPHA, int i,j,k; for(i = 0; i < M; ++i){ for(j = 0; j < N; ++j){ + register float sum = 0; for(k = 0; k < K; ++k){ - C[i*ldc+j] += ALPHA*A[i+k*lda]*B[k+j*ldb]; + sum += ALPHA*A[i+k*lda]*B[k+j*ldb]; } + C[i*ldc+j] += sum; } } } @@ -121,13 +125,31 @@ cl_kernel get_gemm_kernel() return gemm_kernel; } +void gemm_ongpu_old(int TA, int TB, int M, int N, int K, float ALPHA, + cl_mem A_gpu, int lda, + cl_mem B_gpu, int ldb, + float BETA, + cl_mem C_gpu, int ldc); + void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA, cl_mem A_gpu, int lda, cl_mem B_gpu, int ldb, float BETA, cl_mem C_gpu, int ldc) { - //printf("gpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc); + cl_setup(); + //cl.error = clblasSgemm(clblasRowMajor, TA?clblasTrans:clblasNoTrans, TB?clblasTrans:clblasNoTrans,M, N, K,ALPHA, A_gpu, 0, lda,B_gpu, 0, ldb,BETA, C_gpu, 0, ldc,1, &queue, 0, NULL, &event); + //check_error(cl); + gemm_ongpu_old(TA, TB, M, N, K, ALPHA, A_gpu, lda, B_gpu, ldb, BETA, C_gpu, ldc); +} + +void gemm_ongpu_old(int TA, int TB, int M, int N, int K, float ALPHA, + cl_mem A_gpu, int lda, + cl_mem B_gpu, int ldb, + float BETA, + cl_mem C_gpu, int ldc) +{ + //printf("gpu: %d %d %d %d %d\n",TA, TB, M, N, K); cl_setup(); cl_kernel gemm_kernel = get_gemm_kernel(); cl_command_queue queue = cl.queue; @@ -213,11 +235,11 @@ void time_gpu_random_matrix(int TA, int TB, int m, int k, int n) float *c = random_matrix(m,n); int i; clock_t start = clock(), end; - for(i = 0; i<1000; ++i){ + for(i = 0; i<10; ++i){ gemm_gpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n); } end = clock(); - printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (float)(end-start)/CLOCKS_PER_SEC); + printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf s\n",m,k,k,n, TA, TB, (float)(end-start)/CLOCKS_PER_SEC); free(a); free(b); free(c); @@ -270,19 +292,19 @@ void test_gpu_blas() test_gpu_accuracy(0,1,1000,10,100); test_gpu_accuracy(1,1,1000,10,100); -/* - time_gpu_random_matrix(0,0,1000,1000,100); - time_random_matrix(0,0,1000,1000,100); + /* + time_gpu_random_matrix(0,0,1000,1000,100); + time_random_matrix(0,0,1000,1000,100); - time_gpu_random_matrix(0,1,1000,1000,100); - time_random_matrix(0,1,1000,1000,100); + time_gpu_random_matrix(0,1,1000,1000,100); + time_random_matrix(0,1,1000,1000,100); - time_gpu_random_matrix(1,0,1000,1000,100); - time_random_matrix(1,0,1000,1000,100); + time_gpu_random_matrix(1,0,1000,1000,100); + time_random_matrix(1,0,1000,1000,100); - time_gpu_random_matrix(1,1,1000,1000,100); - time_random_matrix(1,1,1000,1000,100); - */ + time_gpu_random_matrix(1,1,1000,1000,100); + time_random_matrix(1,1,1000,1000,100); + */ } #endif diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c index 01eed453..65315410 100644 --- a/src/maxpool_layer.c +++ b/src/maxpool_layer.c @@ -27,9 +27,15 @@ maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int size, int layer->c = c; layer->size = size; layer->stride = stride; - layer->indexes = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c*batch, sizeof(int)); - layer->output = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c*batch, sizeof(float)); - layer->delta = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c*batch, sizeof(float)); + int output_size = ((h-1)/stride+1) * ((w-1)/stride+1) * c * batch; + layer->indexes = calloc(output_size, sizeof(int)); + layer->output = calloc(output_size, sizeof(float)); + layer->delta = calloc(output_size, sizeof(float)); + #ifdef GPU + layer->indexes_cl = cl_make_int_array(layer->indexes, output_size); + layer->output_cl = cl_make_array(layer->output, output_size); + layer->delta_cl = cl_make_array(layer->delta, output_size); + #endif return layer; } @@ -66,7 +72,7 @@ void forward_maxpool_layer(const maxpool_layer layer, float *input) int index = cur_w + layer.w*(cur_h + layer.h*(k + b*layer.c)); int valid = (cur_h >= 0 && cur_h < layer.h && cur_w >= 0 && cur_w < layer.w); - float val = (valid != 0) ? input[index] : -INFINITY; + float val = (valid != 0) ? input[index] : -FLT_MAX; max_i = (val > max) ? index : max_i; max = (val > max) ? val : max; } @@ -79,7 +85,7 @@ void forward_maxpool_layer(const maxpool_layer layer, float *input) } } -void backward_maxpool_layer(const maxpool_layer layer, float *input, float *delta) +void backward_maxpool_layer(const maxpool_layer layer, float *delta) { int i; int h = (layer.h-1)/layer.stride + 1; @@ -92,3 +98,76 @@ void backward_maxpool_layer(const maxpool_layer layer, float *input, float *delt } } +#ifdef GPU +cl_kernel get_forward_kernel() +{ + static int init = 0; + static cl_kernel kernel; + if(!init){ + kernel = get_kernel("src/maxpool_layer.cl", "forward", 0); + init = 1; + } + return kernel; +} + +void forward_maxpool_layer_gpu(maxpool_layer layer, cl_mem input) +{ + int h = (layer.h-1)/layer.stride + 1; + int w = (layer.w-1)/layer.stride + 1; + int c = layer.c; + cl_setup(); + cl_kernel kernel = get_forward_kernel(); + cl_command_queue queue = cl.queue; + + cl_uint i = 0; + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.h), (void*) &layer.h); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.w), (void*) &layer.w); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.c), (void*) &layer.c); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.stride), (void*) &layer.stride); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.size), (void*) &layer.size); + cl.error = clSetKernelArg(kernel, i++, sizeof(input), (void*) &input); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.output_cl), (void*) &layer.output_cl); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.indexes_cl), (void*) &layer.indexes_cl); + check_error(cl); + + const size_t global_size[] = {h*w*c*layer.batch}; + + clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); + check_error(cl); +} + +cl_kernel get_backward_kernel() +{ + static int init = 0; + static cl_kernel kernel; + if(!init){ + kernel = get_kernel("src/maxpool_layer.cl", "backward", 0); + init = 1; + } + return kernel; +} + +void backward_maxpool_layer_gpu(maxpool_layer layer, cl_mem delta) +{ + cl_setup(); + cl_kernel kernel = get_backward_kernel(); + cl_command_queue queue = cl.queue; + + cl_uint i = 0; + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.h), (void*) &layer.h); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.w), (void*) &layer.w); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.c), (void*) &layer.c); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.stride), (void*) &layer.stride); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.size), (void*) &layer.size); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.delta_cl), (void*) &layer.delta_cl); + cl.error = clSetKernelArg(kernel, i++, sizeof(delta), (void*) &delta); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.indexes_cl), (void*) &layer.indexes_cl); + check_error(cl); + + const size_t global_size[] = {layer.h*layer.w*layer.c*layer.batch}; + + clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); + check_error(cl); +} + +#endif diff --git a/src/maxpool_layer.cl b/src/maxpool_layer.cl new file mode 100644 index 00000000..fc793d0c --- /dev/null +++ b/src/maxpool_layer.cl @@ -0,0 +1,73 @@ + +__kernel void forward(int in_h, int in_w, int in_c, int stride, int size, __global float *input, __global float *output, __global int *indexes) +{ + int h = (in_h-1)/stride + 1; + int w = (in_w-1)/stride + 1; + int c = in_c; + + int id = get_global_id(0); + int j = id % w; + id /= w; + int i = id % h; + id /= h; + int k = id % c; + id /= c; + int b = id; + + int w_offset = (-size-1)/2 + 1; + int h_offset = (-size-1)/2 + 1; + + int out_index = j + w*(i + h*(k + c*b)); + float max = -INFINITY; + int max_i = -1; + int l, m; + for(l = 0; l < size; ++l){ + for(m = 0; m < size; ++m){ + int cur_h = h_offset + i*stride + l; + int cur_w = w_offset + j*stride + m; + int index = cur_w + in_w*(cur_h + in_h*(k + b*in_c)); + int valid = (cur_h >= 0 && cur_h < in_h && + cur_w >= 0 && cur_w < in_w); + float val = (valid != 0) ? input[index] : -INFINITY; + max_i = (val > max) ? index : max_i; + max = (val > max) ? val : max; + } + } + output[out_index] = max; + indexes[out_index] = max_i; +} + +__kernel void backward(int in_h, int in_w, int in_c, int stride, int size, __global float *delta, __global float *prev_delta, __global int *indexes) +{ + int h = (in_h-1)/stride + 1; + int w = (in_w-1)/stride + 1; + int c = in_c; + int area = (size-1)/stride; + + int id = get_global_id(0); + int index = id; + int j = id % in_w; + id /= in_w; + int i = id % in_h; + id /= in_h; + int k = id % in_c; + id /= in_c; + int b = id; + + int w_offset = (-size-1)/2 + 1; + int h_offset = (-size-1)/2 + 1; + + float d = 0; + int l, m; + for(l = -area; l < area+1; ++l){ + for(m = -area; m < area+1; ++m){ + int out_w = (j-w_offset)/stride + m; + int out_h = (i-h_offset)/stride + l; + int out_index = out_w + w*(out_h + h*(k + c*b)); + int valid = (out_w >= 0 && out_w < w && + out_h >= 0 && out_h < h); + d += (valid && indexes[out_index] == index) ? delta[out_index] : 0; + } + } + prev_delta[index] = d; +} diff --git a/src/maxpool_layer.h b/src/maxpool_layer.h index 9edb214d..dc45c550 100644 --- a/src/maxpool_layer.h +++ b/src/maxpool_layer.h @@ -2,6 +2,7 @@ #define MAXPOOL_LAYER_H #include "image.h" +#include "opencl.h" typedef struct { int batch; @@ -11,13 +12,23 @@ typedef struct { int *indexes; float *delta; float *output; + #ifdef GPU + cl_mem indexes_cl; + cl_mem output_cl; + cl_mem delta_cl; + #endif } maxpool_layer; image get_maxpool_image(maxpool_layer layer); maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int size, int stride); void resize_maxpool_layer(maxpool_layer *layer, int h, int w, int c); void forward_maxpool_layer(const maxpool_layer layer, float *input); -void backward_maxpool_layer(const maxpool_layer layer, float *input, float *delta); +void backward_maxpool_layer(const maxpool_layer layer, float *delta); + +#ifdef GPU +void forward_maxpool_layer_gpu(maxpool_layer layer, cl_mem input); +void backward_maxpool_layer_gpu(maxpool_layer layer, cl_mem delta); +#endif #endif diff --git a/src/mini_blas.c b/src/mini_blas.c index 0227b37c..4d929719 100644 --- a/src/mini_blas.c +++ b/src/mini_blas.c @@ -41,7 +41,7 @@ void time_random_matrix(int TA, int TB, int m, int k, int n) float *c = random_matrix(m,n); int i; clock_t start = clock(), end; - for(i = 0; i<1000; ++i){ + for(i = 0; i<10; ++i){ gemm_cpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n); } end = clock(); diff --git a/src/network.c b/src/network.c index f9b46673..6696769d 100644 --- a/src/network.c +++ b/src/network.c @@ -1,4 +1,5 @@ #include +#include #include "network.h" #include "image.h" #include "data.h" @@ -31,8 +32,10 @@ network make_network(int n, int batch) } #ifdef GPU + void forward_network_gpu(network net, cl_mem input, cl_mem truth, int train) { + //printf("start\n"); int i; for(i = 0; i < net.n; ++i){ if(net.types[i] == CONVOLUTIONAL){ @@ -49,28 +52,28 @@ void forward_network_gpu(network net, cl_mem input, cl_mem truth, int train) forward_connected_layer_gpu(layer, input); input = layer.output_cl; } - /* - else if(net.types[i] == SOFTMAX){ - softmax_layer layer = *(softmax_layer *)net.layers[i]; - forward_softmax_layer(layer, input); - input = layer.output; - } - else if(net.types[i] == CROP){ - crop_layer layer = *(crop_layer *)net.layers[i]; - forward_crop_layer(layer, input); - input = layer.output; - } else if(net.types[i] == MAXPOOL){ maxpool_layer layer = *(maxpool_layer *)net.layers[i]; - forward_maxpool_layer(layer, input); - input = layer.output; + forward_maxpool_layer_gpu(layer, input); + input = layer.output_cl; } - else if(net.types[i] == NORMALIZATION){ - normalization_layer layer = *(normalization_layer *)net.layers[i]; - forward_normalization_layer(layer, input); - input = layer.output; + else if(net.types[i] == SOFTMAX){ + softmax_layer layer = *(softmax_layer *)net.layers[i]; + forward_softmax_layer_gpu(layer, input); + input = layer.output_cl; } - */ + /* + else if(net.types[i] == CROP){ + crop_layer layer = *(crop_layer *)net.layers[i]; + forward_crop_layer(layer, input); + input = layer.output; + } + else if(net.types[i] == NORMALIZATION){ + normalization_layer layer = *(normalization_layer *)net.layers[i]; + forward_normalization_layer(layer, input); + input = layer.output; + } + */ } } @@ -99,6 +102,14 @@ void backward_network_gpu(network net, cl_mem input) connected_layer layer = *(connected_layer *)net.layers[i]; backward_connected_layer_gpu(layer, prev_input, prev_delta); } + else if(net.types[i] == MAXPOOL){ + maxpool_layer layer = *(maxpool_layer *)net.layers[i]; + backward_maxpool_layer_gpu(layer, prev_delta); + } + else if(net.types[i] == SOFTMAX){ + softmax_layer layer = *(softmax_layer *)net.layers[i]; + backward_softmax_layer_gpu(layer, prev_delta); + } } } @@ -127,6 +138,14 @@ cl_mem get_network_output_cl_layer(network net, int i) connected_layer layer = *(connected_layer *)net.layers[i]; return layer.output_cl; } + else if(net.types[i] == MAXPOOL){ + maxpool_layer layer = *(maxpool_layer *)net.layers[i]; + return layer.output_cl; + } + else if(net.types[i] == SOFTMAX){ + softmax_layer layer = *(softmax_layer *)net.layers[i]; + return layer.output_cl; + } return 0; } @@ -140,6 +159,14 @@ cl_mem get_network_delta_cl_layer(network net, int i) connected_layer layer = *(connected_layer *)net.layers[i]; return layer.delta_cl; } + else if(net.types[i] == MAXPOOL){ + maxpool_layer layer = *(maxpool_layer *)net.layers[i]; + return layer.delta_cl; + } + else if(net.types[i] == SOFTMAX){ + softmax_layer layer = *(softmax_layer *)net.layers[i]; + return layer.delta_cl; + } return 0; } @@ -330,7 +357,7 @@ void backward_network(network net, float *input) } else if(net.types[i] == MAXPOOL){ maxpool_layer layer = *(maxpool_layer *)net.layers[i]; - if(i != 0) backward_maxpool_layer(layer, prev_input, prev_delta); + if(i != 0) backward_maxpool_layer(layer, prev_delta); } else if(net.types[i] == NORMALIZATION){ normalization_layer layer = *(normalization_layer *)net.layers[i]; @@ -338,7 +365,7 @@ void backward_network(network net, float *input) } else if(net.types[i] == SOFTMAX){ softmax_layer layer = *(softmax_layer *)net.layers[i]; - if(i != 0) backward_softmax_layer(layer, prev_input, prev_delta); + if(i != 0) backward_softmax_layer(layer, prev_delta); } else if(net.types[i] == CONNECTED){ connected_layer layer = *(connected_layer *)net.layers[i]; @@ -351,6 +378,7 @@ void backward_network(network net, float *input) } } + #ifdef GPU float train_network_datum_gpu(network net, float *x, float *y) { @@ -364,13 +392,12 @@ float train_network_datum_gpu(network net, float *x, float *y) cl_write_array(*net.truth_cl, y, y_size); } forward_network_gpu(net, *net.input_cl, *net.truth_cl, 1); - //int class = get_predicted_class_network(net); backward_network_gpu(net, *net.input_cl); float error = get_network_cost(net); update_network_gpu(net); - //return (y[class]?1:0); return error; } + float train_network_sgd_gpu(network net, data d, int n) { int batch = net.batch; diff --git a/src/opencl.c b/src/opencl.c index 5aec33c9..a2e73668 100644 --- a/src/opencl.c +++ b/src/opencl.c @@ -4,6 +4,7 @@ #include #include #include +//#include #include "opencl.h" #include "utils.h" @@ -80,9 +81,9 @@ cl_info cl_init() } int index = getpid()%num_devices; + index = 0; printf("%d rand, %d devices, %d index\n", getpid(), num_devices, index); - //info.device = devices[index]; - info.device = devices[0]; + info.device = devices[index]; fprintf(stderr, "Found %d device(s)\n", num_devices); check_error(info); @@ -94,10 +95,24 @@ cl_info cl_init() check_error(info); info.queue = clCreateCommandQueue(info.context, info.device, 0, &info.error); check_error(info); + for(i = 0; i < NUM_QUEUES; ++i){ + info.queues[i] = clCreateCommandQueue(info.context, info.device, 0, &info.error); + check_error(info); + } + //info.error = clblasSetup(); + check_error(info); info.initialized = 1; return info; } +void wait_for_queues() +{ + int i; + for(i = 0; i < NUM_QUEUES; ++i){ + clFinish(cl.queues[i]); + } +} + cl_program cl_fprog(char *filename, char *options, cl_info info) { size_t srcsize; @@ -180,4 +195,14 @@ cl_mem cl_make_array(float *x, int n) return mem; } +cl_mem cl_make_int_array(int *x, int n) +{ + cl_setup(); + cl_mem mem = clCreateBuffer(cl.context, + CL_MEM_READ_WRITE|CL_MEM_COPY_HOST_PTR, + sizeof(int)*n, x, &cl.error); + check_error(cl); + return mem; +} + #endif diff --git a/src/opencl.h b/src/opencl.h index 9cf3acd4..aedc0565 100644 --- a/src/opencl.h +++ b/src/opencl.h @@ -7,6 +7,8 @@ #include #endif +#define NUM_QUEUES 8 + typedef struct { int initialized; cl_int error; @@ -14,16 +16,19 @@ typedef struct { cl_device_id device; cl_context context; cl_command_queue queue; + cl_command_queue queues[NUM_QUEUES]; }cl_info; extern cl_info cl; void cl_setup(); +void wait_for_queues(); void check_error(cl_info info); cl_kernel get_kernel(char *filename, char *kernelname, char *options); void cl_read_array(cl_mem mem, float *x, int n); void cl_write_array(cl_mem mem, float *x, int n); cl_mem cl_make_array(float *x, int n); +cl_mem cl_make_int_array(int *x, int n); void cl_copy_array(cl_mem src, cl_mem dst, int n); cl_mem cl_sub_array(cl_mem src, int offset, int size); #endif diff --git a/src/softmax_layer.c b/src/softmax_layer.c index b6e9fe9e..dae332ed 100644 --- a/src/softmax_layer.c +++ b/src/softmax_layer.c @@ -1,5 +1,6 @@ #include "softmax_layer.h" #include "mini_blas.h" +#include #include #include #include @@ -13,36 +14,25 @@ softmax_layer *make_softmax_layer(int batch, int inputs) layer->output = calloc(inputs*batch, sizeof(float)); layer->delta = calloc(inputs*batch, sizeof(float)); layer->jacobian = calloc(inputs*inputs*batch, sizeof(float)); + #ifdef GPU + layer->output_cl = cl_make_array(layer->output, inputs*batch); + layer->delta_cl = cl_make_array(layer->delta, inputs*batch); + #endif return layer; } -/* UNSTABLE! -void forward_softmax_layer(const softmax_layer layer, float *input) -{ - int i; - float sum = 0; - for(i = 0; i < layer.inputs; ++i){ - sum += exp(input[i]); - } - for(i = 0; i < layer.inputs; ++i){ - layer.output[i] = exp(input[i])/sum; - } -} -*/ void forward_softmax_layer(const softmax_layer layer, float *input) { int i,b; for(b = 0; b < layer.batch; ++b){ float sum = 0; - float largest = 0; + float largest = -FLT_MAX; for(i = 0; i < layer.inputs; ++i){ if(input[i+b*layer.inputs] > largest) largest = input[i+b*layer.inputs]; } for(i = 0; i < layer.inputs; ++i){ sum += exp(input[i+b*layer.inputs]-largest); - //printf("%f, ", input[i]); } - //printf("\n"); if(sum) sum = largest+log(sum); else sum = largest-100; for(i = 0; i < layer.inputs; ++i){ @@ -51,33 +41,68 @@ void forward_softmax_layer(const softmax_layer layer, float *input) } } -void backward_softmax_layer(const softmax_layer layer, float *input, float *delta) +void backward_softmax_layer(const softmax_layer layer, float *delta) { -/* - int i,j,b; - for(b = 0; b < layer.batch; ++b){ - for(i = 0; i < layer.inputs; ++i){ - for(j = 0; j < layer.inputs; ++j){ - int d = (i==j); - layer.jacobian[b*layer.inputs*layer.inputs + i*layer.inputs + j] = - layer.output[b*layer.inputs + i] * (d - layer.output[b*layer.inputs + j]); - } - } - } - for(b = 0; b < layer.batch; ++b){ - int M = layer.inputs; - int N = 1; - int K = layer.inputs; - float *A = layer.jacobian + b*layer.inputs*layer.inputs; - float *B = layer.delta + b*layer.inputs; - float *C = delta + b*layer.inputs; - gemm(0,0,M,N,K,1,A,K,B,N,0,C,N); - } - */ - int i; for(i = 0; i < layer.inputs*layer.batch; ++i){ delta[i] = layer.delta[i]; } } +#ifdef GPU +cl_kernel get_softmax_forward_kernel() +{ + static int init = 0; + static cl_kernel kernel; + if(!init){ + kernel = get_kernel("src/softmax_layer.cl", "forward", 0); + init = 1; + } + return kernel; +} + +void forward_softmax_layer_gpu(const softmax_layer layer, cl_mem input) +{ + cl_setup(); + cl_kernel kernel = get_softmax_forward_kernel(); + cl_command_queue queue = cl.queue; + + cl_uint i = 0; + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.inputs), (void*) &layer.inputs); + cl.error = clSetKernelArg(kernel, i++, sizeof(input), (void*) &input); + cl.error = clSetKernelArg(kernel, i++, sizeof(layer.output_cl), (void*) &layer.output_cl); + check_error(cl); + + const size_t global_size[] = {layer.batch}; + + clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); + check_error(cl); +} + +void backward_softmax_layer_gpu(const softmax_layer layer, cl_mem delta) +{ + copy_ongpu(layer.batch*layer.inputs, layer.delta_cl, 1, delta, 1); +} +#endif + +/* This is if you want softmax w/o log-loss classification. You probably don't. + int i,j,b; + for(b = 0; b < layer.batch; ++b){ + for(i = 0; i < layer.inputs; ++i){ + for(j = 0; j < layer.inputs; ++j){ + int d = (i==j); + layer.jacobian[b*layer.inputs*layer.inputs + i*layer.inputs + j] = + layer.output[b*layer.inputs + i] * (d - layer.output[b*layer.inputs + j]); + } + } + } + for(b = 0; b < layer.batch; ++b){ + int M = layer.inputs; + int N = 1; + int K = layer.inputs; + float *A = layer.jacobian + b*layer.inputs*layer.inputs; + float *B = layer.delta + b*layer.inputs; + float *C = delta + b*layer.inputs; + gemm(0,0,M,N,K,1,A,K,B,N,0,C,N); + } + */ diff --git a/src/softmax_layer.cl b/src/softmax_layer.cl new file mode 100644 index 00000000..77da521e --- /dev/null +++ b/src/softmax_layer.cl @@ -0,0 +1,21 @@ + +__kernel void forward(int n, __global float *input, __global float *output) +{ + int b = get_global_id(0); + + int i; + float sum = 0; + float largest = -INFINITY; + for(i = 0; i < n; ++i){ + int val = input[i+b*n]; + largest = (val>largest) ? val : largest; + } + for(i = 0; i < n; ++i){ + sum += exp(input[i+b*n]-largest); + } + sum = (sum != 0) ? largest+log(sum) : largest-100; + for(i = 0; i < n; ++i){ + output[i+b*n] = exp(input[i+b*n]-sum); + } +} + diff --git a/src/softmax_layer.h b/src/softmax_layer.h index 22752508..2f9f9798 100644 --- a/src/softmax_layer.h +++ b/src/softmax_layer.h @@ -1,16 +1,27 @@ #ifndef SOFTMAX_LAYER_H #define SOFTMAX_LAYER_H +#include "opencl.h" + typedef struct { int inputs; int batch; float *delta; float *output; float *jacobian; + #ifdef GPU + cl_mem delta_cl; + cl_mem output_cl; + #endif } softmax_layer; softmax_layer *make_softmax_layer(int batch, int inputs); void forward_softmax_layer(const softmax_layer layer, float *input); -void backward_softmax_layer(const softmax_layer layer, float *input, float *delta); +void backward_softmax_layer(const softmax_layer layer, float *delta); + +#ifdef GPU +void forward_softmax_layer_gpu(const softmax_layer layer, cl_mem input); +void backward_softmax_layer_gpu(const softmax_layer layer, cl_mem delta); +#endif #endif diff --git a/src/utils.c b/src/utils.c index 8a65ba7b..a883ad86 100644 --- a/src/utils.c +++ b/src/utils.c @@ -4,6 +4,11 @@ #include #include +float sec(clock_t clocks) +{ + return (float)clocks/CLOCKS_PER_SEC; +} + void error(char *s) { fprintf(stderr, "Error: %s\n", s); diff --git a/src/utils.h b/src/utils.h index f38af337..49948f53 100644 --- a/src/utils.h +++ b/src/utils.h @@ -1,6 +1,7 @@ #ifndef UTILS_H #define UTILS_H #include +#include #include "list.h" void error(char *s); @@ -25,5 +26,6 @@ float sum_array(float *a, int n); float mean_array(float *a, int n); float variance_array(float *a, int n); float **one_hot_encode(float *a, int n, int k); +float sec(clock_t clocks); #endif