checkpoint

This commit is contained in:
Joseph Redmon 2014-11-18 13:51:04 -08:00
parent b13ad6d5fd
commit d407bffde9
23 changed files with 194 additions and 96 deletions

View File

@ -14,6 +14,7 @@ endif
UNAME = $(shell uname) UNAME = $(shell uname)
OPTS=-Ofast -flto OPTS=-Ofast -flto
#OPTS=-O3
ifeq ($(UNAME), Darwin) ifeq ($(UNAME), Darwin)
COMMON+= -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include COMMON+= -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include
ifeq ($(GPU), 1) ifeq ($(GPU), 1)

View File

@ -128,7 +128,7 @@ void activate_array_ongpu(cl_mem x, int n, ACTIVATION a)
size_t gsize = n; size_t gsize = n;
clEnqueueNDRangeKernel(queue, kernel, 1, 0, &gsize, 0, 0, 0, 0); cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, &gsize, 0, 0, 0, 0);
check_error(cl); check_error(cl);
} }
@ -158,7 +158,7 @@ void gradient_array_ongpu(cl_mem x, int n, ACTIVATION a, cl_mem delta)
size_t gsize = n; size_t gsize = n;
clEnqueueNDRangeKernel(queue, kernel, 1, 0, &gsize, 0, 0, 0, 0); cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, &gsize, 0, 0, 0, 0);
check_error(cl); check_error(cl);
} }
#endif #endif

View File

@ -87,7 +87,7 @@ void axpy_ongpu_offset(int N, float ALPHA, cl_mem X, int OFFX, int INCX, cl_mem
const size_t global_size[] = {N}; const size_t global_size[] = {N};
clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0);
check_error(cl); check_error(cl);
} }
@ -113,7 +113,7 @@ void copy_ongpu_offset(int N, cl_mem X, int OFFX, int INCX, cl_mem Y, int OFFY,
const size_t global_size[] = {N}; const size_t global_size[] = {N};
clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0);
check_error(cl); check_error(cl);
} }
void scal_ongpu(int N, float ALPHA, cl_mem X, int INCX) void scal_ongpu(int N, float ALPHA, cl_mem X, int INCX)
@ -131,7 +131,7 @@ void scal_ongpu(int N, float ALPHA, cl_mem X, int INCX)
const size_t global_size[] = {N}; const size_t global_size[] = {N};
clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0);
check_error(cl); check_error(cl);
} }
#endif #endif

View File

