fast sort of working

This commit is contained in:
Joseph Redmon 2015-01-19 22:06:18 -08:00
parent 08b757a0bf
commit 6e1d5b45de
7 changed files with 255 additions and 30 deletions

View File

@ -210,10 +210,10 @@ void train_imagenet(char *cfgfile)
//network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg");
srand(time(0));
network net = parse_network_cfg(cfgfile);
set_learning_network(&net, net.learning_rate, 0, net.decay);
set_learning_network(&net, net.learning_rate*10., net.momentum, net.decay);
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = 1024;
int i = 0;
int i = 6600;
char **labels = get_labels("/home/pjreddie/data/imagenet/cls.labels.list");
list *plist = get_paths("/data/imagenet/cls.train.list");
char **paths = (char **)list_to_array(plist);
@ -228,9 +228,9 @@ void train_imagenet(char *cfgfile)
time=clock();
pthread_join(load_thread, 0);
train = buffer;
//normalize_data_rows(train);
translate_data_rows(train, -128);
scale_data_rows(train, 1./128);
normalize_data_rows(train);
//translate_data_rows(train, -128);
//scale_data_rows(train, 1./128);
load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, 256, 256, &buffer);
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
@ -539,12 +539,14 @@ void visualize_cat()
void test_correct_nist()
{
network net = parse_network_cfg("cfg/nist_conv.cfg");
test_learn_bias(*(convolutional_layer *)net.layers[0]);
srand(222222);
network net = parse_network_cfg("cfg/nist.cfg");
net = parse_network_cfg("cfg/nist_conv.cfg");
data train = load_categorical_data_csv("data/mnist/mnist_train.csv", 0, 10);
data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
translate_data_rows(train, -144);
translate_data_rows(test, -144);
normalize_data_rows(train);
normalize_data_rows(test);
int count = 0;
int iters = 1000/net.batch;
@ -555,11 +557,12 @@ void test_correct_nist()
float test_acc = network_accuracy(net, test);
printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC, net.learning_rate, net.momentum, net.decay);
}
save_network(net, "cfg/nist_gpu.cfg");
gpu_index = -1;
count = 0;
srand(222222);
net = parse_network_cfg("cfg/nist.cfg");
net = parse_network_cfg("cfg/nist_conv.cfg");
while(++count <= 5){
clock_t start = clock(), end;
float loss = train_network_sgd(net, train, iters);
@ -567,6 +570,7 @@ void test_correct_nist()
float test_acc = network_accuracy(net, test);
printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC, net.learning_rate, net.momentum, net.decay);
}
save_network(net, "cfg/nist_cpu.cfg");
}
void test_correct_alexnet()

View File

