gonna change im2col

This commit is contained in:
Joseph Redmon 2015-03-21 12:25:14 -07:00
parent dcb000b553
commit 4af116e996
16 changed files with 153 additions and 78 deletions

View File

@ -8,7 +8,7 @@ OBJDIR=./obj/
CC=gcc
NVCC=nvcc
OPTS=-O3
OPTS=-O0
LDFLAGS=`pkg-config --libs opencv` -lm -pthread -lstdc++
COMMON=`pkg-config --cflags opencv` -I/usr/local/cuda/include/
CFLAGS=-Wall -Wfatal-errors
@ -22,7 +22,7 @@ CFLAGS+=$(OPTS)
ifeq ($(GPU), 1)
COMMON+=-DGPU
CFLAGS+=-DGPU
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o

View File

@ -8,12 +8,19 @@ __device__ float logistic_activate_kernel(float x){return 1./(1. + exp(-x));}
__device__ float relu_activate_kernel(float x){return x*(x>0);}
__device__ float ramp_activate_kernel(float x){return x*(x>0)+.1*x;}
__device__ float tanh_activate_kernel(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
__device__ float plse_activate_kernel(float x)
{
if(x < -4) return .01 * (x + 4);
if(x > 4) return .01 * (x - 4) + 1;
return .125*x + .5;
}
__device__ float linear_gradient_kernel(float x){return 1;}
__device__ float logistic_gradient_kernel(float x){return (1-x)*x;}
__device__ float relu_gradient_kernel(float x){return (x>0);}
__device__ float ramp_gradient_kernel(float x){return (x>0)+.1;}
__device__ float tanh_gradient_kernel(float x){return 1-x*x;}
__device__ float plse_gradient_kernel(float x){return (x < 0 || x > 1) ? .01 : .125;}
__device__ float activate_kernel(float x, ACTIVATION a)
{
@ -28,6 +35,8 @@ __device__ float activate_kernel(float x, ACTIVATION a)
return ramp_activate_kernel(x);
case TANH:
return tanh_activate_kernel(x);
case PLSE:
return plse_activate_kernel(x);
}
return 0;
}
@ -45,6 +54,8 @@ __device__ float gradient_kernel(float x, ACTIVATION a)
return ramp_gradient_kernel(x);
case TANH:
return tanh_gradient_kernel(x);
case PLSE:
return plse_gradient_kernel(x);
}
return 0;
}

View File

@ -18,6 +18,8 @@ char *get_activation_string(ACTIVATION a)
return "linear";
case TANH:
return "tanh";
case PLSE:
return "plse";
default:
break;
}
@ -28,6 +30,7 @@ ACTIVATION get_activation(char *s)
{
if (strcmp(s, "logistic")==0) return LOGISTIC;
if (strcmp(s, "relu")==0) return RELU;
if (strcmp(s, "plse")==0) return PLSE;
if (strcmp(s, "linear")==0) return LINEAR;
if (strcmp(s, "ramp")==0) return RAMP;
if (strcmp(s, "tanh")==0) return TANH;
@ -48,6 +51,8 @@ float activate(float x, ACTIVATION a)
return ramp_activate(x);
case TANH:
return tanh_activate(x);
case PLSE:
return plse_activate(x);
}
return 0;
}
@ -73,6 +78,8 @@ float gradient(float x, ACTIVATION a)
return ramp_gradient(x);
case TANH:
return tanh_gradient(x);
case PLSE:
return plse_gradient(x);
}
return 0;
}

View File

@ -3,7 +3,7 @@
#define ACTIVATIONS_H
typedef enum{
LOGISTIC, RELU, LINEAR, RAMP, TANH
LOGISTIC, RELU, LINEAR, RAMP, TANH, PLSE
}ACTIVATION;
ACTIVATION get_activation(char *s);
@ -23,12 +23,19 @@ static inline float logistic_activate(float x){return 1./(1. + exp(-x));}
static inline float relu_activate(float x){return x*(x>0);}
static inline float ramp_activate(float x){return x*(x>0)+.1*x;}
static inline float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
static inline float plse_activate(float x)
{
if(x < -4) return .01 * (x + 4);
if(x > 4) return .01 * (x - 4) + 1;
return .125*x + .5;
}
static inline float linear_gradient(float x){return 1;}
static inline float logistic_gradient(float x){return (1-x)*x;}
static inline float relu_gradient(float x){return (x>0);}
static inline float ramp_gradient(float x){return (x>0)+.1;}
static inline float tanh_gradient(float x){return 1-x*x;}
static inline float plse_gradient(float x){return (x < 0 || x > 1) ? .01 : .125;}
#endif

