From 4ac78c89269138b4623993f9f1d81829d8e88131 Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Tue, 20 Jan 2015 13:26:46 -0800 Subject: [PATCH] I am so done with opencl, switching to cuda --- src/cnn.c | 2 +- src/gemm.c | 6 +----- src/gemm_fast.cl | 37 +++++++++++++++++-------------------- 3 files changed, 19 insertions(+), 26 deletions(-) diff --git a/src/cnn.c b/src/cnn.c index be93e8c0..fed69d0c 100644 --- a/src/cnn.c +++ b/src/cnn.c @@ -210,7 +210,7 @@ void train_imagenet(char *cfgfile) //network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg"); srand(time(0)); network net = parse_network_cfg(cfgfile); - set_learning_network(&net, net.learning_rate*10., net.momentum, net.decay); + set_learning_network(&net, net.learning_rate*100., net.momentum, net.decay); printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); int imgs = 1024; int i = 6600; diff --git a/src/gemm.c b/src/gemm.c index 83949914..9797b85c 100644 --- a/src/gemm.c +++ b/src/gemm.c @@ -164,8 +164,7 @@ cl_kernel get_gemm_nn_kernel() #define TILE 64 #define TILE_K 16 -#define WPT 8 -#define THREADS (TILE*TILE)/(WPT*WPT) +#define THREADS 64 cl_kernel get_gemm_nn_fast_kernel() { @@ -175,7 +174,6 @@ cl_kernel get_gemm_nn_fast_kernel() gemm_kernel = get_kernel("src/gemm_fast.cl", "gemm_nn_fast", "-D TILE=" STR(TILE) " -cl-nv-verbose " " -D TILE_K=" STR(TILE_K) - " -D WPT=" STR(WPT) " -D THREADS=" STR(THREADS)); init = 1; } @@ -464,7 +462,6 @@ void test_gpu_blas() test_gpu_accuracy(0,0,128,128,128); -/* time_ongpu(0,0,64,2916,363); time_ongpu_fast(0,0,64,2916,363); time_ongpu(0,0,64,2916,363); @@ -483,7 +480,6 @@ void test_gpu_blas() time_ongpu_fast(0,0,128,4096,12544); time_ongpu(0,0,128,4096,4096); time_ongpu_fast(0,0,128,4096,4096); - */ // time_ongpu(1,0,2304,196,256); // time_ongpu_fast(1,0,2304,196,256); // time_ongpu(0,1,256,2304,196); diff --git a/src/gemm_fast.cl b/src/gemm_fast.cl index 9a982087..2a76396f 100644 --- a/src/gemm_fast.cl +++ b/src/gemm_fast.cl @@ -16,16 +16,15 @@ __kernel void gemm_nn_fast(int TA, int TB, int M, int N, int K, float ALPHA, int ctile = get_group_id(0); int rtile = get_group_id(1); - float Breg; - float Areg[WPT]; - float acc[WPT][WPT]; + float Areg[TILE]; + float acc[TILE][TILE/THREADS]; A += rtile*TILE*lda; B += ctile*TILE; C += rtile*TILE*ldc + ctile*TILE; - for(i = 0; i < WPT; ++i){ - for(j = 0; j < WPT; ++j){ + for(i = 0; i < TILE; ++i){ + for(j = 0; j < TILE/THREADS; ++j){ acc[i][j] = 0; } } @@ -51,28 +50,26 @@ __kernel void gemm_nn_fast(int TA, int TB, int M, int N, int K, float ALPHA, barrier(CLK_LOCAL_MEM_FENCE); for(k = 0; k < TILE_K; ++k){ - for(y = 0; y < WPT; ++y){ - int row = (offset + (y*WPT)*THREADS)/TILE; - //Areg[y] = Asub[y*WPT][k]; + #pragma unroll + for(y = 0; y < TILE; ++y){ + Areg[y] = Asub[y][k]; } - for(y = 0; y < WPT; ++y){ - for(x = 0; x < WPT; ++x){ - int index = offset + (y*WPT + x)*THREADS; - int row = index / TILE; - int col = index % TILE; - acc[y][x] += Asub[row][k]*Bsub[k][col]; + for(x = 0; x < TILE; x += THREADS){ + float Breg = Bsub[k][x+offset]; + #pragma unroll + for(y = 0; y < TILE; ++y){ + acc[y][x/THREADS] += Breg * Areg[y]; } } } barrier(CLK_LOCAL_MEM_FENCE); } - for(y = 0; y < WPT; ++y){ - for(x = 0; x < WPT; ++x){ - int index = offset + (y*WPT + x)*THREADS; - int row = index / TILE; - int col = index % TILE; - C[row*ldc+col] = ALPHA*acc[y][x] + BETA*C[row*ldc+col]; + for(i = 0; i < TILE; ++i){ + for(j = 0; j < TILE/THREADS; ++j){ + int col = j*THREADS + offset; + int row = i; + C[row*ldc+col] = ALPHA*acc[i][j] + BETA*C[row*ldc+col]; } } }