@ -265,10 +265,8 @@ void test_rotate()
void test_parser() void test_parser()
{ {
network net = parse_network_cfg("cfg/test_parser.cfg"); network net = parse_network_cfg("cfg/trained_imagenet.cfg");
save_network(net, "cfg/test_parser_1.cfg"); save_network(net, "cfg/trained_imagenet_smaller.cfg");
network net2 = parse_network_cfg("cfg/test_parser_1.cfg");
save_network(net2, "cfg/test_parser_2.cfg");
} }
void test_data() void test_data()
@ -294,7 +292,8 @@ void train_asirra()
normalize_data_rows(train); normalize_data_rows(train);
printf("Loaded: %lf seconds\n", sec(clock()-time)); printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock(); time=clock();
float loss = train_network_data_gpu(net, train, imgs); //float loss = train_network_data(net, train, imgs);
float loss = 0;
printf("%d: %f, Time: %lf seconds\n", i*net.batch*imgs, loss, sec(clock()-time)); printf("%d: %f, Time: %lf seconds\n", i*net.batch*imgs, loss, sec(clock()-time));
free_data(train); free_data(train);
if(i%10==0){ if(i%10==0){
@ -309,7 +308,8 @@ void train_asirra()
void train_imagenet() void train_imagenet()
{ {
float avg_loss = 1; float avg_loss = 1;
network net = parse_network_cfg("/home/pjreddie/imagenet_backup/imagenet_nin_2680.cfg"); network net = parse_network_cfg("/home/pjreddie/imagenet_backup/imagenet_2280.cfg");
//network net = parse_network_cfg("cfg/imagenet2.cfg");
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = 1000/net.batch+1; int imgs = 1000/net.batch+1;
srand(time(0)); srand(time(0));
@ -335,7 +335,7 @@ void train_imagenet()
free_data(train); free_data(train);
if(i%10==0){ if(i%10==0){
char buff[256]; char buff[256];
sprintf(buff, "/home/pjreddie/imagenet_backup/imagenet_nin_%d.cfg", i); sprintf(buff, "/home/pjreddie/imagenet_backup/imagenet_%d.cfg", i);
save_network(net, buff); save_network(net, buff);
} }
} }
@ -408,7 +408,7 @@ void test_imagenet()
char filename[256]; char filename[256];
int indexes[10]; int indexes[10];
while(1){ while(1){
gets(filename); fgets(filename, 256, stdin);
image im = load_image_color(filename, 256, 256); image im = load_image_color(filename, 256, 256);
z_normalize_image(im); z_normalize_image(im);
printf("%d %d %d\n", im.h, im.w, im.c); printf("%d %d %d\n", im.h, im.w, im.c);
@ -548,35 +548,16 @@ void train_nist()
data train = load_categorical_data_csv("data/mnist/mnist_train.csv", 0, 10); data train = load_categorical_data_csv("data/mnist/mnist_train.csv", 0, 10);
data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10); data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
translate_data_rows(train, -144); translate_data_rows(train, -144);
//scale_data_rows(train, 1./128);
translate_data_rows(test, -144); translate_data_rows(test, -144);
//scale_data_rows(test, 1./128);
//randomize_data(train);
int count = 0; int count = 0;
//clock_t start = clock(), end; int iters = 50000/net.batch;
int iters = 10000/net.batch;
while(++count <= 2000){ while(++count <= 2000){
clock_t start = clock(), end; clock_t start = clock(), end;
float loss = train_network_sgd(net, train, iters); float loss = train_network_sgd(net, train, iters);
end = clock(); end = clock();
float test_acc = network_accuracy(net, test); float test_acc = network_accuracy(net, test);
//float test_acc = 0; printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC);
printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC, net.learning_rate, net.momentum, net.decay);
/*printf("%f %f %f %f %f\n", mean_array(get_network_output_layer(net,0), 100),
mean_array(get_network_output_layer(net,1), 100),
mean_array(get_network_output_layer(net,2), 100),
mean_array(get_network_output_layer(net,3), 100),
mean_array(get_network_output_layer(net,4), 100));
*/
//save_network(net, "cfg/nist_final2.cfg");
//printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*1000, loss, lr, momentum, decay);
//end = clock();
//printf("Time: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
//start=end;
//lr *= .5;
} }
//save_network(net, "cfg/nist_basic_trained.cfg");
} }
void test_ensemble() void test_ensemble()
@ -1052,6 +1033,7 @@ int main(int argc, char *argv[])
} }
if(0==strcmp(argv[1], "train")) train_imagenet(); if(0==strcmp(argv[1], "train")) train_imagenet();
else if(0==strcmp(argv[1], "asirra")) train_asirra(); else if(0==strcmp(argv[1], "asirra")) train_asirra();
else if(0==strcmp(argv[1], "nist")) train_nist();
else if(0==strcmp(argv[1], "train_small")) train_imagenet_small(); else if(0==strcmp(argv[1], "train_small")) train_imagenet_small();
else if(0==strcmp(argv[1], "test_correct")) test_gpu_net(); else if(0==strcmp(argv[1], "test_correct")) test_gpu_net();
else if(0==strcmp(argv[1], "test")) test_imagenet(); else if(0==strcmp(argv[1], "test")) test_imagenet();

View File

@ -82,7 +82,7 @@ void col2im_ongpu(cl_mem data_col, int batch,
size_t global_size = channels*height*width*batch; size_t global_size = channels*height*width*batch;
clEnqueueNDRangeKernel(queue, kernel, 1, 0, cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0,
&global_size, 0, 0, 0, 0); &global_size, 0, 0, 0, 0);
check_error(cl); check_error(cl);
} }

View File