View File

@ -16,7 +16,7 @@ void train_captcha(char *cfgfile, char *weightfile)
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = 1024;
int i = net.seen/imgs;
list *plist = get_paths("/data/captcha/train.base");
list *plist = get_paths("/data/captcha/train.auto5");
char **paths = (char **)list_to_array(plist);
printf("%d\n", plist->size);
clock_t time;
@ -34,7 +34,7 @@ void train_captcha(char *cfgfile, char *weightfile)
avg_loss = avg_loss*.9 + loss*.1;
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
free_data(train);
if(i%100==0){
if(i%10==0){
char buff[256];
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
save_weights(net, buff);
@ -56,11 +56,11 @@ void decode_captcha(char *cfgfile, char *weightfile)
printf("Enter filename: ");
fgets(filename, 256, stdin);
strtok(filename, "\n");
image im = load_image_color(filename, 60, 200);
image im = load_image_color(filename, 57, 300);
scale_image(im, 1./255.);
float *X = im.data;
float *predictions = network_predict(net, X);
image out = float_to_image(60, 200, 3, predictions);
image out = float_to_image(57, 300, 1, predictions);
show_image(out, "decoded");
cvWaitKey(0);
free_image(im);
@ -87,7 +87,7 @@ void encode_captcha(char *cfgfile, char *weightfile)
while(1){
++i;
time=clock();
data train = load_data_captcha_encode(paths, imgs, plist->size, 60, 200);
data train = load_data_captcha_encode(paths, imgs, plist->size, 57, 300);
scale_data_rows(train, 1./255);
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
@ -114,10 +114,10 @@ void validate_captcha(char *cfgfile, char *weightfile)
if(weightfile){
load_weights(&net, weightfile);
}
int imgs = 1000;
int numchars = 37;
list *plist = get_paths("/data/captcha/valid.base");
list *plist = get_paths("/data/captcha/solved.hard");
char **paths = (char **)list_to_array(plist);
int imgs = plist->size;
data valid = load_data_captcha(paths, imgs, 0, 10, 60, 200);
translate_data_rows(valid, -128);
scale_data_rows(valid, 1./128);

View File

@ -56,6 +56,7 @@ extern "C" void backward_bias_gpu(float *bias_updates, float *delta, int batch,
extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state)
{
clock_t time = clock();
int i;
int m = layer.n;
int k = layer.size*layer.size*layer.c;
@ -63,15 +64,31 @@ extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, netwo
convolutional_out_width(layer);
bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, n);
cudaDeviceSynchronize();
printf("bias %f\n", sec(clock() - time));
time = clock();
float imt=0;
float gemt = 0;
for(i = 0; i < layer.batch; ++i){
time = clock();
im2col_ongpu(state.input + i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_gpu);
cudaDeviceSynchronize();
imt += sec(clock()-time);
time = clock();
float * a = layer.filters_gpu;
float * b = layer.col_image_gpu;
float * c = layer.output_gpu;
gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
cudaDeviceSynchronize();
gemt += sec(clock()-time);
time = clock();
}
activate_array_ongpu(layer.output_gpu, m*n*layer.batch, layer.activation);
cudaDeviceSynchronize();
printf("activate %f\n", sec(clock() - time));
printf("im2col %f\n", imt);
printf("gemm %f\n", gemt);
}
extern "C" void backward_convolutional_layer_gpu(convolutional_layer layer, network_state state)

View File

@ -59,6 +59,18 @@ float *cuda_make_array(float *x, int n)
return x_gpu;
}
void cuda_random(float *x_gpu, int n)
{
static curandGenerator_t gen;
static int init = 0;
if(!init){
curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(gen, 0ULL);
}
curandGenerateUniform(gen, x_gpu, n);
check_error(cudaPeekAtLastError());
}
float cuda_compare(float *x_gpu, float *x, int n, char *s)
{
float *tmp = calloc(n, sizeof(float));

View File

@ -8,6 +8,7 @@ extern int gpu_index;
#define BLOCK 256
#include "cuda_runtime.h"
#include "curand.h"
#include "cublas_v2.h"
void check_error(cudaError_t status);
@ -17,6 +18,7 @@ int *cuda_make_int_array(int n);
void cuda_push_array(float *x_gpu, float *x, int n);
void cuda_pull_array(float *x_gpu, float *x, int n);
void cuda_free(float *x_gpu);
void cuda_random(float *x_gpu, int n);
float cuda_compare(float *x_gpu, float *x, int n, char *s);
dim3 cuda_gridsize(size_t n);

View File

@ -112,7 +112,12 @@ void fill_truth_detection(char *path, float *truth, int classes, int height, int
randomize_boxes(boxes, count);
float x, y, h, w;
int id;
int i, j;
int i;
if(background){
for(i = 0; i < num_height*num_width*(4+classes+background); i += 4+classes+background){
truth[i] = 1;
}
}
for(i = 0; i < count; ++i){
x = boxes[i].x;
y = boxes[i].y;
@ -137,21 +142,15 @@ void fill_truth_detection(char *path, float *truth, int classes, int height, int
int index = (i+j*num_width)*(4+classes+background);
if(truth[index+classes+background]) continue;
if(background) truth[index++] = 0;
truth[index+id] = 1;
index += classes+background;
index += classes;
truth[index++] = dh;
truth[index++] = dw;
truth[index++] = h*(height+jitter)/height;
truth[index++] = w*(width+jitter)/width;
}
free(boxes);
if(background){
for(i = 0; i < num_height*num_width*(4+classes+background); i += 4+classes+background){
int object = 0;
for(j = i; j < i+classes; ++j) if (truth[j]) object = 1;
truth[i+classes] = !object;
}
}
}
#define NUMCHARS 37
@ -202,6 +201,7 @@ data load_data_captcha_encode(char **paths, int n, int m, int h, int w)
data d;
d.shallow = 0;
d.X = load_image_paths(paths, n, h, w);
d.X.cols = 17100;
d.y = d.X;
if(m) free(paths);
return d;

View File

@ -108,7 +108,7 @@ void validate_detection(char *cfgfile, char *weightfile)
char **paths = (char **)list_to_array(plist);
int im_size = 448;
int classes = 20;
int background = 0;
int background = 1;
int num_output = 7*7*(4+classes+background);
int m = plist->size;
@ -143,7 +143,7 @@ void validate_detection(char *cfgfile, char *weightfile)
float x = (c + pred.vals[j][ci + 1])/7.;
float h = pred.vals[j][ci + 2];
float w = pred.vals[j][ci + 3];
printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, pred.vals[j][k+class], y, x, h, w);
printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, pred.vals[j][k+class+background], y, x, h, w);
}
}
}