@ -305,6 +305,27 @@ void learn_bias_convolutional_layer_ongpu(convolutional_layer layer)
check_error(cl);
}
void test_learn_bias(convolutional_layer l)
{
int i;
int size = convolutional_out_height(l) * convolutional_out_width(l);
for(i = 0; i < size*l.batch*l.n; ++i){
l.delta[i] = rand_uniform();
}
for(i = 0; i < l.n; ++i){
l.bias_updates[i] = rand_uniform();
}
cl_write_array(l.delta_cl, l.delta, size*l.batch*l.n);
cl_write_array(l.bias_updates_cl, l.bias_updates, l.n);
float *gpu = calloc(l.n, sizeof(float));
cl_read_array(l.bias_updates_cl, gpu, l.n);
for(i = 0; i < l.n; ++i) printf("%.9g %.9g\n", l.bias_updates[i], gpu[i]);
learn_bias_convolutional_layer_ongpu(l);
learn_bias_convolutional_layer(l);
cl_read_array(l.bias_updates_cl, gpu, l.n);
for(i = 0; i < l.n; ++i) printf("%.9g %.9g\n", l.bias_updates[i], gpu[i]);
}
cl_kernel get_convolutional_bias_kernel()
{
static int init = 0;

View File

@ -19,7 +19,7 @@ __kernel void learn_bias(int batch, int n, int size, __global float *delta, __gl
for(b = 0; b < batch; ++b){
for(i = 0; i < size; i += BLOCK){
int index = p + i + size*(filter + n*b);
sum += (index < size) ? delta[index] : 0;
sum += (p+i < size) ? delta[index] : 0;
}
}
part[p] = sum;

View File

@ -162,6 +162,26 @@ cl_kernel get_gemm_nn_kernel()
return gemm_kernel;
}
#define TILE 64
#define TILE_K 16
#define WPT 8
#define THREADS (TILE*TILE)/(WPT*WPT)
cl_kernel get_gemm_nn_fast_kernel()
{
static int init = 0;
static cl_kernel gemm_kernel;
if(!init){
gemm_kernel = get_kernel("src/gemm_fast.cl", "gemm_nn_fast", "-D TILE=" STR(TILE)
" -cl-nv-verbose "
" -D TILE_K=" STR(TILE_K)
" -D WPT=" STR(WPT)
" -D THREADS=" STR(THREADS));
init = 1;
}
return gemm_kernel;
}
void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA,
cl_mem A_gpu, int lda,
cl_mem B_gpu, int ldb,
@ -171,6 +191,45 @@ void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA,
gemm_ongpu_offset(TA, TB, M, N, K, ALPHA, A_gpu, 0, lda, B_gpu, 0, ldb, BETA, C_gpu, 0, ldc);
}
void gemm_ongpu_fast(int TA, int TB, int M, int N, int K, float ALPHA,
cl_mem A_gpu, int lda,
cl_mem B_gpu, int ldb,
float BETA,
cl_mem C_gpu, int ldc)
{
int a_off = 0;
int b_off = 0;
int c_off = 0;
//printf("gpu: %d %d %d %d %d\n",TA, TB, M, N, K);
cl_kernel gemm_kernel = get_gemm_nn_fast_kernel();
cl_command_queue queue = cl.queue;
cl_uint i = 0;
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TA), (void*) &TA);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TB), (void*) &TB);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(M), (void*) &M);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(N), (void*) &N);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(K), (void*) &K);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ALPHA), (void*) &ALPHA);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(A_gpu), (void*) &A_gpu);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(a_off), (void*) &a_off);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(lda), (void*) &lda);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(B_gpu), (void*) &B_gpu);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(b_off), (void*) &b_off);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldb), (void*) &ldb);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(BETA), (void*) &BETA);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(C_gpu), (void*) &C_gpu);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(c_off), (void*) &c_off);
cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldc), (void*) &ldc);
check_error(cl);
const size_t global_size[] = {THREADS*((N-1)/TILE + 1), (M-1)/TILE + 1};
const size_t local_size[] = {THREADS, 1};
cl.error = clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0);
check_error(cl);
}
void gemm_ongpu_offset(int TA, int TB, int M, int N, int K, float ALPHA,
cl_mem A_gpu, int a_off, int lda,
cl_mem B_gpu, int b_off, int ldb,
@ -214,7 +273,7 @@ void gemm_ongpu_offset(int TA, int TB, int M, int N, int K, float ALPHA,
cl.error = clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0);
check_error(cl);
#endif
#endif
}
void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA,
@ -244,7 +303,9 @@ void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA,
size, C, &cl.error);
check_error(cl);
gemm_ongpu(TA, TB, M, N, K, ALPHA, A_gpu, lda, B_gpu, ldb, BETA, C_gpu, ldc);
// TODO
//gemm_ongpu(TA, TB, M, N, K, ALPHA, A_gpu, lda, B_gpu, ldb, BETA, C_gpu, ldc);
gemm_ongpu_fast(TA, TB, M, N, K, ALPHA, A_gpu, lda, B_gpu, ldb, BETA, C_gpu, ldc);
clEnqueueReadBuffer(queue, C_gpu, CL_TRUE, 0, size, C, 0, 0, 0);
check_error(cl);
@ -303,7 +364,7 @@ void time_ongpu(int TA, int TB, int m, int k, int n)
for(i = 0; i<iter; ++i){
gemm_ongpu(TA,TB,m,n,k,1,a_cl,lda,b_cl,ldb,1,c_cl,n);
}
double flop = m*n*k*iter;
double flop = ((double)m)*n*(2.*k + 2.)*iter;
double gflop = flop/pow(10., 9);
end = clock();
double seconds = sec(end-start);
@ -316,6 +377,39 @@ void time_ongpu(int TA, int TB, int m, int k, int n)
free(c);
}
void time_ongpu_fast(int TA, int TB, int m, int k, int n)
{
int iter = 10;
float *a = random_matrix(m,k);
float *b = random_matrix(k,n);
int lda = (!TA)?k:m;
int ldb = (!TB)?n:k;
float *c = random_matrix(m,n);
cl_mem a_cl = cl_make_array(a, m*k);
cl_mem b_cl = cl_make_array(b, k*n);
cl_mem c_cl = cl_make_array(c, m*n);
int i;
clock_t start = clock(), end;
for(i = 0; i<iter; ++i){
gemm_ongpu_fast(TA,TB,m,n,k,1,a_cl,lda,b_cl,ldb,1,c_cl,n);
}
double flop = ((double)m)*n*(2.*k + 2.)*iter;
double gflop = flop/pow(10., 9);
end = clock();
double seconds = sec(end-start);
printf("Fast Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf s, %lf GFLOPS\n",m,k,k,n, TA, TB, seconds, gflop/seconds);
clReleaseMemObject(a_cl);
clReleaseMemObject(b_cl);
clReleaseMemObject(c_cl);
free(a);
free(b);
free(c);
}
void test_gpu_accuracy(int TA, int TB, int m, int k, int n)
{
srand(0);
@ -335,8 +429,10 @@ void test_gpu_accuracy(int TA, int TB, int m, int k, int n)
int i;
//pm(m,k,b);
gemm_gpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c_gpu,n);
//printf("GPU\n");
//pm(m, n, c_gpu);
gemm_cpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n);
//printf("\n\nCPU\n");
//pm(m, n, c);
double sse = 0;
for(i = 0; i < m*n; ++i) {
@ -365,21 +461,47 @@ void test_gpu_blas()
test_gpu_accuracy(0,1,1000,10,100);
test_gpu_accuracy(1,1,1000,10,100);
*/
time_ongpu(0,0,512,256,1152);
time_ongpu(0,0,128,1200,4096);
time_ongpu(0,0,128,1200,4096);
time_ongpu(0,0,128,1200,4096);
time_ongpu(0,1,128,1200,4096);
time_ongpu(1,0,1200,4096,128);
time_ongpu(1,0,4096,1200,128);
time_ongpu(1,0,1200,128,4096);
test_gpu_accuracy(0,0,128,128,128);
test_gpu_accuracy(0,0,512,256,1152);
test_gpu_accuracy(0,0,131,4093,1199);
test_gpu_accuracy(0,1,131,4093,1199);
test_gpu_accuracy(1,0,131,4093,1199);
test_gpu_accuracy(1,1,131,4093,1199);
/*
time_ongpu(0,0,64,2916,363);
time_ongpu_fast(0,0,64,2916,363);
time_ongpu(0,0,64,2916,363);
time_ongpu_fast(0,0,64,2916,363);
time_ongpu(0,0,64,2916,363);
time_ongpu_fast(0,0,64,2916,363);
time_ongpu(0,0,192,729,1600);
time_ongpu_fast(0,0,192,729,1600);
time_ongpu(0,0,384,196,1728);
time_ongpu_fast(0,0,384,196,1728);
time_ongpu(0,0,256,196,3456);
time_ongpu_fast(0,0,256,196,3456);
time_ongpu(0,0,256,196,2304);
time_ongpu_fast(0,0,256,196,2304);
time_ongpu(0,0,128,4096,12544);
time_ongpu_fast(0,0,128,4096,12544);
time_ongpu(0,0,128,4096,4096);
time_ongpu_fast(0,0,128,4096,4096);
*/
// time_ongpu(1,0,2304,196,256);
// time_ongpu_fast(1,0,2304,196,256);
// time_ongpu(0,1,256,2304,196);
// time_ongpu_fast(0,1,256,2304,196);
time_ongpu(0,0,2048,2048,2048);
time_ongpu_fast(0,0,2048,2048,2048);
time_ongpu(0,0,2048,2048,2048);
time_ongpu_fast(0,0,2048,2048,2048);
time_ongpu(0,0,2048,2048,2048);
time_ongpu_fast(0,0,2048,2048,2048);
/*
test_gpu_accuracy(0,0,131,4093,1199);
test_gpu_accuracy(0,1,131,4093,1199);
test_gpu_accuracy(1,0,131,4093,1199);
test_gpu_accuracy(1,1,131,4093,1199);
*/
/*
time_ongpu(0,0,1024,1024,1024);

View File

@ -215,4 +215,3 @@ __kernel void gemm(int TA, int TB, int M, int N, int K, float ALPHA,
C[row*ldc+col] = ALPHA*val + BETA*C[row*ldc+col];
}
}

79
src/gemm_fast.cl Normal file
View File

@ -0,0 +1,79 @@
__kernel void gemm_nn_fast(int TA, int TB, int M, int N, int K, float ALPHA,
__global float *A, int a_off, int lda,
__global float *B, int b_off, int ldb,
float BETA,
__global float *C, int c_off, int ldc)
{
int i, j, k, x, y;
A += a_off;
B += b_off;
C += c_off;
__local float Asub[TILE] [TILE_K];
__local float Bsub[TILE_K][TILE];
int ctile = get_group_id(0);
int rtile = get_group_id(1);
float Breg;
float Areg[WPT];
float acc[WPT][WPT];
A += rtile*TILE*lda;
B += ctile*TILE;
C += rtile*TILE*ldc + ctile*TILE;
for(i = 0; i < WPT; ++i){
for(j = 0; j < WPT; ++j){
acc[i][j] = 0;
}
}
int offset = get_local_id(0);
for(i = 0; i < K; i += TILE_K){
for(j = 0; j < TILE*TILE_K; j += THREADS){
int index = j+offset;
int row = index / TILE_K;
int col = index % TILE_K;
Asub[row][col] = A[row*lda + col];
row = index / TILE;
col = index % TILE;
Bsub[row][col] = B[row*ldb + col];
}
A += TILE_K;
B += TILE_K*ldb;
barrier(CLK_LOCAL_MEM_FENCE);
for(k = 0; k < TILE_K; ++k){
for(y = 0; y < WPT; ++y){
int row = (offset + (y*WPT)*THREADS)/TILE;
//Areg[y] = Asub[y*WPT][k];
}
for(y = 0; y < WPT; ++y){
for(x = 0; x < WPT; ++x){
int index = offset + (y*WPT + x)*THREADS;
int row = index / TILE;
int col = index % TILE;
acc[y][x] += Asub[row][k]*Bsub[k][col];
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
for(y = 0; y < WPT; ++y){
for(x = 0; x < WPT; ++x){
int index = offset + (y*WPT + x)*THREADS;
int row = index / TILE;
int col = index % TILE;
C[row*ldc+col] = ALPHA*acc[y][x] + BETA*C[row*ldc+col];
}
}
}

View File

@ -132,11 +132,11 @@ cl_program cl_fprog(char *filename, char *options, cl_info info)
char build_c[1024*64];
// and compile it (after this we could extract the compiled version)
info.error=clBuildProgram(prog, 0, 0, options, 0, 0);
if ( info.error != CL_SUCCESS ) {
//if ( info.error != CL_SUCCESS ) {
fprintf(stderr, "Error Building Program: %d\n", info.error);
clGetProgramBuildInfo( prog, info.device, CL_PROGRAM_BUILD_LOG, 1024*64, build_c, 0);
fprintf(stderr, "Build Log for %s program:\n%s\n", filename, build_c);
}
//}
check_error(info);
return prog;
}
@ -205,7 +205,7 @@ cl_mem cl_make_array(float *x, int n)
CL_MEM_READ_WRITE|CL_MEM_COPY_HOST_PTR,
sizeof(float)*n, x, &cl.error);
check_error(cl);
activate_array_ongpu(mem, n, LINEAR);
//activate_array_ongpu(mem, n, LINEAR);
return mem;
}