@ -9,7 +9,6 @@
connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation, float learning_rate, float momentum, float decay) connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation, float learning_rate, float momentum, float decay)
{ {
fprintf(stderr, "Connected Layer: %d inputs, %d outputs\n", inputs, outputs);
int i; int i;
connected_layer *layer = calloc(1, sizeof(connected_layer)); connected_layer *layer = calloc(1, sizeof(connected_layer));
@ -51,6 +50,7 @@ connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVA
layer->delta_cl = cl_make_array(layer->delta, outputs*batch); layer->delta_cl = cl_make_array(layer->delta, outputs*batch);
#endif #endif
layer->activation = activation; layer->activation = activation;
fprintf(stderr, "Connected Layer: %d inputs, %d outputs\n", inputs, outputs);
return layer; return layer;
} }

View File

@ -304,7 +304,7 @@ void learn_bias_convolutional_layer_ongpu(convolutional_layer layer)
const size_t global_size[] = {layer.n}; const size_t global_size[] = {layer.n};
clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0);
check_error(cl); check_error(cl);
} }
@ -338,7 +338,7 @@ void bias_output_gpu(const convolutional_layer layer)
const size_t global_size[] = {layer.n*size, layer.batch}; const size_t global_size[] = {layer.n*size, layer.batch};
clEnqueueNDRangeKernel(queue, kernel, 2, 0, global_size, 0, 0, 0, 0); cl.error = clEnqueueNDRangeKernel(queue, kernel, 2, 0, global_size, 0, 0, 0, 0);
check_error(cl); check_error(cl);
} }
@ -400,7 +400,6 @@ void backward_convolutional_layer_gpu(convolutional_layer layer, cl_mem delta_cl
gemm_ongpu_offset(0,1,m,n,k,1,a,i*m*k,k,b,i*k*n,k,1,c,0,n); gemm_ongpu_offset(0,1,m,n,k,1,a,i*m*k,k,b,i*k*n,k,1,c,0,n);
} }
//cl_read_array(layer.delta_cl, layer.delta, m*k*layer.batch);
if(delta_cl){ if(delta_cl){
m = layer.size*layer.size*layer.c; m = layer.size*layer.size*layer.c;

View File

@ -1,4 +1,5 @@
#include "cost_layer.h" #include "cost_layer.h"
#include "utils.h"
#include "mini_blas.h" #include "mini_blas.h"
#include <math.h> #include <math.h>
#include <stdlib.h> #include <stdlib.h>
@ -36,11 +37,12 @@ void forward_cost_layer_gpu(cost_layer layer, cl_mem input, cl_mem truth)
{ {
if (!truth) return; if (!truth) return;
copy_ongpu(layer.batch*layer.inputs, truth, 1, layer.delta_cl, 1); copy_ongpu(layer.batch*layer.inputs, truth, 1, layer.delta_cl, 1);
axpy_ongpu(layer.batch*layer.inputs, -1, input, 1, layer.delta_cl, 1); axpy_ongpu(layer.batch*layer.inputs, -1, input, 1, layer.delta_cl, 1);
cl_read_array(layer.delta_cl, layer.delta, layer.batch*layer.inputs); cl_read_array(layer.delta_cl, layer.delta, layer.batch*layer.inputs);
*(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1); *(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1);
//printf("%f\n", *layer.output);
} }
void backward_cost_layer_gpu(const cost_layer layer, cl_mem input, cl_mem delta) void backward_cost_layer_gpu(const cost_layer layer, cl_mem input, cl_mem delta)

View File

@ -19,6 +19,12 @@ list *get_paths(char *filename)
return lines; return lines;
} }
void fill_truth_det(char *path, float *truth)
{
find_replace(path, "imgs", "det");
find_replace(path, ".JPEG", ".txt");
}
void fill_truth(char *path, char **labels, int k, float *truth) void fill_truth(char *path, char **labels, int k, float *truth)
{ {
int i; int i;
@ -83,7 +89,6 @@ void free_data(data d)
data load_data_image_pathfile_part(char *filename, int part, int total, char **labels, int k, int h, int w) data load_data_image_pathfile_part(char *filename, int part, int total, char **labels, int k, int h, int w)
{ {
clock_t time = clock();
list *plist = get_paths(filename); list *plist = get_paths(filename);
char **paths = (char **)list_to_array(plist); char **paths = (char **)list_to_array(plist);
int start = part*plist->size/total; int start = part*plist->size/total;

View File

@ -1,6 +1,7 @@
#include "dropout_layer.h" #include "dropout_layer.h"
#include "stdlib.h" #include "utils.h"
#include "stdio.h" #include <stdlib.h>
#include <stdio.h>
dropout_layer *make_dropout_layer(int batch, int inputs, float probability) dropout_layer *make_dropout_layer(int batch, int inputs, float probability)
{ {
@ -9,6 +10,10 @@ dropout_layer *make_dropout_layer(int batch, int inputs, float probability)
layer->probability = probability; layer->probability = probability;
layer->inputs = inputs; layer->inputs = inputs;
layer->batch = batch; layer->batch = batch;
#ifdef GPU
layer->rand = calloc(inputs*batch, sizeof(float));
layer->rand_cl = cl_make_array(layer->rand, inputs*batch);
#endif
return layer; return layer;
} }
@ -16,7 +21,7 @@ void forward_dropout_layer(dropout_layer layer, float *input)
{ {
int i; int i;
for(i = 0; i < layer.batch * layer.inputs; ++i){ for(i = 0; i < layer.batch * layer.inputs; ++i){
if((float)rand()/RAND_MAX < layer.probability) input[i] = 0; if(rand_uniform() < layer.probability) input[i] = 0;
else input[i] /= (1-layer.probability); else input[i] /= (1-layer.probability);
} }
} }
@ -24,3 +29,38 @@ void backward_dropout_layer(dropout_layer layer, float *input, float *delta)
{ {
// Don't do shit LULZ // Don't do shit LULZ
} }
#ifdef GPU
cl_kernel get_dropout_kernel()
{
static int init = 0;
static cl_kernel kernel;
if(!init){
kernel = get_kernel("src/dropout_layer.cl", "forward", 0);
init = 1;
}
return kernel;
}
void forward_dropout_layer_gpu(dropout_layer layer, cl_mem input)
{
int j;
int size = layer.inputs*layer.batch;
for(j = 0; j < size; ++j) layer.rand[j] = rand_uniform();
cl_write_array(layer.rand_cl, layer.rand, layer.inputs*layer.batch);
cl_kernel kernel = get_dropout_kernel();
cl_command_queue queue = cl.queue;
cl_uint i = 0;
cl.error = clSetKernelArg(kernel, i++, sizeof(input), (void*) &input);
cl.error = clSetKernelArg(kernel, i++, sizeof(layer.rand_cl), (void*) &layer.rand_cl);
cl.error = clSetKernelArg(kernel, i++, sizeof(layer.probability), (void*) &layer.probability);
check_error(cl);
const size_t global_size[] = {size};
cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0);
check_error(cl);
}
#endif

View File

@ -1,15 +1,23 @@
#ifndef DROPOUT_LAYER_H #ifndef DROPOUT_LAYER_H
#define DROPOUT_LAYER_H #define DROPOUT_LAYER_H
#include "opencl.h"
typedef struct{ typedef struct{
int batch; int batch;
int inputs; int inputs;
float probability; float probability;
#ifdef GPU
float *rand;
cl_mem rand_cl;
#endif
} dropout_layer; } dropout_layer;
dropout_layer *make_dropout_layer(int batch, int inputs, float probability); dropout_layer *make_dropout_layer(int batch, int inputs, float probability);
void forward_dropout_layer(dropout_layer layer, float *input); void forward_dropout_layer(dropout_layer layer, float *input);
void backward_dropout_layer(dropout_layer layer, float *input, float *delta); void backward_dropout_layer(dropout_layer layer, float *input, float *delta);
#ifdef GPU
void forward_dropout_layer_gpu(dropout_layer layer, cl_mem input);
#endif #endif
#endif

View File

@ -18,6 +18,7 @@ void forward_freeweight_layer(freeweight_layer layer, float *input)
input[i] *= 2.*((float)rand()/RAND_MAX); input[i] *= 2.*((float)rand()/RAND_MAX);
} }
} }
void backward_freeweight_layer(freeweight_layer layer, float *input, float *delta) void backward_freeweight_layer(freeweight_layer layer, float *input, float *delta)
{ {
// Don't do shit LULZ // Don't do shit LULZ

View File

@ -214,7 +214,7 @@ void gemm_ongpu_offset(int TA, int TB, int M, int N, int K, float ALPHA,
const size_t global_size[] = {ceil((float)N/BLOCK)*BLOCK, ceil((float)M/BLOCK)*BLOCK}; const size_t global_size[] = {ceil((float)N/BLOCK)*BLOCK, ceil((float)M/BLOCK)*BLOCK};
const size_t local_size[] = {BLOCK, BLOCK}; const size_t local_size[] = {BLOCK, BLOCK};
clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0); cl.error = clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0);
check_error(cl); check_error(cl);
#endif #endif
} }
@ -368,6 +368,7 @@ void test_gpu_blas()
test_gpu_accuracy(0,1,1000,10,100); test_gpu_accuracy(0,1,1000,10,100);
test_gpu_accuracy(1,1,1000,10,100); test_gpu_accuracy(1,1,1000,10,100);
*/ */
time_ongpu(0,0,512,256,1152);
time_ongpu(0,0,128,1200,4096); time_ongpu(0,0,128,1200,4096);
time_ongpu(0,0,128,1200,4096); time_ongpu(0,0,128,1200,4096);
time_ongpu(0,0,128,1200,4096); time_ongpu(0,0,128,1200,4096);
@ -377,6 +378,7 @@ void test_gpu_blas()
time_ongpu(1,0,4096,1200,128); time_ongpu(1,0,4096,1200,128);
time_ongpu(1,0,1200,128,4096); time_ongpu(1,0,1200,128,4096);
test_gpu_accuracy(0,0,512,256,1152);
test_gpu_accuracy(0,0,131,4093,1199); test_gpu_accuracy(0,0,131,4093,1199);
test_gpu_accuracy(0,1,131,4093,1199); test_gpu_accuracy(0,1,131,4093,1199);
test_gpu_accuracy(1,0,131,4093,1199); test_gpu_accuracy(1,0,131,4093,1199);