View File

@ -8,23 +8,24 @@
int get_detection_layer_locations(detection_layer layer)
{
return layer.inputs / (layer.classes+layer.coords+layer.rescore);
return layer.inputs / (layer.classes+layer.coords+layer.rescore+layer.background);
}
int get_detection_layer_output_size(detection_layer layer)
{
return get_detection_layer_locations(layer)*(layer.classes+layer.coords);
return get_detection_layer_locations(layer)*(layer.background + layer.classes + layer.coords);
}
detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore)
detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background)
{
detection_layer *layer = calloc(1, sizeof(detection_layer));
layer->batch = batch;
layer->inputs = inputs;
layer->classes = classes;
layer->coords = coords;
layer->rescore = rescore;
layer->background = background;
int outputs = get_detection_layer_output_size(*layer);
layer->output = calloc(batch*outputs, sizeof(float));
layer->delta = calloc(batch*outputs, sizeof(float));
@ -39,38 +40,13 @@ detection_layer *make_detection_layer(int batch, int inputs, int classes, int co
return layer;
}
void forward_detection_layer(const detection_layer layer, network_state state)
void dark_zone(detection_layer layer, int class, int start, network_state state)
{
int in_i = 0;
int out_i = 0;
int locations = get_detection_layer_locations(layer);
int i,j;
for(i = 0; i < layer.batch*locations; ++i){
int mask = (!state.truth || state.truth[out_i + layer.classes + 2]);
float scale = 1;
if(layer.rescore) scale = state.input[in_i++];
for(j = 0; j < layer.classes; ++j){
layer.output[out_i++] = scale*state.input[in_i++];
}
if(!layer.rescore){
softmax_array(layer.output + out_i - layer.classes, layer.classes, layer.output + out_i - layer.classes);
activate_array(state.input+in_i, layer.coords, LOGISTIC);
}
for(j = 0; j < layer.coords; ++j){
layer.output[out_i++] = mask*state.input[in_i++];
}
}
}
void dark_zone(detection_layer layer, int index, network_state state)
{
int size = layer.classes+layer.rescore+layer.coords;
int index = start+layer.background+class;
int size = layer.classes+layer.coords+layer.background;
int location = (index%(7*7*size)) / size ;
int r = location / 7;
int c = location % 7;
int class = index%size;
if(layer.rescore) --class;
int dr, dc;
for(dr = -1; dr <= 1; ++dr){
for(dc = -1; dc <= 1; ++dc){
@ -79,7 +55,44 @@ void dark_zone(detection_layer layer, int index, network_state state)
if((c + dc) > 6 || (c + dc) < 0) continue;
int di = (dr*7 + dc) * size;
if(state.truth[index+di]) continue;
layer.delta[index + di] = 0;
layer.output[index + di] = 0;
//if(!state.truth[start+di]) continue;
//layer.output[start + di] = 1;
}
}
}
void forward_detection_layer(const detection_layer layer, network_state state)
{
int in_i = 0;
int out_i = 0;
int locations = get_detection_layer_locations(layer);
int i,j;
for(i = 0; i < layer.batch*locations; ++i){
int mask = (!state.truth || state.truth[out_i + layer.background + layer.classes + 2]);
float scale = 1;
if(layer.rescore) scale = state.input[in_i++];
if(layer.background) layer.output[out_i++] = scale*state.input[in_i++];
for(j = 0; j < layer.classes; ++j){
layer.output[out_i++] = scale*state.input[in_i++];
}
if(layer.background){
softmax_array(layer.output + out_i - layer.classes-layer.background, layer.classes+layer.background, layer.output + out_i - layer.classes-layer.background);
activate_array(state.input+in_i, layer.coords, LOGISTIC);
}
for(j = 0; j < layer.coords; ++j){
layer.output[out_i++] = mask*state.input[in_i++];
}
}
if(layer.background || 1){
for(i = 0; i < layer.batch*locations; ++i){
int index = i*(layer.classes+layer.coords+layer.background);
for(j= 0; j < layer.classes; ++j){
if(state.truth[index+j+layer.background]){
//dark_zone(layer, j, index, state);
}
}
}
}
}
@ -94,21 +107,17 @@ void backward_detection_layer(const detection_layer layer, network_state state)
float scale = 1;
float latent_delta = 0;
if(layer.rescore) scale = state.input[in_i++];
if(!layer.rescore){
for(j = 0; j < layer.classes-1; ++j){
if(state.truth[out_i + j]) dark_zone(layer, out_i+j, state);
}
}
if(layer.background) state.delta[in_i++] = scale*layer.delta[out_i++];
for(j = 0; j < layer.classes; ++j){
latent_delta += state.input[in_i]*layer.delta[out_i];
state.delta[in_i++] = scale*layer.delta[out_i++];
}
if (!layer.rescore) gradient_array(layer.output + out_i, layer.coords, LOGISTIC, layer.delta + out_i);
if (layer.background) gradient_array(layer.output + out_i, layer.coords, LOGISTIC, layer.delta + out_i);
for(j = 0; j < layer.coords; ++j){
state.delta[in_i++] = layer.delta[out_i++];
}
if(layer.rescore) state.delta[in_i-layer.coords-layer.classes-layer.rescore] = latent_delta;
if(layer.rescore) state.delta[in_i-layer.coords-layer.classes-layer.rescore-layer.background] = latent_delta;
}
}