View File

@ -106,7 +106,7 @@ void im2col_ongpu(cl_mem data_im, int batch,
size_t global_size = batch*channels_col*height_col*width_col; size_t global_size = batch*channels_col*height_col*width_col;
clEnqueueNDRangeKernel(queue, kernel, 1, 0, cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0,
&global_size, 0, 0, 0, 0); &global_size, 0, 0, 0, 0);
check_error(cl); check_error(cl);
} }

View File

@ -132,7 +132,7 @@ void forward_maxpool_layer_gpu(maxpool_layer layer, cl_mem input)
const size_t global_size[] = {h*w*c*layer.batch}; const size_t global_size[] = {h*w*c*layer.batch};
clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0);
check_error(cl); check_error(cl);
} }
@ -166,7 +166,7 @@ void backward_maxpool_layer_gpu(maxpool_layer layer, cl_mem delta)
const size_t global_size[] = {layer.h*layer.w*layer.c*layer.batch}; const size_t global_size[] = {layer.h*layer.w*layer.c*layer.batch};
clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0);
check_error(cl); check_error(cl);
} }

View File

@ -53,6 +53,7 @@ void time_random_matrix(int TA, int TB, int m, int k, int n)
void test_blas() void test_blas()
{ {
time_random_matrix(0,0,100,100,100); time_random_matrix(0,0,100,100,100);
time_random_matrix(1,0,100,100,100); time_random_matrix(1,0,100,100,100);
time_random_matrix(0,1,100,100,100); time_random_matrix(0,1,100,100,100);

View File

@ -476,25 +476,11 @@ void visualize_network(network net)
} }
} }
void top_predictions(network net, int n, int *index) void top_predictions(network net, int k, int *index)
{ {
int i,j; int size = get_network_output_size(net);
int k = get_network_output_size(net);
float *out = get_network_output(net); float *out = get_network_output(net);
float thresh = FLT_MAX; top_k(out, size, k, index);
for(i = 0; i < n; ++i){
float max = -FLT_MAX;
int max_i = -1;
for(j = 0; j < k; ++j){
float val = out[j];
if(val > max && val < thresh){
max = val;
max_i = j;
}
}
index[i] = max_i;
thresh = max;
}
} }