View File

@ -8,6 +8,7 @@ typedef struct {
int inputs;
int classes;
int coords;
int background;
int rescore;
float *output;
float *delta;
@ -17,7 +18,7 @@ typedef struct {
#endif
} detection_layer;
detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore);
detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background);
void forward_detection_layer(const detection_layer layer, network_state state);
void backward_detection_layer(const detection_layer layer, network_state state);
int get_detection_layer_output_size(detection_layer layer);

View File

@ -14,10 +14,8 @@ __global__ void yoloswag420blazeit360noscope(float *input, int size, float *rand
extern "C" void forward_dropout_layer_gpu(dropout_layer layer, network_state state)
{
if (!state.train) return;
int j;
int size = layer.inputs*layer.batch;
for(j = 0; j < size; ++j) layer.rand[j] = rand_uniform();
cuda_push_array(layer.rand_gpu, layer.rand, layer.inputs*layer.batch);
cuda_random(layer.rand_gpu, size);
yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK>>>(state.input, size, layer.rand_gpu, layer.probability, layer.scale);
check_error(cudaPeekAtLastError());

View File

@ -13,7 +13,7 @@ void train_imagenet(char *cfgfile, char *weightfile)
load_weights(&net, weightfile);
}
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = 1024;
int imgs = 128;
int i = net.seen/imgs;
char **labels = get_labels("/home/pjreddie/data/imagenet/cls.labels.list");
list *plist = get_paths("/data/imagenet/cls.train.list");

View File

@ -28,6 +28,7 @@ void forward_network_gpu(network net, network_state state)
{
int i;
for(i = 0; i < net.n; ++i){
//clock_t time = clock();
if(net.types[i] == CONVOLUTIONAL){
forward_convolutional_layer_gpu(*(convolutional_layer *)net.layers[i], state);
}
@ -56,6 +57,9 @@ void forward_network_gpu(network net, network_state state)
forward_crop_layer_gpu(*(crop_layer *)net.layers[i], state);
}
state.input = get_network_output_gpu_layer(net, i);
//cudaDeviceSynchronize();
//printf("forw %d: %s %f\n", i, get_layer_string(net.types[i]), sec(clock() - time));
//time = clock();
}
}
@ -64,7 +68,7 @@ void backward_network_gpu(network net, network_state state)
int i;
float * original_input = state.input;
for(i = net.n-1; i >= 0; --i){
//clock_t time = clock();
//clock_t time = clock();
if(i == 0){
state.input = original_input;
state.delta = 0;
@ -96,6 +100,9 @@ void backward_network_gpu(network net, network_state state)
else if(net.types[i] == SOFTMAX){
backward_softmax_layer_gpu(*(softmax_layer *)net.layers[i], state);
}
//cudaDeviceSynchronize();
//printf("back %d: %s %f\n", i, get_layer_string(net.types[i]), sec(clock() - time));
//time = clock();
}
}
@ -181,7 +188,7 @@ float * get_network_delta_gpu_layer(network net, int i)
float train_network_datum_gpu(network net, float *x, float *y)
{
//clock_t time = clock();
// clock_t time = clock();
network_state state;
int x_size = get_network_input_size(net)*net.batch;
int y_size = get_network_output_size(net)*net.batch;
@ -195,22 +202,26 @@ float train_network_datum_gpu(network net, float *x, float *y)
state.input = *net.input_gpu;
state.truth = *net.truth_gpu;
state.train = 1;
//printf("trans %f\n", sec(clock() - time));
//time = clock();
//cudaDeviceSynchronize();
//printf("trans %f\n", sec(clock() - time));
//time = clock();
forward_network_gpu(net, state);
//printf("forw %f\n", sec(clock() - time));
//time = clock();
//cudaDeviceSynchronize();
//printf("forw %f\n", sec(clock() - time));
//time = clock();
backward_network_gpu(net, state);
//printf("back %f\n", sec(clock() - time));
//time = clock();
//cudaDeviceSynchronize();
//printf("back %f\n", sec(clock() - time));
//time = clock();
update_network_gpu(net);
float error = get_network_cost(net);
//print_letters(y, 50);
//float *out = get_network_output_gpu(net);
//print_letters(out, 50);
//printf("updt %f\n", sec(clock() - time));
//time = clock();
//cudaDeviceSynchronize();
//printf("updt %f\n", sec(clock() - time));
//time = clock();
return error;
}
@ -256,7 +267,6 @@ float *get_network_output_gpu(network net)
float *network_predict_gpu(network net, float *input)
{
int size = get_network_input_size(net) * net.batch;
network_state state;
state.input = cuda_make_array(input, size);

View File

@ -165,7 +165,8 @@ detection_layer *parse_detection(list *options, size_params params)
int coords = option_find_int(options, "coords", 1);
int classes = option_find_int(options, "classes", 1);
int rescore = option_find_int(options, "rescore", 1);
detection_layer *layer = make_detection_layer(params.batch, params.inputs, classes, coords, rescore);
int background = option_find_int(options, "background", 1);
detection_layer *layer = make_detection_layer(params.batch, params.inputs, classes, coords, rescore, background);
option_unused(options);
return layer;
}