View File

@ -22,7 +22,9 @@ void forward_network_gpu(network net, cl_mem input, cl_mem truth, int train)
{ {
//printf("start\n"); //printf("start\n");
int i; int i;
// printf("Truth: %f\n", cl_checksum(truth, 1000*net.batch));
for(i = 0; i < net.n; ++i){ for(i = 0; i < net.n; ++i){
//printf("Truth %i: %f\n", i, cl_checksum(truth, 1000*net.batch));
//clock_t time = clock(); //clock_t time = clock();
if(net.types[i] == CONVOLUTIONAL){ if(net.types[i] == CONVOLUTIONAL){
convolutional_layer layer = *(convolutional_layer *)net.layers[i]; convolutional_layer layer = *(convolutional_layer *)net.layers[i];
@ -48,6 +50,11 @@ void forward_network_gpu(network net, cl_mem input, cl_mem truth, int train)
forward_softmax_layer_gpu(layer, input); forward_softmax_layer_gpu(layer, input);
input = layer.output_cl; input = layer.output_cl;
} }
else if(net.types[i] == DROPOUT){
if(!train) continue;
dropout_layer layer = *(dropout_layer *)net.layers[i];
forward_dropout_layer_gpu(layer, input);
}
//printf("%d %f\n", i, sec(clock()-time)); //printf("%d %f\n", i, sec(clock()-time));
/* /*
else if(net.types[i] == CROP){ else if(net.types[i] == CROP){
@ -134,6 +141,8 @@ cl_mem get_network_output_cl_layer(network net, int i)
else if(net.types[i] == SOFTMAX){ else if(net.types[i] == SOFTMAX){
softmax_layer layer = *(softmax_layer *)net.layers[i]; softmax_layer layer = *(softmax_layer *)net.layers[i];
return layer.output_cl; return layer.output_cl;
} else if(net.types[i] == DROPOUT){
return get_network_output_cl_layer(net, i-1);
} }
return 0; return 0;
} }
@ -155,6 +164,8 @@ cl_mem get_network_delta_cl_layer(network net, int i)
else if(net.types[i] == SOFTMAX){ else if(net.types[i] == SOFTMAX){
softmax_layer layer = *(softmax_layer *)net.layers[i]; softmax_layer layer = *(softmax_layer *)net.layers[i];
return layer.delta_cl; return layer.delta_cl;
} else if(net.types[i] == DROPOUT){
return get_network_delta_cl_layer(net, i-1);
} }
return 0; return 0;
} }
@ -173,14 +184,18 @@ float train_network_datum_gpu(network net, float *x, float *y)
} }
//printf("trans %f\n", sec(clock()-time)); //printf("trans %f\n", sec(clock()-time));
//time = clock(); //time = clock();
forward_network_gpu(net, *net.input_cl, *net.truth_cl, 1); forward_network_gpu(net, *net.input_cl, *net.truth_cl, 1);
//printf("forw %f\n", sec(clock()-time)); //printf("forw %f\n", sec(clock()-time));
//time = clock(); //time = clock();
backward_network_gpu(net, *net.input_cl); backward_network_gpu(net, *net.input_cl);
//printf("back %f\n", sec(clock()-time)); //printf("back %f\n", sec(clock()-time));
//time = clock(); //time = clock();
update_network_gpu(net); update_network_gpu(net);
float error = get_network_cost(net); float error = get_network_cost(net);
//printf("updt %f\n", sec(clock()-time)); //printf("updt %f\n", sec(clock()-time));
//time = clock(); //time = clock();
return error; return error;

View File

@ -11,14 +11,16 @@
#include "opencl.h" #include "opencl.h"
#include "utils.h" #include "utils.h"
#include "activations.h"
cl_info cl = {0}; cl_info cl = {0};
void check_error(cl_info info) void check_error(cl_info info)
{ {
clFinish(cl.queue); // clFinish(cl.queue);
if (info.error != CL_SUCCESS) { if (info.error != CL_SUCCESS) {
printf("\n Error number %d", info.error); printf("\n Error number %d", info.error);
abort();
exit(1); exit(1);
} }
} }
@ -72,6 +74,8 @@ cl_info cl_init()
printf(" DEVICE_MAX_CLOCK_FREQUENCY = %u\n", (unsigned int)buf_uint); printf(" DEVICE_MAX_CLOCK_FREQUENCY = %u\n", (unsigned int)buf_uint);
clGetDeviceInfo(devices[i], CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(buf_ulong), &buf_ulong, NULL); clGetDeviceInfo(devices[i], CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(buf_ulong), &buf_ulong, NULL);
printf(" DEVICE_GLOBAL_MEM_SIZE = %llu\n", (unsigned long long)buf_ulong); printf(" DEVICE_GLOBAL_MEM_SIZE = %llu\n", (unsigned long long)buf_ulong);
clGetDeviceInfo(devices[i], CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(buf_ulong), &buf_ulong, NULL);
printf(" DEVICE_MAX_MEM_ALLOC_SIZE = %llu\n", (unsigned long long)buf_ulong);
clGetDeviceInfo(devices[i], CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(buf_ulong), &buf_ulong, NULL); clGetDeviceInfo(devices[i], CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(buf_ulong), &buf_ulong, NULL);
printf(" DEVICE_MAX_WORK_GROUP_SIZE = %llu\n", (unsigned long long)buf_ulong); printf(" DEVICE_MAX_WORK_GROUP_SIZE = %llu\n", (unsigned long long)buf_ulong);
cl_uint items; cl_uint items;
@ -151,21 +155,31 @@ cl_kernel get_kernel(char *filename, char *kernelname, char *options)
void cl_read_array(cl_mem mem, float *x, int n) void cl_read_array(cl_mem mem, float *x, int n)
{ {
cl_setup(); cl_setup();
clEnqueueReadBuffer(cl.queue, mem, CL_TRUE, 0, sizeof(float)*n,x,0,0,0); cl.error = clEnqueueReadBuffer(cl.queue, mem, CL_TRUE, 0, sizeof(float)*n,x,0,0,0);
check_error(cl); check_error(cl);
} }
float cl_checksum(cl_mem mem, int n)
{
float *x = calloc(n, sizeof(float));
cl_read_array(mem, x, n);
float sum = sum_array(x, n);
free(x);
return sum;
}
void cl_write_array(cl_mem mem, float *x, int n) void cl_write_array(cl_mem mem, float *x, int n)
{ {
cl_setup(); cl_setup();
clEnqueueWriteBuffer(cl.queue, mem, CL_TRUE, 0,sizeof(float)*n,x,0,0,0); cl.error = clEnqueueWriteBuffer(cl.queue, mem, CL_TRUE, 0,sizeof(float)*n,x,0,0,0);
check_error(cl); check_error(cl);
} }
void cl_copy_array(cl_mem src, cl_mem dst, int n) void cl_copy_array(cl_mem src, cl_mem dst, int n)
{ {
cl_setup(); cl_setup();
clEnqueueCopyBuffer(cl.queue, src, dst, 0, 0, sizeof(float)*n,0,0,0); cl.error = clEnqueueCopyBuffer(cl.queue, src, dst, 0, 0, sizeof(float)*n,0,0,0);
check_error(cl); check_error(cl);
} }
@ -179,6 +193,7 @@ cl_mem cl_sub_array(cl_mem src, int offset, int size)
return sub; return sub;
} }
cl_mem cl_make_array(float *x, int n) cl_mem cl_make_array(float *x, int n)
{ {
cl_setup(); cl_setup();
@ -186,6 +201,7 @@ cl_mem cl_make_array(float *x, int n)
CL_MEM_READ_WRITE|CL_MEM_COPY_HOST_PTR, CL_MEM_READ_WRITE|CL_MEM_COPY_HOST_PTR,
sizeof(float)*n, x, &cl.error); sizeof(float)*n, x, &cl.error);
check_error(cl); check_error(cl);
activate_array_ongpu(mem, n, LINEAR);
return mem; return mem;
} }

View File

@ -28,5 +28,6 @@ cl_mem cl_make_array(float *x, int n);
cl_mem cl_make_int_array(int *x, int n); cl_mem cl_make_int_array(int *x, int n);
void cl_copy_array(cl_mem src, cl_mem dst, int n); void cl_copy_array(cl_mem src, cl_mem dst, int n);
cl_mem cl_sub_array(cl_mem src, int offset, int size); cl_mem cl_sub_array(cl_mem src, int offset, int size);
float cl_checksum(cl_mem mem, int n);
#endif #endif
#endif #endif

View File

@ -81,10 +81,10 @@ void forward_softmax_layer_gpu(const softmax_layer layer, cl_mem input)
const size_t global_size[] = {layer.batch}; const size_t global_size[] = {layer.batch};
clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0); cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0);
check_error(cl); check_error(cl);
/* /*
cl_read_array(layer.output_cl, layer.output, layer.inputs*layer.batch); cl_read_array(layer.output_cl, layer.output, layer.inputs*layer.batch);
int z; int z;
for(z = 0; z < layer.inputs*layer.batch; ++z) printf("%f,",layer.output[z]); for(z = 0; z < layer.inputs*layer.batch; ++z) printf("%f,",layer.output[z]);

View File

@ -1,14 +1,51 @@
#include "utils.h"
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <math.h> #include <math.h>
#include <float.h>
#include "utils.h"
char *find_replace(char *str, char *orig, char *rep)
{
static char buffer[4096];
char *p;
if(!(p = strstr(str, orig))) // Is 'orig' even in 'str'?
return str;
strncpy(buffer, str, p-str); // Copy characters from 'str' start to 'orig' st$
buffer[p-str] = '\0';
sprintf(buffer+(p-str), "%s%s", rep, p+strlen(orig));
return buffer;
}
float sec(clock_t clocks) float sec(clock_t clocks)
{ {
return (float)clocks/CLOCKS_PER_SEC; return (float)clocks/CLOCKS_PER_SEC;
} }
void top_k(float *a, int n, int k, int *index)
{
int i,j;
float thresh = FLT_MAX;
for(i = 0; i < k; ++i){
float max = -FLT_MAX;
int max_i = -1;
for(j = 0; j < n; ++j){
float val = a[j];
if(val > max && val < thresh){
max = val;
max_i = j;
}
}
index[i] = max_i;
thresh = max;
}
}
void error(char *s) void error(char *s)
{ {
fprintf(stderr, "Error: %s\n", s); fprintf(stderr, "Error: %s\n", s);
@ -79,7 +116,7 @@ char *fgetl(FILE *fp)
} }
int curr = strlen(line); int curr = strlen(line);
while(line[curr-1]!='\n'){ while(line[curr-1]!='\n'){
size *= 2; size *= 2;
line = realloc(line, size*sizeof(char)); line = realloc(line, size*sizeof(char));
@ -121,34 +158,34 @@ list *parse_csv_line(char *line)
int count_fields(char *line) int count_fields(char *line)
{ {
int count = 0; int count = 0;
int done = 0; int done = 0;
char *c; char *c;
for(c = line; !done; ++c){ for(c = line; !done; ++c){
done = (*c == '\0'); done = (*c == '\0');
if(*c == ',' || done) ++count; if(*c == ',' || done) ++count;
} }
return count; return count;
} }
float *parse_fields(char *line, int n) float *parse_fields(char *line, int n)
{ {
float *field = calloc(n, sizeof(float)); float *field = calloc(n, sizeof(float));
char *c, *p, *end; char *c, *p, *end;
int count = 0; int count = 0;
int done = 0; int done = 0;
for(c = line, p = line; !done; ++c){ for(c = line, p = line; !done; ++c){
done = (*c == '\0'); done = (*c == '\0');
if(*c == ',' || done){ if(*c == ',' || done){
*c = '\0'; *c = '\0';
field[count] = strtod(p, &end); field[count] = strtod(p, &end);
if(p == c) field[count] = nan(""); if(p == c) field[count] = nan("");
if(end != c && (end != c-1 || *end != '\r')) field[count] = nan(""); //DOS file formats! if(end != c && (end != c-1 || *end != '\r')) field[count] = nan(""); //DOS file formats!
p = c+1; p = c+1;
++count; ++count;
} }
} }
return field; return field;
} }
float sum_array(float *a, int n) float sum_array(float *a, int n)

View File

@ -4,11 +4,13 @@
#include <time.h> #include <time.h>
#include "list.h" #include "list.h"
char *find_replace(char *str, char *orig, char *rep);
void error(char *s); void error(char *s);
void malloc_error(); void malloc_error();
void file_error(char *s); void file_error(char *s);
void strip(char *s); void strip(char *s);
void strip_char(char *s, char bad); void strip_char(char *s, char bad);
void top_k(float *a, int n, int k, int *index);
list *split_str(char *s, char delim); list *split_str(char *s, char delim);
char *fgetl(FILE *fp); char *fgetl(FILE *fp);
list *parse_csv_line(char *line); list *parse_csv_line(char *line);