diff --git a/Makefile b/Makefile index fc060c89..bdf1e8dc 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,7 @@ CFLAGS+=-DGPU LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand endif -OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o +OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o ifeq ($(GPU), 1) OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o endif diff --git a/src/connected_layer.c b/src/connected_layer.c index bdab6d84..bff3602a 100644 --- a/src/connected_layer.c +++ b/src/connected_layer.c @@ -9,99 +9,97 @@ #include #include -connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation) +connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation) { int i; - connected_layer *layer = calloc(1, sizeof(connected_layer)); + connected_layer l = {0}; + l.type = CONNECTED; - layer->inputs = inputs; - layer->outputs = outputs; - layer->batch=batch; + l.inputs = inputs; + l.outputs = outputs; + l.batch=batch; - layer->output = calloc(batch*outputs, sizeof(float*)); - layer->delta = calloc(batch*outputs, sizeof(float*)); + l.output = calloc(batch*outputs, sizeof(float*)); + l.delta = calloc(batch*outputs, sizeof(float*)); - layer->weight_updates = calloc(inputs*outputs, sizeof(float)); - layer->bias_updates = calloc(outputs, sizeof(float)); + l.weight_updates = calloc(inputs*outputs, sizeof(float)); + l.bias_updates = calloc(outputs, sizeof(float)); - layer->weight_prev = calloc(inputs*outputs, sizeof(float)); - layer->bias_prev = calloc(outputs, sizeof(float)); - - layer->weights = calloc(inputs*outputs, sizeof(float)); - layer->biases = calloc(outputs, sizeof(float)); + l.weights = calloc(inputs*outputs, sizeof(float)); + l.biases = calloc(outputs, sizeof(float)); float scale = 1./sqrt(inputs); for(i = 0; i < inputs*outputs; ++i){ - layer->weights[i] = 2*scale*rand_uniform() - scale; + l.weights[i] = 2*scale*rand_uniform() - scale; } for(i = 0; i < outputs; ++i){ - layer->biases[i] = scale; + l.biases[i] = scale; } #ifdef GPU - layer->weights_gpu = cuda_make_array(layer->weights, inputs*outputs); - layer->biases_gpu = cuda_make_array(layer->biases, outputs); + l.weights_gpu = cuda_make_array(l.weights, inputs*outputs); + l.biases_gpu = cuda_make_array(l.biases, outputs); - layer->weight_updates_gpu = cuda_make_array(layer->weight_updates, inputs*outputs); - layer->bias_updates_gpu = cuda_make_array(layer->bias_updates, outputs); + l.weight_updates_gpu = cuda_make_array(l.weight_updates, inputs*outputs); + l.bias_updates_gpu = cuda_make_array(l.bias_updates, outputs); - layer->output_gpu = cuda_make_array(layer->output, outputs*batch); - layer->delta_gpu = cuda_make_array(layer->delta, outputs*batch); + l.output_gpu = cuda_make_array(l.output, outputs*batch); + l.delta_gpu = cuda_make_array(l.delta, outputs*batch); #endif - layer->activation = activation; + l.activation = activation; fprintf(stderr, "Connected Layer: %d inputs, %d outputs\n", inputs, outputs); - return layer; + return l; } -void update_connected_layer(connected_layer layer, int batch, float learning_rate, float momentum, float decay) +void update_connected_layer(connected_layer l, int batch, float learning_rate, float momentum, float decay) { - axpy_cpu(layer.outputs, learning_rate/batch, layer.bias_updates, 1, layer.biases, 1); - scal_cpu(layer.outputs, momentum, layer.bias_updates, 1); + axpy_cpu(l.outputs, learning_rate/batch, l.bias_updates, 1, l.biases, 1); + scal_cpu(l.outputs, momentum, l.bias_updates, 1); - axpy_cpu(layer.inputs*layer.outputs, -decay*batch, layer.weights, 1, layer.weight_updates, 1); - axpy_cpu(layer.inputs*layer.outputs, learning_rate/batch, layer.weight_updates, 1, layer.weights, 1); - scal_cpu(layer.inputs*layer.outputs, momentum, layer.weight_updates, 1); + axpy_cpu(l.inputs*l.outputs, -decay*batch, l.weights, 1, l.weight_updates, 1); + axpy_cpu(l.inputs*l.outputs, learning_rate/batch, l.weight_updates, 1, l.weights, 1); + scal_cpu(l.inputs*l.outputs, momentum, l.weight_updates, 1); } -void forward_connected_layer(connected_layer layer, network_state state) +void forward_connected_layer(connected_layer l, network_state state) { int i; - for(i = 0; i < layer.batch; ++i){ - copy_cpu(layer.outputs, layer.biases, 1, layer.output + i*layer.outputs, 1); + for(i = 0; i < l.batch; ++i){ + copy_cpu(l.outputs, l.biases, 1, l.output + i*l.outputs, 1); } - int m = layer.batch; - int k = layer.inputs; - int n = layer.outputs; + int m = l.batch; + int k = l.inputs; + int n = l.outputs; float *a = state.input; - float *b = layer.weights; - float *c = layer.output; + float *b = l.weights; + float *c = l.output; gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); - activate_array(layer.output, layer.outputs*layer.batch, layer.activation); + activate_array(l.output, l.outputs*l.batch, l.activation); } -void backward_connected_layer(connected_layer layer, network_state state) +void backward_connected_layer(connected_layer l, network_state state) { int i; - gradient_array(layer.output, layer.outputs*layer.batch, layer.activation, layer.delta); - for(i = 0; i < layer.batch; ++i){ - axpy_cpu(layer.outputs, 1, layer.delta + i*layer.outputs, 1, layer.bias_updates, 1); + gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta); + for(i = 0; i < l.batch; ++i){ + axpy_cpu(l.outputs, 1, l.delta + i*l.outputs, 1, l.bias_updates, 1); } - int m = layer.inputs; - int k = layer.batch; - int n = layer.outputs; + int m = l.inputs; + int k = l.batch; + int n = l.outputs; float *a = state.input; - float *b = layer.delta; - float *c = layer.weight_updates; + float *b = l.delta; + float *c = l.weight_updates; gemm(1,0,m,n,k,1,a,m,b,n,1,c,n); - m = layer.batch; - k = layer.outputs; - n = layer.inputs; + m = l.batch; + k = l.outputs; + n = l.inputs; - a = layer.delta; - b = layer.weights; + a = l.delta; + b = l.weights; c = state.delta; if(c) gemm(0,1,m,n,k,1,a,k,b,k,0,c,n); @@ -109,69 +107,69 @@ void backward_connected_layer(connected_layer layer, network_state state) #ifdef GPU -void pull_connected_layer(connected_layer layer) +void pull_connected_layer(connected_layer l) { - cuda_pull_array(layer.weights_gpu, layer.weights, layer.inputs*layer.outputs); - cuda_pull_array(layer.biases_gpu, layer.biases, layer.outputs); - cuda_pull_array(layer.weight_updates_gpu, layer.weight_updates, layer.inputs*layer.outputs); - cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.outputs); + cuda_pull_array(l.weights_gpu, l.weights, l.inputs*l.outputs); + cuda_pull_array(l.biases_gpu, l.biases, l.outputs); + cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.inputs*l.outputs); + cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.outputs); } -void push_connected_layer(connected_layer layer) +void push_connected_layer(connected_layer l) { - cuda_push_array(layer.weights_gpu, layer.weights, layer.inputs*layer.outputs); - cuda_push_array(layer.biases_gpu, layer.biases, layer.outputs); - cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.inputs*layer.outputs); - cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.outputs); + cuda_push_array(l.weights_gpu, l.weights, l.inputs*l.outputs); + cuda_push_array(l.biases_gpu, l.biases, l.outputs); + cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.inputs*l.outputs); + cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.outputs); } -void update_connected_layer_gpu(connected_layer layer, int batch, float learning_rate, float momentum, float decay) +void update_connected_layer_gpu(connected_layer l, int batch, float learning_rate, float momentum, float decay) { - axpy_ongpu(layer.outputs, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1); - scal_ongpu(layer.outputs, momentum, layer.bias_updates_gpu, 1); + axpy_ongpu(l.outputs, learning_rate/batch, l.bias_updates_gpu, 1, l.biases_gpu, 1); + scal_ongpu(l.outputs, momentum, l.bias_updates_gpu, 1); - axpy_ongpu(layer.inputs*layer.outputs, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1); - axpy_ongpu(layer.inputs*layer.outputs, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1); - scal_ongpu(layer.inputs*layer.outputs, momentum, layer.weight_updates_gpu, 1); + axpy_ongpu(l.inputs*l.outputs, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1); + axpy_ongpu(l.inputs*l.outputs, learning_rate/batch, l.weight_updates_gpu, 1, l.weights_gpu, 1); + scal_ongpu(l.inputs*l.outputs, momentum, l.weight_updates_gpu, 1); } -void forward_connected_layer_gpu(connected_layer layer, network_state state) +void forward_connected_layer_gpu(connected_layer l, network_state state) { int i; - for(i = 0; i < layer.batch; ++i){ - copy_ongpu_offset(layer.outputs, layer.biases_gpu, 0, 1, layer.output_gpu, i*layer.outputs, 1); + for(i = 0; i < l.batch; ++i){ + copy_ongpu_offset(l.outputs, l.biases_gpu, 0, 1, l.output_gpu, i*l.outputs, 1); } - int m = layer.batch; - int k = layer.inputs; - int n = layer.outputs; + int m = l.batch; + int k = l.inputs; + int n = l.outputs; float * a = state.input; - float * b = layer.weights_gpu; - float * c = layer.output_gpu; + float * b = l.weights_gpu; + float * c = l.output_gpu; gemm_ongpu(0,0,m,n,k,1,a,k,b,n,1,c,n); - activate_array_ongpu(layer.output_gpu, layer.outputs*layer.batch, layer.activation); + activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); } -void backward_connected_layer_gpu(connected_layer layer, network_state state) +void backward_connected_layer_gpu(connected_layer l, network_state state) { int i; - gradient_array_ongpu(layer.output_gpu, layer.outputs*layer.batch, layer.activation, layer.delta_gpu); - for(i = 0; i < layer.batch; ++i){ - axpy_ongpu_offset(layer.outputs, 1, layer.delta_gpu, i*layer.outputs, 1, layer.bias_updates_gpu, 0, 1); + gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu); + for(i = 0; i < l.batch; ++i){ + axpy_ongpu_offset(l.outputs, 1, l.delta_gpu, i*l.outputs, 1, l.bias_updates_gpu, 0, 1); } - int m = layer.inputs; - int k = layer.batch; - int n = layer.outputs; + int m = l.inputs; + int k = l.batch; + int n = l.outputs; float * a = state.input; - float * b = layer.delta_gpu; - float * c = layer.weight_updates_gpu; + float * b = l.delta_gpu; + float * c = l.weight_updates_gpu; gemm_ongpu(1,0,m,n,k,1,a,m,b,n,1,c,n); - m = layer.batch; - k = layer.outputs; - n = layer.inputs; + m = l.batch; + k = l.outputs; + n = l.inputs; - a = layer.delta_gpu; - b = layer.weights_gpu; + a = l.delta_gpu; + b = l.weights_gpu; c = state.delta; if(c) gemm_ongpu(0,1,m,n,k,1,a,k,b,k,0,c,n); diff --git a/src/connected_layer.h b/src/connected_layer.h index 33002d2e..cea5a023 100644 --- a/src/connected_layer.h +++ b/src/connected_layer.h @@ -3,38 +3,11 @@ #include "activations.h" #include "params.h" +#include "layer.h" -typedef struct{ - int batch; - int inputs; - int outputs; - float *weights; - float *biases; +typedef layer connected_layer; - float *weight_updates; - float *bias_updates; - - float *weight_prev; - float *bias_prev; - - float *output; - float *delta; - - #ifdef GPU - float * weights_gpu; - float * biases_gpu; - - float * weight_updates_gpu; - float * bias_updates_gpu; - - float * output_gpu; - float * delta_gpu; - #endif - ACTIVATION activation; - -} connected_layer; - -connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation); +connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation); void forward_connected_layer(connected_layer layer, network_state state); void backward_connected_layer(connected_layer layer, network_state state); diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index cd357d39..b6437d4d 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -7,111 +7,117 @@ #include #include -int convolutional_out_height(convolutional_layer layer) +int convolutional_out_height(convolutional_layer l) { - int h = layer.h; - if (!layer.pad) h -= layer.size; + int h = l.h; + if (!l.pad) h -= l.size; else h -= 1; - return h/layer.stride + 1; + return h/l.stride + 1; } -int convolutional_out_width(convolutional_layer layer) +int convolutional_out_width(convolutional_layer l) { - int w = layer.w; - if (!layer.pad) w -= layer.size; + int w = l.w; + if (!l.pad) w -= l.size; else w -= 1; - return w/layer.stride + 1; + return w/l.stride + 1; } -image get_convolutional_image(convolutional_layer layer) +image get_convolutional_image(convolutional_layer l) { int h,w,c; - h = convolutional_out_height(layer); - w = convolutional_out_width(layer); - c = layer.n; - return float_to_image(w,h,c,layer.output); + h = convolutional_out_height(l); + w = convolutional_out_width(l); + c = l.n; + return float_to_image(w,h,c,l.output); } -image get_convolutional_delta(convolutional_layer layer) +image get_convolutional_delta(convolutional_layer l) { int h,w,c; - h = convolutional_out_height(layer); - w = convolutional_out_width(layer); - c = layer.n; - return float_to_image(w,h,c,layer.delta); + h = convolutional_out_height(l); + w = convolutional_out_width(l); + c = l.n; + return float_to_image(w,h,c,l.delta); } -convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation) +convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation) { int i; - convolutional_layer *layer = calloc(1, sizeof(convolutional_layer)); + convolutional_layer l = {0}; + l.type = CONVOLUTIONAL; - layer->h = h; - layer->w = w; - layer->c = c; - layer->n = n; - layer->batch = batch; - layer->stride = stride; - layer->size = size; - layer->pad = pad; + l.h = h; + l.w = w; + l.c = c; + l.n = n; + l.batch = batch; + l.stride = stride; + l.size = size; + l.pad = pad; - layer->filters = calloc(c*n*size*size, sizeof(float)); - layer->filter_updates = calloc(c*n*size*size, sizeof(float)); + l.filters = calloc(c*n*size*size, sizeof(float)); + l.filter_updates = calloc(c*n*size*size, sizeof(float)); - layer->biases = calloc(n, sizeof(float)); - layer->bias_updates = calloc(n, sizeof(float)); + l.biases = calloc(n, sizeof(float)); + l.bias_updates = calloc(n, sizeof(float)); float scale = 1./sqrt(size*size*c); - for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = 2*scale*rand_uniform() - scale; + for(i = 0; i < c*n*size*size; ++i) l.filters[i] = 2*scale*rand_uniform() - scale; for(i = 0; i < n; ++i){ - layer->biases[i] = scale; + l.biases[i] = scale; } - int out_h = convolutional_out_height(*layer); - int out_w = convolutional_out_width(*layer); + int out_h = convolutional_out_height(l); + int out_w = convolutional_out_width(l); + l.out_h = out_h; + l.out_w = out_w; + l.out_c = n; + l.outputs = l.out_h * l.out_w * l.out_c; + l.inputs = l.w * l.h * l.c; - layer->col_image = calloc(out_h*out_w*size*size*c, sizeof(float)); - layer->output = calloc(layer->batch*out_h * out_w * n, sizeof(float)); - layer->delta = calloc(layer->batch*out_h * out_w * n, sizeof(float)); + l.col_image = calloc(out_h*out_w*size*size*c, sizeof(float)); + l.output = calloc(l.batch*out_h * out_w * n, sizeof(float)); + l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float)); #ifdef GPU - layer->filters_gpu = cuda_make_array(layer->filters, c*n*size*size); - layer->filter_updates_gpu = cuda_make_array(layer->filter_updates, c*n*size*size); + l.filters_gpu = cuda_make_array(l.filters, c*n*size*size); + l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size); - layer->biases_gpu = cuda_make_array(layer->biases, n); - layer->bias_updates_gpu = cuda_make_array(layer->bias_updates, n); + l.biases_gpu = cuda_make_array(l.biases, n); + l.bias_updates_gpu = cuda_make_array(l.bias_updates, n); - layer->col_image_gpu = cuda_make_array(layer->col_image, out_h*out_w*size*size*c); - layer->delta_gpu = cuda_make_array(layer->delta, layer->batch*out_h*out_w*n); - layer->output_gpu = cuda_make_array(layer->output, layer->batch*out_h*out_w*n); + l.col_image_gpu = cuda_make_array(l.col_image, out_h*out_w*size*size*c); + l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n); + l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); #endif - layer->activation = activation; + l.activation = activation; fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n); - return layer; + return l; } -void resize_convolutional_layer(convolutional_layer *layer, int h, int w) +void resize_convolutional_layer(convolutional_layer *l, int h, int w) { - layer->h = h; - layer->w = w; - int out_h = convolutional_out_height(*layer); - int out_w = convolutional_out_width(*layer); + l->h = h; + l->w = w; + int out_h = convolutional_out_height(*l); + int out_w = convolutional_out_width(*l); - layer->col_image = realloc(layer->col_image, - out_h*out_w*layer->size*layer->size*layer->c*sizeof(float)); - layer->output = realloc(layer->output, - layer->batch*out_h * out_w * layer->n*sizeof(float)); - layer->delta = realloc(layer->delta, - layer->batch*out_h * out_w * layer->n*sizeof(float)); + l->col_image = realloc(l->col_image, + out_h*out_w*l->size*l->size*l->c*sizeof(float)); + l->output = realloc(l->output, + l->batch*out_h * out_w * l->n*sizeof(float)); + l->delta = realloc(l->delta, + l->batch*out_h * out_w * l->n*sizeof(float)); #ifdef GPU - cuda_free(layer->col_image_gpu); - cuda_free(layer->delta_gpu); - cuda_free(layer->output_gpu); + cuda_free(l->col_image_gpu); + cuda_free(l->delta_gpu); + cuda_free(l->output_gpu); - layer->col_image_gpu = cuda_make_array(layer->col_image, out_h*out_w*layer->size*layer->size*layer->c); - layer->delta_gpu = cuda_make_array(layer->delta, layer->batch*out_h*out_w*layer->n); - layer->output_gpu = cuda_make_array(layer->output, layer->batch*out_h*out_w*layer->n); + l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c); + l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n); + l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n); #endif } @@ -138,104 +144,104 @@ void backward_bias(float *bias_updates, float *delta, int batch, int n, int size } -void forward_convolutional_layer(const convolutional_layer layer, network_state state) +void forward_convolutional_layer(const convolutional_layer l, network_state state) { - int out_h = convolutional_out_height(layer); - int out_w = convolutional_out_width(layer); + int out_h = convolutional_out_height(l); + int out_w = convolutional_out_width(l); int i; - bias_output(layer.output, layer.biases, layer.batch, layer.n, out_h*out_w); + bias_output(l.output, l.biases, l.batch, l.n, out_h*out_w); - int m = layer.n; - int k = layer.size*layer.size*layer.c; + int m = l.n; + int k = l.size*l.size*l.c; int n = out_h*out_w; - float *a = layer.filters; - float *b = layer.col_image; - float *c = layer.output; + float *a = l.filters; + float *b = l.col_image; + float *c = l.output; - for(i = 0; i < layer.batch; ++i){ - im2col_cpu(state.input, layer.c, layer.h, layer.w, - layer.size, layer.stride, layer.pad, b); + for(i = 0; i < l.batch; ++i){ + im2col_cpu(state.input, l.c, l.h, l.w, + l.size, l.stride, l.pad, b); gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); c += n*m; - state.input += layer.c*layer.h*layer.w; + state.input += l.c*l.h*l.w; } - activate_array(layer.output, m*n*layer.batch, layer.activation); + activate_array(l.output, m*n*l.batch, l.activation); } -void backward_convolutional_layer(convolutional_layer layer, network_state state) +void backward_convolutional_layer(convolutional_layer l, network_state state) { int i; - int m = layer.n; - int n = layer.size*layer.size*layer.c; - int k = convolutional_out_height(layer)* - convolutional_out_width(layer); + int m = l.n; + int n = l.size*l.size*l.c; + int k = convolutional_out_height(l)* + convolutional_out_width(l); - gradient_array(layer.output, m*k*layer.batch, layer.activation, layer.delta); - backward_bias(layer.bias_updates, layer.delta, layer.batch, layer.n, k); + gradient_array(l.output, m*k*l.batch, l.activation, l.delta); + backward_bias(l.bias_updates, l.delta, l.batch, l.n, k); - if(state.delta) memset(state.delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float)); + if(state.delta) memset(state.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float)); - for(i = 0; i < layer.batch; ++i){ - float *a = layer.delta + i*m*k; - float *b = layer.col_image; - float *c = layer.filter_updates; + for(i = 0; i < l.batch; ++i){ + float *a = l.delta + i*m*k; + float *b = l.col_image; + float *c = l.filter_updates; - float *im = state.input+i*layer.c*layer.h*layer.w; + float *im = state.input+i*l.c*l.h*l.w; - im2col_cpu(im, layer.c, layer.h, layer.w, - layer.size, layer.stride, layer.pad, b); + im2col_cpu(im, l.c, l.h, l.w, + l.size, l.stride, l.pad, b); gemm(0,1,m,n,k,1,a,k,b,k,1,c,n); if(state.delta){ - a = layer.filters; - b = layer.delta + i*m*k; - c = layer.col_image; + a = l.filters; + b = l.delta + i*m*k; + c = l.col_image; gemm(1,0,n,k,m,1,a,n,b,k,0,c,k); - col2im_cpu(layer.col_image, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, state.delta+i*layer.c*layer.h*layer.w); + col2im_cpu(l.col_image, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta+i*l.c*l.h*l.w); } } } -void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay) +void update_convolutional_layer(convolutional_layer l, int batch, float learning_rate, float momentum, float decay) { - int size = layer.size*layer.size*layer.c*layer.n; - axpy_cpu(layer.n, learning_rate/batch, layer.bias_updates, 1, layer.biases, 1); - scal_cpu(layer.n, momentum, layer.bias_updates, 1); + int size = l.size*l.size*l.c*l.n; + axpy_cpu(l.n, learning_rate/batch, l.bias_updates, 1, l.biases, 1); + scal_cpu(l.n, momentum, l.bias_updates, 1); - axpy_cpu(size, -decay*batch, layer.filters, 1, layer.filter_updates, 1); - axpy_cpu(size, learning_rate/batch, layer.filter_updates, 1, layer.filters, 1); - scal_cpu(size, momentum, layer.filter_updates, 1); + axpy_cpu(size, -decay*batch, l.filters, 1, l.filter_updates, 1); + axpy_cpu(size, learning_rate/batch, l.filter_updates, 1, l.filters, 1); + scal_cpu(size, momentum, l.filter_updates, 1); } -image get_convolutional_filter(convolutional_layer layer, int i) +image get_convolutional_filter(convolutional_layer l, int i) { - int h = layer.size; - int w = layer.size; - int c = layer.c; - return float_to_image(w,h,c,layer.filters+i*h*w*c); + int h = l.size; + int w = l.size; + int c = l.c; + return float_to_image(w,h,c,l.filters+i*h*w*c); } -image *get_filters(convolutional_layer layer) +image *get_filters(convolutional_layer l) { - image *filters = calloc(layer.n, sizeof(image)); + image *filters = calloc(l.n, sizeof(image)); int i; - for(i = 0; i < layer.n; ++i){ - filters[i] = copy_image(get_convolutional_filter(layer, i)); + for(i = 0; i < l.n; ++i){ + filters[i] = copy_image(get_convolutional_filter(l, i)); } return filters; } -image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters) +image *visualize_convolutional_layer(convolutional_layer l, char *window, image *prev_filters) { - image *single_filters = get_filters(layer); - show_images(single_filters, layer.n, window); + image *single_filters = get_filters(l); + show_images(single_filters, l.n, window); - image delta = get_convolutional_image(layer); + image delta = get_convolutional_image(l); image dc = collapse_image_layers(delta, 1); char buff[256]; sprintf(buff, "%s: Output", window); diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index 5cf8adca..334759cd 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -5,38 +5,9 @@ #include "params.h" #include "image.h" #include "activations.h" +#include "layer.h" -typedef struct { - int batch; - int h,w,c; - int n; - int size; - int stride; - int pad; - float *filters; - float *filter_updates; - - float *biases; - float *bias_updates; - - float *col_image; - float *delta; - float *output; - - #ifdef GPU - float * filters_gpu; - float * filter_updates_gpu; - - float * biases_gpu; - float * bias_updates_gpu; - - float * col_image_gpu; - float * delta_gpu; - float * output_gpu; - #endif - - ACTIVATION activation; -} convolutional_layer; +typedef layer convolutional_layer; #ifdef GPU void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state); @@ -50,7 +21,7 @@ void bias_output_gpu(float *output, float *biases, int batch, int n, int size); void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size); #endif -convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation); +convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation); void resize_convolutional_layer(convolutional_layer *layer, int h, int w); void forward_convolutional_layer(const convolutional_layer layer, network_state state); void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay); diff --git a/src/cost_layer.c b/src/cost_layer.c index 1f36232b..24f6ffa3 100644 --- a/src/cost_layer.c +++ b/src/cost_layer.c @@ -26,70 +26,73 @@ char *get_cost_string(COST_TYPE a) return "sse"; } -cost_layer *make_cost_layer(int batch, int inputs, COST_TYPE type) +cost_layer make_cost_layer(int batch, int inputs, COST_TYPE cost_type) { fprintf(stderr, "Cost Layer: %d inputs\n", inputs); - cost_layer *layer = calloc(1, sizeof(cost_layer)); - layer->batch = batch; - layer->inputs = inputs; - layer->type = type; - layer->delta = calloc(inputs*batch, sizeof(float)); - layer->output = calloc(1, sizeof(float)); + cost_layer l = {0}; + l.type = COST; + + l.batch = batch; + l.inputs = inputs; + l.outputs = inputs; + l.cost_type = cost_type; + l.delta = calloc(inputs*batch, sizeof(float)); + l.output = calloc(1, sizeof(float)); #ifdef GPU - layer->delta_gpu = cuda_make_array(layer->delta, inputs*batch); + l.delta_gpu = cuda_make_array(l.delta, inputs*batch); #endif - return layer; + return l; } -void forward_cost_layer(cost_layer layer, network_state state) +void forward_cost_layer(cost_layer l, network_state state) { if (!state.truth) return; - if(layer.type == MASKED){ + if(l.cost_type == MASKED){ int i; - for(i = 0; i < layer.batch*layer.inputs; ++i){ + for(i = 0; i < l.batch*l.inputs; ++i){ if(state.truth[i] == 0) state.input[i] = 0; } } - copy_cpu(layer.batch*layer.inputs, state.truth, 1, layer.delta, 1); - axpy_cpu(layer.batch*layer.inputs, -1, state.input, 1, layer.delta, 1); - *(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1); - //printf("cost: %f\n", *layer.output); + copy_cpu(l.batch*l.inputs, state.truth, 1, l.delta, 1); + axpy_cpu(l.batch*l.inputs, -1, state.input, 1, l.delta, 1); + *(l.output) = dot_cpu(l.batch*l.inputs, l.delta, 1, l.delta, 1); + //printf("cost: %f\n", *l.output); } -void backward_cost_layer(const cost_layer layer, network_state state) +void backward_cost_layer(const cost_layer l, network_state state) { - copy_cpu(layer.batch*layer.inputs, layer.delta, 1, state.delta, 1); + copy_cpu(l.batch*l.inputs, l.delta, 1, state.delta, 1); } #ifdef GPU -void pull_cost_layer(cost_layer layer) +void pull_cost_layer(cost_layer l) { - cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs); + cuda_pull_array(l.delta_gpu, l.delta, l.batch*l.inputs); } -void push_cost_layer(cost_layer layer) +void push_cost_layer(cost_layer l) { - cuda_push_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs); + cuda_push_array(l.delta_gpu, l.delta, l.batch*l.inputs); } -void forward_cost_layer_gpu(cost_layer layer, network_state state) +void forward_cost_layer_gpu(cost_layer l, network_state state) { if (!state.truth) return; - if (layer.type == MASKED) { - mask_ongpu(layer.batch*layer.inputs, state.input, state.truth); + if (l.cost_type == MASKED) { + mask_ongpu(l.batch*l.inputs, state.input, state.truth); } - copy_ongpu(layer.batch*layer.inputs, state.truth, 1, layer.delta_gpu, 1); - axpy_ongpu(layer.batch*layer.inputs, -1, state.input, 1, layer.delta_gpu, 1); + copy_ongpu(l.batch*l.inputs, state.truth, 1, l.delta_gpu, 1); + axpy_ongpu(l.batch*l.inputs, -1, state.input, 1, l.delta_gpu, 1); - cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs); - *(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1); + cuda_pull_array(l.delta_gpu, l.delta, l.batch*l.inputs); + *(l.output) = dot_cpu(l.batch*l.inputs, l.delta, 1, l.delta, 1); } -void backward_cost_layer_gpu(const cost_layer layer, network_state state) +void backward_cost_layer_gpu(const cost_layer l, network_state state) { - copy_ongpu(layer.batch*layer.inputs, layer.delta_gpu, 1, state.delta, 1); + copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1); } #endif diff --git a/src/cost_layer.h b/src/cost_layer.h index 0b92a11b..07323239 100644 --- a/src/cost_layer.h +++ b/src/cost_layer.h @@ -1,33 +1,19 @@ #ifndef COST_LAYER_H #define COST_LAYER_H #include "params.h" +#include "layer.h" -typedef enum{ - SSE, MASKED -} COST_TYPE; - -typedef struct { - int inputs; - int batch; - int coords; - int classes; - float *delta; - float *output; - COST_TYPE type; - #ifdef GPU - float * delta_gpu; - #endif -} cost_layer; +typedef layer cost_layer; COST_TYPE get_cost_type(char *s); char *get_cost_string(COST_TYPE a); -cost_layer *make_cost_layer(int batch, int inputs, COST_TYPE type); -void forward_cost_layer(const cost_layer layer, network_state state); -void backward_cost_layer(const cost_layer layer, network_state state); +cost_layer make_cost_layer(int batch, int inputs, COST_TYPE type); +void forward_cost_layer(const cost_layer l, network_state state); +void backward_cost_layer(const cost_layer l, network_state state); #ifdef GPU -void forward_cost_layer_gpu(cost_layer layer, network_state state); -void backward_cost_layer_gpu(const cost_layer layer, network_state state); +void forward_cost_layer_gpu(cost_layer l, network_state state); +void backward_cost_layer_gpu(const cost_layer l, network_state state); #endif #endif diff --git a/src/crop_layer.c b/src/crop_layer.c index 7ae4aa5d..13190219 100644 --- a/src/crop_layer.c +++ b/src/crop_layer.c @@ -2,63 +2,69 @@ #include "cuda.h" #include -image get_crop_image(crop_layer layer) +image get_crop_image(crop_layer l) { - int h = layer.crop_height; - int w = layer.crop_width; - int c = layer.c; - return float_to_image(w,h,c,layer.output); + int h = l.out_h; + int w = l.out_w; + int c = l.out_c; + return float_to_image(w,h,c,l.output); } -crop_layer *make_crop_layer(int batch, int h, int w, int c, int crop_height, int crop_width, int flip, float angle, float saturation, float exposure) +crop_layer make_crop_layer(int batch, int h, int w, int c, int crop_height, int crop_width, int flip, float angle, float saturation, float exposure) { fprintf(stderr, "Crop Layer: %d x %d -> %d x %d x %d image\n", h,w,crop_height,crop_width,c); - crop_layer *layer = calloc(1, sizeof(crop_layer)); - layer->batch = batch; - layer->h = h; - layer->w = w; - layer->c = c; - layer->flip = flip; - layer->angle = angle; - layer->saturation = saturation; - layer->exposure = exposure; - layer->crop_width = crop_width; - layer->crop_height = crop_height; - layer->output = calloc(crop_width*crop_height * c*batch, sizeof(float)); + crop_layer l = {0}; + l.type = CROP; + l.batch = batch; + l.h = h; + l.w = w; + l.c = c; + l.flip = flip; + l.angle = angle; + l.saturation = saturation; + l.exposure = exposure; + l.crop_width = crop_width; + l.crop_height = crop_height; + l.out_w = crop_width; + l.out_h = crop_height; + l.out_c = c; + l.inputs = l.w * l.h * l.c; + l.outputs = l.out_w * l.out_h * l.out_c; + l.output = calloc(crop_width*crop_height * c*batch, sizeof(float)); #ifdef GPU - layer->output_gpu = cuda_make_array(layer->output, crop_width*crop_height*c*batch); - layer->rand_gpu = cuda_make_array(0, layer->batch*8); + l.output_gpu = cuda_make_array(l.output, crop_width*crop_height*c*batch); + l.rand_gpu = cuda_make_array(0, l.batch*8); #endif - return layer; + return l; } -void forward_crop_layer(const crop_layer layer, network_state state) +void forward_crop_layer(const crop_layer l, network_state state) { int i,j,c,b,row,col; int index; int count = 0; - int flip = (layer.flip && rand()%2); - int dh = rand()%(layer.h - layer.crop_height + 1); - int dw = rand()%(layer.w - layer.crop_width + 1); + int flip = (l.flip && rand()%2); + int dh = rand()%(l.h - l.crop_height + 1); + int dw = rand()%(l.w - l.crop_width + 1); float scale = 2; float trans = -1; if(!state.train){ flip = 0; - dh = (layer.h - layer.crop_height)/2; - dw = (layer.w - layer.crop_width)/2; + dh = (l.h - l.crop_height)/2; + dw = (l.w - l.crop_width)/2; } - for(b = 0; b < layer.batch; ++b){ - for(c = 0; c < layer.c; ++c){ - for(i = 0; i < layer.crop_height; ++i){ - for(j = 0; j < layer.crop_width; ++j){ + for(b = 0; b < l.batch; ++b){ + for(c = 0; c < l.c; ++c){ + for(i = 0; i < l.crop_height; ++i){ + for(j = 0; j < l.crop_width; ++j){ if(flip){ - col = layer.w - dw - j - 1; + col = l.w - dw - j - 1; }else{ col = j + dw; } row = i + dh; - index = col+layer.w*(row+layer.h*(c + layer.c*b)); - layer.output[count++] = state.input[index]*scale + trans; + index = col+l.w*(row+l.h*(c + l.c*b)); + l.output[count++] = state.input[index]*scale + trans; } } } diff --git a/src/crop_layer.h b/src/crop_layer.h index 00333391..81641862 100644 --- a/src/crop_layer.h +++ b/src/crop_layer.h @@ -3,29 +3,16 @@ #include "image.h" #include "params.h" +#include "layer.h" -typedef struct { - int batch; - int h,w,c; - int crop_width; - int crop_height; - int flip; - float angle; - float saturation; - float exposure; - float *output; -#ifdef GPU - float *output_gpu; - float *rand_gpu; -#endif -} crop_layer; +typedef layer crop_layer; -image get_crop_image(crop_layer layer); -crop_layer *make_crop_layer(int batch, int h, int w, int c, int crop_height, int crop_width, int flip, float angle, float saturation, float exposure); -void forward_crop_layer(const crop_layer layer, network_state state); +image get_crop_image(crop_layer l); +crop_layer make_crop_layer(int batch, int h, int w, int c, int crop_height, int crop_width, int flip, float angle, float saturation, float exposure); +void forward_crop_layer(const crop_layer l, network_state state); #ifdef GPU -void forward_crop_layer_gpu(crop_layer layer, network_state state); +void forward_crop_layer_gpu(crop_layer l, network_state state); #endif #endif diff --git a/src/darknet.c b/src/darknet.c index 411efdf7..37f80ecd 100644 --- a/src/darknet.c +++ b/src/darknet.c @@ -72,15 +72,6 @@ void partial(char *cfgfile, char *weightfile, char *outfile, int max) save_weights(net, outfile); } -void convert(char *cfgfile, char *outfile, char *weightfile) -{ - network net = parse_network_cfg(cfgfile); - if(weightfile){ - load_weights(&net, weightfile); - } - save_network(net, outfile); -} - void visualize(char *cfgfile, char *weightfile) { network net = parse_network_cfg(cfgfile); @@ -120,8 +111,6 @@ int main(int argc, char **argv) run_captcha(argc, argv); } else if (0 == strcmp(argv[1], "change")){ change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0); - } else if (0 == strcmp(argv[1], "convert")){ - convert(argv[2], argv[3], (argc > 4) ? argv[4] : 0); } else if (0 == strcmp(argv[1], "partial")){ partial(argv[2], argv[3], argv[4], atoi(argv[5])); } else if (0 == strcmp(argv[1], "visualize")){ diff --git a/src/data.c b/src/data.c index 0aad98ca..902f30cf 100644 --- a/src/data.c +++ b/src/data.c @@ -174,7 +174,7 @@ void fill_truth_detection(char *path, float *truth, int classes, int num_boxes, } int index = (i+j*num_boxes)*(4+classes+background); - if(truth[index+classes+background+2]) continue; + //if(truth[index+classes+background+2]) continue; if(background) truth[index++] = 0; truth[index+id] = 1; index += classes; diff --git a/src/deconvolutional_layer.c b/src/deconvolutional_layer.c index 532045c2..524fc958 100644 --- a/src/deconvolutional_layer.c +++ b/src/deconvolutional_layer.c @@ -8,172 +8,179 @@ #include #include -int deconvolutional_out_height(deconvolutional_layer layer) +int deconvolutional_out_height(deconvolutional_layer l) { - int h = layer.stride*(layer.h - 1) + layer.size; + int h = l.stride*(l.h - 1) + l.size; return h; } -int deconvolutional_out_width(deconvolutional_layer layer) +int deconvolutional_out_width(deconvolutional_layer l) { - int w = layer.stride*(layer.w - 1) + layer.size; + int w = l.stride*(l.w - 1) + l.size; return w; } -int deconvolutional_out_size(deconvolutional_layer layer) +int deconvolutional_out_size(deconvolutional_layer l) { - return deconvolutional_out_height(layer) * deconvolutional_out_width(layer); + return deconvolutional_out_height(l) * deconvolutional_out_width(l); } -image get_deconvolutional_image(deconvolutional_layer layer) +image get_deconvolutional_image(deconvolutional_layer l) { int h,w,c; - h = deconvolutional_out_height(layer); - w = deconvolutional_out_width(layer); - c = layer.n; - return float_to_image(w,h,c,layer.output); + h = deconvolutional_out_height(l); + w = deconvolutional_out_width(l); + c = l.n; + return float_to_image(w,h,c,l.output); } -image get_deconvolutional_delta(deconvolutional_layer layer) +image get_deconvolutional_delta(deconvolutional_layer l) { int h,w,c; - h = deconvolutional_out_height(layer); - w = deconvolutional_out_width(layer); - c = layer.n; - return float_to_image(w,h,c,layer.delta); + h = deconvolutional_out_height(l); + w = deconvolutional_out_width(l); + c = l.n; + return float_to_image(w,h,c,l.delta); } -deconvolutional_layer *make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation) +deconvolutional_layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation) { int i; - deconvolutional_layer *layer = calloc(1, sizeof(deconvolutional_layer)); + deconvolutional_layer l = {0}; + l.type = DECONVOLUTIONAL; - layer->h = h; - layer->w = w; - layer->c = c; - layer->n = n; - layer->batch = batch; - layer->stride = stride; - layer->size = size; + l.h = h; + l.w = w; + l.c = c; + l.n = n; + l.batch = batch; + l.stride = stride; + l.size = size; - layer->filters = calloc(c*n*size*size, sizeof(float)); - layer->filter_updates = calloc(c*n*size*size, sizeof(float)); + l.filters = calloc(c*n*size*size, sizeof(float)); + l.filter_updates = calloc(c*n*size*size, sizeof(float)); - layer->biases = calloc(n, sizeof(float)); - layer->bias_updates = calloc(n, sizeof(float)); + l.biases = calloc(n, sizeof(float)); + l.bias_updates = calloc(n, sizeof(float)); float scale = 1./sqrt(size*size*c); - for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*rand_normal(); + for(i = 0; i < c*n*size*size; ++i) l.filters[i] = scale*rand_normal(); for(i = 0; i < n; ++i){ - layer->biases[i] = scale; + l.biases[i] = scale; } - int out_h = deconvolutional_out_height(*layer); - int out_w = deconvolutional_out_width(*layer); + int out_h = deconvolutional_out_height(l); + int out_w = deconvolutional_out_width(l); - layer->col_image = calloc(h*w*size*size*n, sizeof(float)); - layer->output = calloc(layer->batch*out_h * out_w * n, sizeof(float)); - layer->delta = calloc(layer->batch*out_h * out_w * n, sizeof(float)); + l.out_h = out_h; + l.out_w = out_w; + l.out_c = n; + l.outputs = l.out_w * l.out_h * l.out_c; + l.inputs = l.w * l.h * l.c; + + l.col_image = calloc(h*w*size*size*n, sizeof(float)); + l.output = calloc(l.batch*out_h * out_w * n, sizeof(float)); + l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float)); #ifdef GPU - layer->filters_gpu = cuda_make_array(layer->filters, c*n*size*size); - layer->filter_updates_gpu = cuda_make_array(layer->filter_updates, c*n*size*size); + l.filters_gpu = cuda_make_array(l.filters, c*n*size*size); + l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size); - layer->biases_gpu = cuda_make_array(layer->biases, n); - layer->bias_updates_gpu = cuda_make_array(layer->bias_updates, n); + l.biases_gpu = cuda_make_array(l.biases, n); + l.bias_updates_gpu = cuda_make_array(l.bias_updates, n); - layer->col_image_gpu = cuda_make_array(layer->col_image, h*w*size*size*n); - layer->delta_gpu = cuda_make_array(layer->delta, layer->batch*out_h*out_w*n); - layer->output_gpu = cuda_make_array(layer->output, layer->batch*out_h*out_w*n); + l.col_image_gpu = cuda_make_array(l.col_image, h*w*size*size*n); + l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n); + l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); #endif - layer->activation = activation; + l.activation = activation; fprintf(stderr, "Deconvolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n); - return layer; + return l; } -void resize_deconvolutional_layer(deconvolutional_layer *layer, int h, int w) +void resize_deconvolutional_layer(deconvolutional_layer *l, int h, int w) { - layer->h = h; - layer->w = w; - int out_h = deconvolutional_out_height(*layer); - int out_w = deconvolutional_out_width(*layer); + l->h = h; + l->w = w; + int out_h = deconvolutional_out_height(*l); + int out_w = deconvolutional_out_width(*l); - layer->col_image = realloc(layer->col_image, - out_h*out_w*layer->size*layer->size*layer->c*sizeof(float)); - layer->output = realloc(layer->output, - layer->batch*out_h * out_w * layer->n*sizeof(float)); - layer->delta = realloc(layer->delta, - layer->batch*out_h * out_w * layer->n*sizeof(float)); + l->col_image = realloc(l->col_image, + out_h*out_w*l->size*l->size*l->c*sizeof(float)); + l->output = realloc(l->output, + l->batch*out_h * out_w * l->n*sizeof(float)); + l->delta = realloc(l->delta, + l->batch*out_h * out_w * l->n*sizeof(float)); #ifdef GPU - cuda_free(layer->col_image_gpu); - cuda_free(layer->delta_gpu); - cuda_free(layer->output_gpu); + cuda_free(l->col_image_gpu); + cuda_free(l->delta_gpu); + cuda_free(l->output_gpu); - layer->col_image_gpu = cuda_make_array(layer->col_image, out_h*out_w*layer->size*layer->size*layer->c); - layer->delta_gpu = cuda_make_array(layer->delta, layer->batch*out_h*out_w*layer->n); - layer->output_gpu = cuda_make_array(layer->output, layer->batch*out_h*out_w*layer->n); + l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c); + l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n); + l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n); #endif } -void forward_deconvolutional_layer(const deconvolutional_layer layer, network_state state) +void forward_deconvolutional_layer(const deconvolutional_layer l, network_state state) { int i; - int out_h = deconvolutional_out_height(layer); - int out_w = deconvolutional_out_width(layer); + int out_h = deconvolutional_out_height(l); + int out_w = deconvolutional_out_width(l); int size = out_h*out_w; - int m = layer.size*layer.size*layer.n; - int n = layer.h*layer.w; - int k = layer.c; + int m = l.size*l.size*l.n; + int n = l.h*l.w; + int k = l.c; - bias_output(layer.output, layer.biases, layer.batch, layer.n, size); + bias_output(l.output, l.biases, l.batch, l.n, size); - for(i = 0; i < layer.batch; ++i){ - float *a = layer.filters; - float *b = state.input + i*layer.c*layer.h*layer.w; - float *c = layer.col_image; + for(i = 0; i < l.batch; ++i){ + float *a = l.filters; + float *b = state.input + i*l.c*l.h*l.w; + float *c = l.col_image; gemm(1,0,m,n,k,1,a,m,b,n,0,c,n); - col2im_cpu(c, layer.n, out_h, out_w, layer.size, layer.stride, 0, layer.output+i*layer.n*size); + col2im_cpu(c, l.n, out_h, out_w, l.size, l.stride, 0, l.output+i*l.n*size); } - activate_array(layer.output, layer.batch*layer.n*size, layer.activation); + activate_array(l.output, l.batch*l.n*size, l.activation); } -void backward_deconvolutional_layer(deconvolutional_layer layer, network_state state) +void backward_deconvolutional_layer(deconvolutional_layer l, network_state state) { - float alpha = 1./layer.batch; - int out_h = deconvolutional_out_height(layer); - int out_w = deconvolutional_out_width(layer); + float alpha = 1./l.batch; + int out_h = deconvolutional_out_height(l); + int out_w = deconvolutional_out_width(l); int size = out_h*out_w; int i; - gradient_array(layer.output, size*layer.n*layer.batch, layer.activation, layer.delta); - backward_bias(layer.bias_updates, layer.delta, layer.batch, layer.n, size); + gradient_array(l.output, size*l.n*l.batch, l.activation, l.delta); + backward_bias(l.bias_updates, l.delta, l.batch, l.n, size); - if(state.delta) memset(state.delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float)); + if(state.delta) memset(state.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float)); - for(i = 0; i < layer.batch; ++i){ - int m = layer.c; - int n = layer.size*layer.size*layer.n; - int k = layer.h*layer.w; + for(i = 0; i < l.batch; ++i){ + int m = l.c; + int n = l.size*l.size*l.n; + int k = l.h*l.w; float *a = state.input + i*m*n; - float *b = layer.col_image; - float *c = layer.filter_updates; + float *b = l.col_image; + float *c = l.filter_updates; - im2col_cpu(layer.delta + i*layer.n*size, layer.n, out_h, out_w, - layer.size, layer.stride, 0, b); + im2col_cpu(l.delta + i*l.n*size, l.n, out_h, out_w, + l.size, l.stride, 0, b); gemm(0,1,m,n,k,alpha,a,k,b,k,1,c,n); if(state.delta){ - int m = layer.c; - int n = layer.h*layer.w; - int k = layer.size*layer.size*layer.n; + int m = l.c; + int n = l.h*l.w; + int k = l.size*l.size*l.n; - float *a = layer.filters; - float *b = layer.col_image; + float *a = l.filters; + float *b = l.col_image; float *c = state.delta + i*n*m; gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); @@ -181,15 +188,15 @@ void backward_deconvolutional_layer(deconvolutional_layer layer, network_state s } } -void update_deconvolutional_layer(deconvolutional_layer layer, float learning_rate, float momentum, float decay) +void update_deconvolutional_layer(deconvolutional_layer l, float learning_rate, float momentum, float decay) { - int size = layer.size*layer.size*layer.c*layer.n; - axpy_cpu(layer.n, learning_rate, layer.bias_updates, 1, layer.biases, 1); - scal_cpu(layer.n, momentum, layer.bias_updates, 1); + int size = l.size*l.size*l.c*l.n; + axpy_cpu(l.n, learning_rate, l.bias_updates, 1, l.biases, 1); + scal_cpu(l.n, momentum, l.bias_updates, 1); - axpy_cpu(size, -decay, layer.filters, 1, layer.filter_updates, 1); - axpy_cpu(size, learning_rate, layer.filter_updates, 1, layer.filters, 1); - scal_cpu(size, momentum, layer.filter_updates, 1); + axpy_cpu(size, -decay, l.filters, 1, l.filter_updates, 1); + axpy_cpu(size, learning_rate, l.filter_updates, 1, l.filters, 1); + scal_cpu(size, momentum, l.filter_updates, 1); } diff --git a/src/deconvolutional_layer.h b/src/deconvolutional_layer.h index 0ece76f2..74498c77 100644 --- a/src/deconvolutional_layer.h +++ b/src/deconvolutional_layer.h @@ -5,37 +5,9 @@ #include "params.h" #include "image.h" #include "activations.h" +#include "layer.h" -typedef struct { - int batch; - int h,w,c; - int n; - int size; - int stride; - float *filters; - float *filter_updates; - - float *biases; - float *bias_updates; - - float *col_image; - float *delta; - float *output; - - #ifdef GPU - float * filters_gpu; - float * filter_updates_gpu; - - float * biases_gpu; - float * bias_updates_gpu; - - float * col_image_gpu; - float * delta_gpu; - float * output_gpu; - #endif - - ACTIVATION activation; -} deconvolutional_layer; +typedef layer deconvolutional_layer; #ifdef GPU void forward_deconvolutional_layer_gpu(deconvolutional_layer layer, network_state state); @@ -45,7 +17,7 @@ void push_deconvolutional_layer(deconvolutional_layer layer); void pull_deconvolutional_layer(deconvolutional_layer layer); #endif -deconvolutional_layer *make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation); +deconvolutional_layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation); void resize_deconvolutional_layer(deconvolutional_layer *layer, int h, int w); void forward_deconvolutional_layer(const deconvolutional_layer layer, network_state state); void update_deconvolutional_layer(deconvolutional_layer layer, float learning_rate, float momentum, float decay); diff --git a/src/detection.c b/src/detection.c index dafececc..a1ba888c 100644 --- a/src/detection.c +++ b/src/detection.c @@ -115,6 +115,7 @@ void train_localization(char *cfgfile, char *weightfile) time=clock(); float loss = train_network(net, train); + //TODO float *out = get_network_output_gpu(net); image im = float_to_image(net.w, net.h, 3, train.X.vals[127]); image copy = copy_image(im); @@ -149,7 +150,7 @@ void train_detection_teststuff(char *cfgfile, char *weightfile) if(weightfile){ load_weights(&net, weightfile); } - detection_layer *layer = get_network_detection_layer(net); + detection_layer layer = get_network_detection_layer(net); net.learning_rate = 0; net.decay = 0; printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); @@ -157,9 +158,9 @@ void train_detection_teststuff(char *cfgfile, char *weightfile) int i = net.seen/imgs; data train, buffer; - int classes = layer->classes; - int background = layer->background; - int side = sqrt(get_detection_layer_locations(*layer)); + int classes = layer.classes; + int background = layer.background; + int side = sqrt(get_detection_layer_locations(layer)); char **paths; list *plist; @@ -174,7 +175,7 @@ void train_detection_teststuff(char *cfgfile, char *weightfile) paths = (char **)list_to_array(plist); pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer); clock_t time; - cost_layer clayer = *((cost_layer *)net.layers[net.n-1]); + cost_layer clayer = net.layers[net.n-1]; while(1){ i += 1; time=clock(); @@ -235,15 +236,15 @@ void train_detection(char *cfgfile, char *weightfile) if(weightfile){ load_weights(&net, weightfile); } - detection_layer *layer = get_network_detection_layer(net); + detection_layer layer = get_network_detection_layer(net); printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); int imgs = 128; int i = net.seen/imgs; data train, buffer; - int classes = layer->classes; - int background = layer->background; - int side = sqrt(get_detection_layer_locations(*layer)); + int classes = layer.classes; + int background = layer.background; + int side = sqrt(get_detection_layer_locations(layer)); char **paths; list *plist; @@ -325,7 +326,7 @@ void validate_detection(char *cfgfile, char *weightfile) if(weightfile){ load_weights(&net, weightfile); } - detection_layer *layer = get_network_detection_layer(net); + detection_layer layer = get_network_detection_layer(net); fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); srand(time(0)); @@ -336,10 +337,10 @@ void validate_detection(char *cfgfile, char *weightfile) //list *plist = get_paths("/home/pjreddie/data/voc/train.txt"); char **paths = (char **)list_to_array(plist); - int classes = layer->classes; - int nuisance = layer->nuisance; - int background = (layer->background && !nuisance); - int num_boxes = sqrt(get_detection_layer_locations(*layer)); + int classes = layer.classes; + int nuisance = layer.nuisance; + int background = (layer.background && !nuisance); + int num_boxes = sqrt(get_detection_layer_locations(layer)); int per_box = 4+classes+background+nuisance; int num_output = num_boxes*num_boxes*per_box; @@ -393,7 +394,7 @@ void validate_detection_post(char *cfgfile, char *weightfile) load_weights(&post, "/home/pjreddie/imagenet_backup/localize_1000.weights"); set_batch_network(&post, 1); - detection_layer *layer = get_network_detection_layer(net); + detection_layer layer = get_network_detection_layer(net); fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); srand(time(0)); @@ -404,10 +405,10 @@ void validate_detection_post(char *cfgfile, char *weightfile) //list *plist = get_paths("/home/pjreddie/data/voc/train.txt"); char **paths = (char **)list_to_array(plist); - int classes = layer->classes; - int nuisance = layer->nuisance; - int background = (layer->background && !nuisance); - int num_boxes = sqrt(get_detection_layer_locations(*layer)); + int classes = layer.classes; + int nuisance = layer.nuisance; + int background = (layer.background && !nuisance); + int num_boxes = sqrt(get_detection_layer_locations(layer)); int per_box = 4+classes+background+nuisance; diff --git a/src/detection_layer.c b/src/detection_layer.c index 831439e6..395146b3 100644 --- a/src/detection_layer.c +++ b/src/detection_layer.c @@ -8,47 +8,49 @@ #include #include -int get_detection_layer_locations(detection_layer layer) +int get_detection_layer_locations(detection_layer l) { - return layer.inputs / (layer.classes+layer.coords+layer.rescore+layer.background); + return l.inputs / (l.classes+l.coords+l.rescore+l.background); } -int get_detection_layer_output_size(detection_layer layer) +int get_detection_layer_output_size(detection_layer l) { - return get_detection_layer_locations(layer)*(layer.background + layer.classes + layer.coords); + return get_detection_layer_locations(l)*(l.background + l.classes + l.coords); } -detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background, int nuisance) +detection_layer make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background, int nuisance) { - detection_layer *layer = calloc(1, sizeof(detection_layer)); + detection_layer l = {0}; + l.type = DETECTION; - layer->batch = batch; - layer->inputs = inputs; - layer->classes = classes; - layer->coords = coords; - layer->rescore = rescore; - layer->nuisance = nuisance; - layer->cost = calloc(1, sizeof(float)); - layer->does_cost=1; - layer->background = background; - int outputs = get_detection_layer_output_size(*layer); - layer->output = calloc(batch*outputs, sizeof(float)); - layer->delta = calloc(batch*outputs, sizeof(float)); + l.batch = batch; + l.inputs = inputs; + l.classes = classes; + l.coords = coords; + l.rescore = rescore; + l.nuisance = nuisance; + l.cost = calloc(1, sizeof(float)); + l.does_cost=1; + l.background = background; + int outputs = get_detection_layer_output_size(l); + l.outputs = outputs; + l.output = calloc(batch*outputs, sizeof(float)); + l.delta = calloc(batch*outputs, sizeof(float)); #ifdef GPU - layer->output_gpu = cuda_make_array(0, batch*outputs); - layer->delta_gpu = cuda_make_array(0, batch*outputs); + l.output_gpu = cuda_make_array(0, batch*outputs); + l.delta_gpu = cuda_make_array(0, batch*outputs); #endif fprintf(stderr, "Detection Layer\n"); srand(0); - return layer; + return l; } -void dark_zone(detection_layer layer, int class, int start, network_state state) +void dark_zone(detection_layer l, int class, int start, network_state state) { - int index = start+layer.background+class; - int size = layer.classes+layer.coords+layer.background; + int index = start+l.background+class; + int size = l.classes+l.coords+l.background; int location = (index%(7*7*size)) / size ; int r = location / 7; int c = location % 7; @@ -60,9 +62,9 @@ void dark_zone(detection_layer layer, int class, int start, network_state state) if((c + dc) > 6 || (c + dc) < 0) continue; int di = (dr*7 + dc) * size; if(state.truth[index+di]) continue; - layer.output[index + di] = 0; + l.output[index + di] = 0; //if(!state.truth[start+di]) continue; - //layer.output[start + di] = 1; + //l.output[start + di] = 1; } } } @@ -299,47 +301,47 @@ dbox diou(box a, box b) return dd; } -void forward_detection_layer(const detection_layer layer, network_state state) +void forward_detection_layer(const detection_layer l, network_state state) { int in_i = 0; int out_i = 0; - int locations = get_detection_layer_locations(layer); + int locations = get_detection_layer_locations(l); int i,j; - for(i = 0; i < layer.batch*locations; ++i){ - int mask = (!state.truth || state.truth[out_i + layer.background + layer.classes + 2]); + for(i = 0; i < l.batch*locations; ++i){ + int mask = (!state.truth || state.truth[out_i + l.background + l.classes + 2]); float scale = 1; - if(layer.rescore) scale = state.input[in_i++]; - else if(layer.nuisance){ - layer.output[out_i++] = 1-state.input[in_i++]; + if(l.rescore) scale = state.input[in_i++]; + else if(l.nuisance){ + l.output[out_i++] = 1-state.input[in_i++]; scale = mask; } - else if(layer.background) layer.output[out_i++] = scale*state.input[in_i++]; + else if(l.background) l.output[out_i++] = scale*state.input[in_i++]; - for(j = 0; j < layer.classes; ++j){ - layer.output[out_i++] = scale*state.input[in_i++]; + for(j = 0; j < l.classes; ++j){ + l.output[out_i++] = scale*state.input[in_i++]; } - if(layer.nuisance){ + if(l.nuisance){ - }else if(layer.background){ - softmax_array(layer.output + out_i - layer.classes-layer.background, layer.classes+layer.background, layer.output + out_i - layer.classes-layer.background); - activate_array(state.input+in_i, layer.coords, LOGISTIC); + }else if(l.background){ + softmax_array(l.output + out_i - l.classes-l.background, l.classes+l.background, l.output + out_i - l.classes-l.background); + activate_array(state.input+in_i, l.coords, LOGISTIC); } - for(j = 0; j < layer.coords; ++j){ - layer.output[out_i++] = mask*state.input[in_i++]; + for(j = 0; j < l.coords; ++j){ + l.output[out_i++] = mask*state.input[in_i++]; } } - if(layer.does_cost && state.train && 0){ + if(l.does_cost && state.train && 0){ int count = 0; float avg = 0; - *(layer.cost) = 0; - int size = get_detection_layer_output_size(layer) * layer.batch; - memset(layer.delta, 0, size * sizeof(float)); - for (i = 0; i < layer.batch*locations; ++i) { - int classes = layer.nuisance+layer.classes; - int offset = i*(classes+layer.coords); + *(l.cost) = 0; + int size = get_detection_layer_output_size(l) * l.batch; + memset(l.delta, 0, size * sizeof(float)); + for (i = 0; i < l.batch*locations; ++i) { + int classes = l.nuisance+l.classes; + int offset = i*(classes+l.coords); for (j = offset; j < offset+classes; ++j) { - *(layer.cost) += pow(state.truth[j] - layer.output[j], 2); - layer.delta[j] = state.truth[j] - layer.output[j]; + *(l.cost) += pow(state.truth[j] - l.output[j], 2); + l.delta[j] = state.truth[j] - l.output[j]; } box truth; truth.x = state.truth[j+0]; @@ -347,17 +349,17 @@ void forward_detection_layer(const detection_layer layer, network_state state) truth.w = state.truth[j+2]; truth.h = state.truth[j+3]; box out; - out.x = layer.output[j+0]; - out.y = layer.output[j+1]; - out.w = layer.output[j+2]; - out.h = layer.output[j+3]; + out.x = l.output[j+0]; + out.y = l.output[j+1]; + out.w = l.output[j+2]; + out.h = l.output[j+3]; if(!(truth.w*truth.h)) continue; //printf("iou: %f\n", iou); dbox d = diou(out, truth); - layer.delta[j+0] = d.dx; - layer.delta[j+1] = d.dy; - layer.delta[j+2] = d.dw; - layer.delta[j+3] = d.dh; + l.delta[j+0] = d.dx; + l.delta[j+1] = d.dy; + l.delta[j+2] = d.dw; + l.delta[j+3] = d.dh; int sqr = 1; if(sqr){ @@ -367,7 +369,7 @@ void forward_detection_layer(const detection_layer layer, network_state state) out.h *= out.h; } float iou = box_iou(truth, out); - *(layer.cost) += pow((1-iou), 2); + *(l.cost) += pow((1-iou), 2); avg += iou; ++count; } @@ -375,24 +377,24 @@ void forward_detection_layer(const detection_layer layer, network_state state) } /* int count = 0; - for(i = 0; i < layer.batch*locations; ++i){ - for(j = 0; j < layer.classes+layer.background; ++j){ - printf("%f, ", layer.output[count++]); + for(i = 0; i < l.batch*locations; ++i){ + for(j = 0; j < l.classes+l.background; ++j){ + printf("%f, ", l.output[count++]); } printf("\n"); - for(j = 0; j < layer.coords; ++j){ - printf("%f, ", layer.output[count++]); + for(j = 0; j < l.coords; ++j){ + printf("%f, ", l.output[count++]); } printf("\n"); } */ /* - if(layer.background || 1){ - for(i = 0; i < layer.batch*locations; ++i){ - int index = i*(layer.classes+layer.coords+layer.background); - for(j= 0; j < layer.classes; ++j){ - if(state.truth[index+j+layer.background]){ -//dark_zone(layer, j, index, state); + if(l.background || 1){ + for(i = 0; i < l.batch*locations; ++i){ + int index = i*(l.classes+l.coords+l.background); + for(j= 0; j < l.classes; ++j){ + if(state.truth[index+j+l.background]){ +//dark_zone(l, j, index, state); } } } @@ -400,66 +402,66 @@ void forward_detection_layer(const detection_layer layer, network_state state) */ } -void backward_detection_layer(const detection_layer layer, network_state state) +void backward_detection_layer(const detection_layer l, network_state state) { - int locations = get_detection_layer_locations(layer); + int locations = get_detection_layer_locations(l); int i,j; int in_i = 0; int out_i = 0; - for(i = 0; i < layer.batch*locations; ++i){ + for(i = 0; i < l.batch*locations; ++i){ float scale = 1; float latent_delta = 0; - if(layer.rescore) scale = state.input[in_i++]; - else if (layer.nuisance) state.delta[in_i++] = -layer.delta[out_i++]; - else if (layer.background) state.delta[in_i++] = scale*layer.delta[out_i++]; - for(j = 0; j < layer.classes; ++j){ - latent_delta += state.input[in_i]*layer.delta[out_i]; - state.delta[in_i++] = scale*layer.delta[out_i++]; + if(l.rescore) scale = state.input[in_i++]; + else if (l.nuisance) state.delta[in_i++] = -l.delta[out_i++]; + else if (l.background) state.delta[in_i++] = scale*l.delta[out_i++]; + for(j = 0; j < l.classes; ++j){ + latent_delta += state.input[in_i]*l.delta[out_i]; + state.delta[in_i++] = scale*l.delta[out_i++]; } - if (layer.nuisance) { + if (l.nuisance) { - }else if (layer.background) gradient_array(layer.output + out_i, layer.coords, LOGISTIC, layer.delta + out_i); - for(j = 0; j < layer.coords; ++j){ - state.delta[in_i++] = layer.delta[out_i++]; + }else if (l.background) gradient_array(l.output + out_i, l.coords, LOGISTIC, l.delta + out_i); + for(j = 0; j < l.coords; ++j){ + state.delta[in_i++] = l.delta[out_i++]; } - if(layer.rescore) state.delta[in_i-layer.coords-layer.classes-layer.rescore-layer.background] = latent_delta; + if(l.rescore) state.delta[in_i-l.coords-l.classes-l.rescore-l.background] = latent_delta; } } #ifdef GPU -void forward_detection_layer_gpu(const detection_layer layer, network_state state) +void forward_detection_layer_gpu(const detection_layer l, network_state state) { - int outputs = get_detection_layer_output_size(layer); - float *in_cpu = calloc(layer.batch*layer.inputs, sizeof(float)); + int outputs = get_detection_layer_output_size(l); + float *in_cpu = calloc(l.batch*l.inputs, sizeof(float)); float *truth_cpu = 0; if(state.truth){ - truth_cpu = calloc(layer.batch*outputs, sizeof(float)); - cuda_pull_array(state.truth, truth_cpu, layer.batch*outputs); + truth_cpu = calloc(l.batch*outputs, sizeof(float)); + cuda_pull_array(state.truth, truth_cpu, l.batch*outputs); } - cuda_pull_array(state.input, in_cpu, layer.batch*layer.inputs); + cuda_pull_array(state.input, in_cpu, l.batch*l.inputs); network_state cpu_state; cpu_state.train = state.train; cpu_state.truth = truth_cpu; cpu_state.input = in_cpu; - forward_detection_layer(layer, cpu_state); - cuda_push_array(layer.output_gpu, layer.output, layer.batch*outputs); - cuda_push_array(layer.delta_gpu, layer.delta, layer.batch*outputs); + forward_detection_layer(l, cpu_state); + cuda_push_array(l.output_gpu, l.output, l.batch*outputs); + cuda_push_array(l.delta_gpu, l.delta, l.batch*outputs); free(cpu_state.input); if(cpu_state.truth) free(cpu_state.truth); } -void backward_detection_layer_gpu(detection_layer layer, network_state state) +void backward_detection_layer_gpu(detection_layer l, network_state state) { - int outputs = get_detection_layer_output_size(layer); + int outputs = get_detection_layer_output_size(l); - float *in_cpu = calloc(layer.batch*layer.inputs, sizeof(float)); - float *delta_cpu = calloc(layer.batch*layer.inputs, sizeof(float)); + float *in_cpu = calloc(l.batch*l.inputs, sizeof(float)); + float *delta_cpu = calloc(l.batch*l.inputs, sizeof(float)); float *truth_cpu = 0; if(state.truth){ - truth_cpu = calloc(layer.batch*outputs, sizeof(float)); - cuda_pull_array(state.truth, truth_cpu, layer.batch*outputs); + truth_cpu = calloc(l.batch*outputs, sizeof(float)); + cuda_pull_array(state.truth, truth_cpu, l.batch*outputs); } network_state cpu_state; cpu_state.train = state.train; @@ -467,10 +469,10 @@ void backward_detection_layer_gpu(detection_layer layer, network_state state) cpu_state.truth = truth_cpu; cpu_state.delta = delta_cpu; - cuda_pull_array(state.input, in_cpu, layer.batch*layer.inputs); - cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*outputs); - backward_detection_layer(layer, cpu_state); - cuda_push_array(state.delta, delta_cpu, layer.batch*layer.inputs); + cuda_pull_array(state.input, in_cpu, l.batch*l.inputs); + cuda_pull_array(l.delta_gpu, l.delta, l.batch*outputs); + backward_detection_layer(l, cpu_state); + cuda_push_array(state.delta, delta_cpu, l.batch*l.inputs); free(in_cpu); free(delta_cpu); diff --git a/src/detection_layer.h b/src/detection_layer.h index 0aa5f665..dfc5db97 100644 --- a/src/detection_layer.h +++ b/src/detection_layer.h @@ -2,34 +2,19 @@ #define DETECTION_LAYER_H #include "params.h" +#include "layer.h" -typedef struct { - int batch; - int inputs; - int classes; - int coords; - int background; - int rescore; - int nuisance; - int does_cost; - float *cost; - float *output; - float *delta; - #ifdef GPU - float * output_gpu; - float * delta_gpu; - #endif -} detection_layer; +typedef layer detection_layer; -detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background, int nuisance); -void forward_detection_layer(const detection_layer layer, network_state state); -void backward_detection_layer(const detection_layer layer, network_state state); -int get_detection_layer_output_size(detection_layer layer); -int get_detection_layer_locations(detection_layer layer); +detection_layer make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background, int nuisance); +void forward_detection_layer(const detection_layer l, network_state state); +void backward_detection_layer(const detection_layer l, network_state state); +int get_detection_layer_output_size(detection_layer l); +int get_detection_layer_locations(detection_layer l); #ifdef GPU -void forward_detection_layer_gpu(const detection_layer layer, network_state state); -void backward_detection_layer_gpu(detection_layer layer, network_state state); +void forward_detection_layer_gpu(const detection_layer l, network_state state); +void backward_detection_layer_gpu(detection_layer l, network_state state); #endif #endif diff --git a/src/dropout_layer.c b/src/dropout_layer.c index 7fbf8ff2..97dd47f5 100644 --- a/src/dropout_layer.c +++ b/src/dropout_layer.c @@ -5,51 +5,53 @@ #include #include -dropout_layer *make_dropout_layer(int batch, int inputs, float probability) +dropout_layer make_dropout_layer(int batch, int inputs, float probability) { fprintf(stderr, "Dropout Layer: %d inputs, %f probability\n", inputs, probability); - dropout_layer *layer = calloc(1, sizeof(dropout_layer)); - layer->probability = probability; - layer->inputs = inputs; - layer->batch = batch; - layer->rand = calloc(inputs*batch, sizeof(float)); - layer->scale = 1./(1.-probability); + dropout_layer l = {0}; + l.type = DROPOUT; + l.probability = probability; + l.inputs = inputs; + l.outputs = inputs; + l.batch = batch; + l.rand = calloc(inputs*batch, sizeof(float)); + l.scale = 1./(1.-probability); #ifdef GPU - layer->rand_gpu = cuda_make_array(layer->rand, inputs*batch); + l.rand_gpu = cuda_make_array(l.rand, inputs*batch); #endif - return layer; + return l; } -void resize_dropout_layer(dropout_layer *layer, int inputs) +void resize_dropout_layer(dropout_layer *l, int inputs) { - layer->rand = realloc(layer->rand, layer->inputs*layer->batch*sizeof(float)); + l->rand = realloc(l->rand, l->inputs*l->batch*sizeof(float)); #ifdef GPU - cuda_free(layer->rand_gpu); + cuda_free(l->rand_gpu); - layer->rand_gpu = cuda_make_array(layer->rand, inputs*layer->batch); + l->rand_gpu = cuda_make_array(l->rand, inputs*l->batch); #endif } -void forward_dropout_layer(dropout_layer layer, network_state state) +void forward_dropout_layer(dropout_layer l, network_state state) { int i; if (!state.train) return; - for(i = 0; i < layer.batch * layer.inputs; ++i){ + for(i = 0; i < l.batch * l.inputs; ++i){ float r = rand_uniform(); - layer.rand[i] = r; - if(r < layer.probability) state.input[i] = 0; - else state.input[i] *= layer.scale; + l.rand[i] = r; + if(r < l.probability) state.input[i] = 0; + else state.input[i] *= l.scale; } } -void backward_dropout_layer(dropout_layer layer, network_state state) +void backward_dropout_layer(dropout_layer l, network_state state) { int i; if(!state.delta) return; - for(i = 0; i < layer.batch * layer.inputs; ++i){ - float r = layer.rand[i]; - if(r < layer.probability) state.delta[i] = 0; - else state.delta[i] *= layer.scale; + for(i = 0; i < l.batch * l.inputs; ++i){ + float r = l.rand[i]; + if(r < l.probability) state.delta[i] = 0; + else state.delta[i] *= l.scale; } } diff --git a/src/dropout_layer.h b/src/dropout_layer.h index d12d4a18..b1dc883b 100644 --- a/src/dropout_layer.h +++ b/src/dropout_layer.h @@ -1,27 +1,20 @@ #ifndef DROPOUT_LAYER_H #define DROPOUT_LAYER_H + #include "params.h" +#include "layer.h" -typedef struct{ - int batch; - int inputs; - float probability; - float scale; - float *rand; - #ifdef GPU - float * rand_gpu; - #endif -} dropout_layer; +typedef layer dropout_layer; -dropout_layer *make_dropout_layer(int batch, int inputs, float probability); +dropout_layer make_dropout_layer(int batch, int inputs, float probability); -void forward_dropout_layer(dropout_layer layer, network_state state); -void backward_dropout_layer(dropout_layer layer, network_state state); -void resize_dropout_layer(dropout_layer *layer, int inputs); +void forward_dropout_layer(dropout_layer l, network_state state); +void backward_dropout_layer(dropout_layer l, network_state state); +void resize_dropout_layer(dropout_layer *l, int inputs); #ifdef GPU -void forward_dropout_layer_gpu(dropout_layer layer, network_state state); -void backward_dropout_layer_gpu(dropout_layer layer, network_state state); +void forward_dropout_layer_gpu(dropout_layer l, network_state state); +void backward_dropout_layer_gpu(dropout_layer l, network_state state); #endif #endif diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c index 76402fa7..c7739f1d 100644 --- a/src/maxpool_layer.c +++ b/src/maxpool_layer.c @@ -2,109 +2,115 @@ #include "cuda.h" #include -image get_maxpool_image(maxpool_layer layer) +image get_maxpool_image(maxpool_layer l) { - int h = (layer.h-1)/layer.stride + 1; - int w = (layer.w-1)/layer.stride + 1; - int c = layer.c; - return float_to_image(w,h,c,layer.output); + int h = (l.h-1)/l.stride + 1; + int w = (l.w-1)/l.stride + 1; + int c = l.c; + return float_to_image(w,h,c,l.output); } -image get_maxpool_delta(maxpool_layer layer) +image get_maxpool_delta(maxpool_layer l) { - int h = (layer.h-1)/layer.stride + 1; - int w = (layer.w-1)/layer.stride + 1; - int c = layer.c; - return float_to_image(w,h,c,layer.delta); + int h = (l.h-1)/l.stride + 1; + int w = (l.w-1)/l.stride + 1; + int c = l.c; + return float_to_image(w,h,c,l.delta); } -maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int size, int stride) +maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride) { fprintf(stderr, "Maxpool Layer: %d x %d x %d image, %d size, %d stride\n", h,w,c,size,stride); - maxpool_layer *layer = calloc(1, sizeof(maxpool_layer)); - layer->batch = batch; - layer->h = h; - layer->w = w; - layer->c = c; - layer->size = size; - layer->stride = stride; + maxpool_layer l = {0}; + l.type = MAXPOOL; + l.batch = batch; + l.h = h; + l.w = w; + l.c = c; + l.out_h = h; + l.out_w = w; + l.out_c = c; + l.outputs = l.out_h * l.out_w * l.out_c; + l.inputs = l.outputs; + l.size = size; + l.stride = stride; int output_size = ((h-1)/stride+1) * ((w-1)/stride+1) * c * batch; - layer->indexes = calloc(output_size, sizeof(int)); - layer->output = calloc(output_size, sizeof(float)); - layer->delta = calloc(output_size, sizeof(float)); + l.indexes = calloc(output_size, sizeof(int)); + l.output = calloc(output_size, sizeof(float)); + l.delta = calloc(output_size, sizeof(float)); #ifdef GPU - layer->indexes_gpu = cuda_make_int_array(output_size); - layer->output_gpu = cuda_make_array(layer->output, output_size); - layer->delta_gpu = cuda_make_array(layer->delta, output_size); + l.indexes_gpu = cuda_make_int_array(output_size); + l.output_gpu = cuda_make_array(l.output, output_size); + l.delta_gpu = cuda_make_array(l.delta, output_size); #endif - return layer; + return l; } -void resize_maxpool_layer(maxpool_layer *layer, int h, int w) +void resize_maxpool_layer(maxpool_layer *l, int h, int w) { - layer->h = h; - layer->w = w; - int output_size = ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * layer->c * layer->batch; - layer->output = realloc(layer->output, output_size * sizeof(float)); - layer->delta = realloc(layer->delta, output_size * sizeof(float)); + l->h = h; + l->w = w; + int output_size = ((h-1)/l->stride+1) * ((w-1)/l->stride+1) * l->c * l->batch; + l->output = realloc(l->output, output_size * sizeof(float)); + l->delta = realloc(l->delta, output_size * sizeof(float)); #ifdef GPU - cuda_free((float *)layer->indexes_gpu); - cuda_free(layer->output_gpu); - cuda_free(layer->delta_gpu); - layer->indexes_gpu = cuda_make_int_array(output_size); - layer->output_gpu = cuda_make_array(layer->output, output_size); - layer->delta_gpu = cuda_make_array(layer->delta, output_size); + cuda_free((float *)l->indexes_gpu); + cuda_free(l->output_gpu); + cuda_free(l->delta_gpu); + l->indexes_gpu = cuda_make_int_array(output_size); + l->output_gpu = cuda_make_array(l->output, output_size); + l->delta_gpu = cuda_make_array(l->delta, output_size); #endif } -void forward_maxpool_layer(const maxpool_layer layer, network_state state) +void forward_maxpool_layer(const maxpool_layer l, network_state state) { - int b,i,j,k,l,m; - int w_offset = (-layer.size-1)/2 + 1; - int h_offset = (-layer.size-1)/2 + 1; + int b,i,j,k,m,n; + int w_offset = (-l.size-1)/2 + 1; + int h_offset = (-l.size-1)/2 + 1; - int h = (layer.h-1)/layer.stride + 1; - int w = (layer.w-1)/layer.stride + 1; - int c = layer.c; + int h = (l.h-1)/l.stride + 1; + int w = (l.w-1)/l.stride + 1; + int c = l.c; - for(b = 0; b < layer.batch; ++b){ + for(b = 0; b < l.batch; ++b){ for(k = 0; k < c; ++k){ for(i = 0; i < h; ++i){ for(j = 0; j < w; ++j){ int out_index = j + w*(i + h*(k + c*b)); float max = -FLT_MAX; int max_i = -1; - for(l = 0; l < layer.size; ++l){ - for(m = 0; m < layer.size; ++m){ - int cur_h = h_offset + i*layer.stride + l; - int cur_w = w_offset + j*layer.stride + m; - int index = cur_w + layer.w*(cur_h + layer.h*(k + b*layer.c)); - int valid = (cur_h >= 0 && cur_h < layer.h && - cur_w >= 0 && cur_w < layer.w); + for(n = 0; n < l.size; ++n){ + for(m = 0; m < l.size; ++m){ + int cur_h = h_offset + i*l.stride + n; + int cur_w = w_offset + j*l.stride + m; + int index = cur_w + l.w*(cur_h + l.h*(k + b*l.c)); + int valid = (cur_h >= 0 && cur_h < l.h && + cur_w >= 0 && cur_w < l.w); float val = (valid != 0) ? state.input[index] : -FLT_MAX; max_i = (val > max) ? index : max_i; max = (val > max) ? val : max; } } - layer.output[out_index] = max; - layer.indexes[out_index] = max_i; + l.output[out_index] = max; + l.indexes[out_index] = max_i; } } } } } -void backward_maxpool_layer(const maxpool_layer layer, network_state state) +void backward_maxpool_layer(const maxpool_layer l, network_state state) { int i; - int h = (layer.h-1)/layer.stride + 1; - int w = (layer.w-1)/layer.stride + 1; - int c = layer.c; - memset(state.delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float)); - for(i = 0; i < h*w*c*layer.batch; ++i){ - int index = layer.indexes[i]; - state.delta[index] += layer.delta[i]; + int h = (l.h-1)/l.stride + 1; + int w = (l.w-1)/l.stride + 1; + int c = l.c; + memset(state.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float)); + for(i = 0; i < h*w*c*l.batch; ++i){ + int index = l.indexes[i]; + state.delta[index] += l.delta[i]; } } diff --git a/src/maxpool_layer.h b/src/maxpool_layer.h index cbd6a767..4456863b 100644 --- a/src/maxpool_layer.h +++ b/src/maxpool_layer.h @@ -4,31 +4,19 @@ #include "image.h" #include "params.h" #include "cuda.h" +#include "layer.h" -typedef struct { - int batch; - int h,w,c; - int stride; - int size; - int *indexes; - float *delta; - float *output; - #ifdef GPU - int *indexes_gpu; - float *output_gpu; - float *delta_gpu; - #endif -} maxpool_layer; +typedef layer maxpool_layer; -image get_maxpool_image(maxpool_layer layer); -maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int size, int stride); -void resize_maxpool_layer(maxpool_layer *layer, int h, int w); -void forward_maxpool_layer(const maxpool_layer layer, network_state state); -void backward_maxpool_layer(const maxpool_layer layer, network_state state); +image get_maxpool_image(maxpool_layer l); +maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride); +void resize_maxpool_layer(maxpool_layer *l, int h, int w); +void forward_maxpool_layer(const maxpool_layer l, network_state state); +void backward_maxpool_layer(const maxpool_layer l, network_state state); #ifdef GPU -void forward_maxpool_layer_gpu(maxpool_layer layer, network_state state); -void backward_maxpool_layer_gpu(maxpool_layer layer, network_state state); +void forward_maxpool_layer_gpu(maxpool_layer l, network_state state); +void backward_maxpool_layer_gpu(maxpool_layer l, network_state state); #endif #endif diff --git a/src/network.c b/src/network.c index 01a61288..68790e58 100644 --- a/src/network.c +++ b/src/network.c @@ -12,7 +12,6 @@ #include "detection_layer.h" #include "maxpool_layer.h" #include "cost_layer.h" -#include "normalization_layer.h" #include "softmax_layer.h" #include "dropout_layer.h" #include "route_layer.h" @@ -32,8 +31,6 @@ char *get_layer_string(LAYER_TYPE a) return "softmax"; case DETECTION: return "detection"; - case NORMALIZATION: - return "normalization"; case DROPOUT: return "dropout"; case CROP: @@ -50,16 +47,9 @@ char *get_layer_string(LAYER_TYPE a) network make_network(int n) { - network net; + network net = {0}; net.n = n; - net.layers = calloc(net.n, sizeof(void *)); - net.types = calloc(net.n, sizeof(LAYER_TYPE)); - net.outputs = 0; - net.output = 0; - net.seen = 0; - net.batch = 0; - net.inputs = 0; - net.h = net.w = net.c = 0; + net.layers = calloc(net.n, sizeof(layer)); #ifdef GPU net.input_gpu = calloc(1, sizeof(float *)); net.truth_gpu = calloc(1, sizeof(float *)); @@ -71,40 +61,29 @@ void forward_network(network net, network_state state) { int i; for(i = 0; i < net.n; ++i){ - if(net.types[i] == CONVOLUTIONAL){ - forward_convolutional_layer(*(convolutional_layer *)net.layers[i], state); + layer l = net.layers[i]; + if(l.type == CONVOLUTIONAL){ + forward_convolutional_layer(l, state); + } else if(l.type == DECONVOLUTIONAL){ + forward_deconvolutional_layer(l, state); + } else if(l.type == DETECTION){ + forward_detection_layer(l, state); + } else if(l.type == CONNECTED){ + forward_connected_layer(l, state); + } else if(l.type == CROP){ + forward_crop_layer(l, state); + } else if(l.type == COST){ + forward_cost_layer(l, state); + } else if(l.type == SOFTMAX){ + forward_softmax_layer(l, state); + } else if(l.type == MAXPOOL){ + forward_maxpool_layer(l, state); + } else if(l.type == DROPOUT){ + forward_dropout_layer(l, state); + } else if(l.type == ROUTE){ + forward_route_layer(l, net); } - else if(net.types[i] == DECONVOLUTIONAL){ - forward_deconvolutional_layer(*(deconvolutional_layer *)net.layers[i], state); - } - else if(net.types[i] == DETECTION){ - forward_detection_layer(*(detection_layer *)net.layers[i], state); - } - else if(net.types[i] == CONNECTED){ - forward_connected_layer(*(connected_layer *)net.layers[i], state); - } - else if(net.types[i] == CROP){ - forward_crop_layer(*(crop_layer *)net.layers[i], state); - } - else if(net.types[i] == COST){ - forward_cost_layer(*(cost_layer *)net.layers[i], state); - } - else if(net.types[i] == SOFTMAX){ - forward_softmax_layer(*(softmax_layer *)net.layers[i], state); - } - else if(net.types[i] == MAXPOOL){ - forward_maxpool_layer(*(maxpool_layer *)net.layers[i], state); - } - else if(net.types[i] == NORMALIZATION){ - forward_normalization_layer(*(normalization_layer *)net.layers[i], state); - } - else if(net.types[i] == DROPOUT){ - forward_dropout_layer(*(dropout_layer *)net.layers[i], state); - } - else if(net.types[i] == ROUTE){ - forward_route_layer(*(route_layer *)net.layers[i], net); - } - state.input = get_network_output_layer(net, i); + state.input = l.output; } } @@ -113,99 +92,35 @@ void update_network(network net) int i; int update_batch = net.batch*net.subdivisions; for(i = 0; i < net.n; ++i){ - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - update_convolutional_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay); - } - else if(net.types[i] == DECONVOLUTIONAL){ - deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; - update_deconvolutional_layer(layer, net.learning_rate, net.momentum, net.decay); - } - else if(net.types[i] == CONNECTED){ - connected_layer layer = *(connected_layer *)net.layers[i]; - update_connected_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay); + layer l = net.layers[i]; + if(l.type == CONVOLUTIONAL){ + update_convolutional_layer(l, update_batch, net.learning_rate, net.momentum, net.decay); + } else if(l.type == DECONVOLUTIONAL){ + update_deconvolutional_layer(l, net.learning_rate, net.momentum, net.decay); + } else if(l.type == CONNECTED){ + update_connected_layer(l, update_batch, net.learning_rate, net.momentum, net.decay); } } } -float *get_network_output_layer(network net, int i) -{ - if(net.types[i] == CONVOLUTIONAL){ - return ((convolutional_layer *)net.layers[i]) -> output; - } else if(net.types[i] == DECONVOLUTIONAL){ - return ((deconvolutional_layer *)net.layers[i]) -> output; - } else if(net.types[i] == MAXPOOL){ - return ((maxpool_layer *)net.layers[i]) -> output; - } else if(net.types[i] == DETECTION){ - return ((detection_layer *)net.layers[i]) -> output; - } else if(net.types[i] == SOFTMAX){ - return ((softmax_layer *)net.layers[i]) -> output; - } else if(net.types[i] == DROPOUT){ - return get_network_output_layer(net, i-1); - } else if(net.types[i] == CONNECTED){ - return ((connected_layer *)net.layers[i]) -> output; - } else if(net.types[i] == CROP){ - return ((crop_layer *)net.layers[i]) -> output; - } else if(net.types[i] == NORMALIZATION){ - return ((normalization_layer *)net.layers[i]) -> output; - } else if(net.types[i] == ROUTE){ - return ((route_layer *)net.layers[i]) -> output; - } - return 0; -} - float *get_network_output(network net) { int i; - for(i = net.n-1; i > 0; --i) if(net.types[i] != COST) break; - return get_network_output_layer(net, i); -} - -float *get_network_delta_layer(network net, int i) -{ - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - return layer.delta; - } else if(net.types[i] == DECONVOLUTIONAL){ - deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; - return layer.delta; - } else if(net.types[i] == MAXPOOL){ - maxpool_layer layer = *(maxpool_layer *)net.layers[i]; - return layer.delta; - } else if(net.types[i] == SOFTMAX){ - softmax_layer layer = *(softmax_layer *)net.layers[i]; - return layer.delta; - } else if(net.types[i] == DETECTION){ - detection_layer layer = *(detection_layer *)net.layers[i]; - return layer.delta; - } else if(net.types[i] == DROPOUT){ - if(i == 0) return 0; - return get_network_delta_layer(net, i-1); - } else if(net.types[i] == CONNECTED){ - connected_layer layer = *(connected_layer *)net.layers[i]; - return layer.delta; - } else if(net.types[i] == ROUTE){ - return ((route_layer *)net.layers[i]) -> delta; - } - return 0; + for(i = net.n-1; i > 0; --i) if(net.layers[i].type != COST) break; + return net.layers[i].output; } float get_network_cost(network net) { - if(net.types[net.n-1] == COST){ - return ((cost_layer *)net.layers[net.n-1])->output[0]; + if(net.layers[net.n-1].type == COST){ + return net.layers[net.n-1].output[0]; } - if(net.types[net.n-1] == DETECTION){ - return ((detection_layer *)net.layers[net.n-1])->cost[0]; + if(net.layers[net.n-1].type == DETECTION){ + return net.layers[net.n-1].cost[0]; } return 0; } -float *get_network_delta(network net) -{ - return get_network_delta_layer(net, net.n-1); -} - int get_predicted_class_network(network net) { float *out = get_network_output(net); @@ -222,46 +137,29 @@ void backward_network(network net, network_state state) state.input = original_input; state.delta = 0; }else{ - state.input = get_network_output_layer(net, i-1); - state.delta = get_network_delta_layer(net, i-1); + layer prev = net.layers[i-1]; + state.input = prev.output; + state.delta = prev.delta; } - - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - backward_convolutional_layer(layer, state); - } else if(net.types[i] == DECONVOLUTIONAL){ - deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; - backward_deconvolutional_layer(layer, state); - } - else if(net.types[i] == MAXPOOL){ - maxpool_layer layer = *(maxpool_layer *)net.layers[i]; - if(i != 0) backward_maxpool_layer(layer, state); - } - else if(net.types[i] == DROPOUT){ - dropout_layer layer = *(dropout_layer *)net.layers[i]; - backward_dropout_layer(layer, state); - } - else if(net.types[i] == DETECTION){ - detection_layer layer = *(detection_layer *)net.layers[i]; - backward_detection_layer(layer, state); - } - else if(net.types[i] == NORMALIZATION){ - normalization_layer layer = *(normalization_layer *)net.layers[i]; - if(i != 0) backward_normalization_layer(layer, state); - } - else if(net.types[i] == SOFTMAX){ - softmax_layer layer = *(softmax_layer *)net.layers[i]; - if(i != 0) backward_softmax_layer(layer, state); - } - else if(net.types[i] == CONNECTED){ - connected_layer layer = *(connected_layer *)net.layers[i]; - backward_connected_layer(layer, state); - } else if(net.types[i] == COST){ - cost_layer layer = *(cost_layer *)net.layers[i]; - backward_cost_layer(layer, state); - } else if(net.types[i] == ROUTE){ - route_layer layer = *(route_layer *)net.layers[i]; - backward_route_layer(layer, net); + layer l = net.layers[i]; + if(l.type == CONVOLUTIONAL){ + backward_convolutional_layer(l, state); + } else if(l.type == DECONVOLUTIONAL){ + backward_deconvolutional_layer(l, state); + } else if(l.type == MAXPOOL){ + if(i != 0) backward_maxpool_layer(l, state); + } else if(l.type == DROPOUT){ + backward_dropout_layer(l, state); + } else if(l.type == DETECTION){ + backward_detection_layer(l, state); + } else if(l.type == SOFTMAX){ + if(i != 0) backward_softmax_layer(l, state); + } else if(l.type == CONNECTED){ + backward_connected_layer(l, state); + } else if(l.type == COST){ + backward_cost_layer(l, state); + } else if(l.type == ROUTE){ + backward_route_layer(l, net); } } } @@ -347,127 +245,11 @@ void set_batch_network(network *net, int b) net->batch = b; int i; for(i = 0; i < net->n; ++i){ - if(net->types[i] == CONVOLUTIONAL){ - convolutional_layer *layer = (convolutional_layer *)net->layers[i]; - layer->batch = b; - }else if(net->types[i] == DECONVOLUTIONAL){ - deconvolutional_layer *layer = (deconvolutional_layer *)net->layers[i]; - layer->batch = b; - } - else if(net->types[i] == MAXPOOL){ - maxpool_layer *layer = (maxpool_layer *)net->layers[i]; - layer->batch = b; - } - else if(net->types[i] == CONNECTED){ - connected_layer *layer = (connected_layer *)net->layers[i]; - layer->batch = b; - } else if(net->types[i] == DROPOUT){ - dropout_layer *layer = (dropout_layer *) net->layers[i]; - layer->batch = b; - } else if(net->types[i] == DETECTION){ - detection_layer *layer = (detection_layer *) net->layers[i]; - layer->batch = b; - } - else if(net->types[i] == SOFTMAX){ - softmax_layer *layer = (softmax_layer *)net->layers[i]; - layer->batch = b; - } - else if(net->types[i] == COST){ - cost_layer *layer = (cost_layer *)net->layers[i]; - layer->batch = b; - } - else if(net->types[i] == CROP){ - crop_layer *layer = (crop_layer *)net->layers[i]; - layer->batch = b; - } - else if(net->types[i] == ROUTE){ - route_layer *layer = (route_layer *)net->layers[i]; - layer->batch = b; - } + net->layers[i].batch = b; } } - -int get_network_input_size_layer(network net, int i) -{ - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - return layer.h*layer.w*layer.c; - } - if(net.types[i] == DECONVOLUTIONAL){ - deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; - return layer.h*layer.w*layer.c; - } - else if(net.types[i] == MAXPOOL){ - maxpool_layer layer = *(maxpool_layer *)net.layers[i]; - return layer.h*layer.w*layer.c; - } - else if(net.types[i] == CONNECTED){ - connected_layer layer = *(connected_layer *)net.layers[i]; - return layer.inputs; - } else if(net.types[i] == DROPOUT){ - dropout_layer layer = *(dropout_layer *) net.layers[i]; - return layer.inputs; - } else if(net.types[i] == DETECTION){ - detection_layer layer = *(detection_layer *) net.layers[i]; - return layer.inputs; - } else if(net.types[i] == CROP){ - crop_layer layer = *(crop_layer *) net.layers[i]; - return layer.c*layer.h*layer.w; - } - else if(net.types[i] == SOFTMAX){ - softmax_layer layer = *(softmax_layer *)net.layers[i]; - return layer.inputs; - } - fprintf(stderr, "Can't find input size\n"); - return 0; -} - -int get_network_output_size_layer(network net, int i) -{ - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - image output = get_convolutional_image(layer); - return output.h*output.w*output.c; - } - else if(net.types[i] == DECONVOLUTIONAL){ - deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; - image output = get_deconvolutional_image(layer); - return output.h*output.w*output.c; - } - else if(net.types[i] == DETECTION){ - detection_layer layer = *(detection_layer *)net.layers[i]; - return get_detection_layer_output_size(layer); - } - else if(net.types[i] == MAXPOOL){ - maxpool_layer layer = *(maxpool_layer *)net.layers[i]; - image output = get_maxpool_image(layer); - return output.h*output.w*output.c; - } - else if(net.types[i] == CROP){ - crop_layer layer = *(crop_layer *) net.layers[i]; - return layer.c*layer.crop_height*layer.crop_width; - } - else if(net.types[i] == CONNECTED){ - connected_layer layer = *(connected_layer *)net.layers[i]; - return layer.outputs; - } - else if(net.types[i] == DROPOUT){ - dropout_layer layer = *(dropout_layer *) net.layers[i]; - return layer.inputs; - } - else if(net.types[i] == SOFTMAX){ - softmax_layer layer = *(softmax_layer *)net.layers[i]; - return layer.inputs; - } - else if(net.types[i] == ROUTE){ - route_layer layer = *(route_layer *)net.layers[i]; - return layer.outputs; - } - fprintf(stderr, "Can't find output size\n"); - return 0; -} - +/* int resize_network(network net, int h, int w, int c) { fprintf(stderr, "Might be broken, careful!!"); @@ -497,74 +279,47 @@ int resize_network(network net, int h, int w, int c) }else if(net.types[i] == DROPOUT){ dropout_layer *layer = (dropout_layer *)net.layers[i]; resize_dropout_layer(layer, h*w*c); - }else if(net.types[i] == NORMALIZATION){ - normalization_layer *layer = (normalization_layer *)net.layers[i]; - resize_normalization_layer(layer, h, w); - image output = get_normalization_image(*layer); - h = output.h; - w = output.w; - c = output.c; }else{ error("Cannot resize this type of layer"); } } return 0; } +*/ int get_network_output_size(network net) { int i; - for(i = net.n-1; i > 0; --i) if(net.types[i] != COST) break; - return get_network_output_size_layer(net, i); + for(i = net.n-1; i > 0; --i) if(net.layers[i].type != COST) break; + return net.layers[i].outputs; } int get_network_input_size(network net) { - return get_network_input_size_layer(net, 0); + return net.layers[0].inputs; } -detection_layer *get_network_detection_layer(network net) +detection_layer get_network_detection_layer(network net) { int i; for(i = 0; i < net.n; ++i){ - if(net.types[i] == DETECTION){ - detection_layer *layer = (detection_layer *)net.layers[i]; - return layer; + if(net.layers[i].type == DETECTION){ + return net.layers[i]; } } - return 0; + fprintf(stderr, "Detection layer not found!!\n"); + detection_layer l = {0}; + return l; } image get_network_image_layer(network net, int i) { - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - return get_convolutional_image(layer); + layer l = net.layers[i]; + if (l.out_w && l.out_h && l.out_c){ + return float_to_image(l.out_w, l.out_h, l.out_c, l.output); } - else if(net.types[i] == DECONVOLUTIONAL){ - deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; - return get_deconvolutional_image(layer); - } - else if(net.types[i] == MAXPOOL){ - maxpool_layer layer = *(maxpool_layer *)net.layers[i]; - return get_maxpool_image(layer); - } - else if(net.types[i] == NORMALIZATION){ - normalization_layer layer = *(normalization_layer *)net.layers[i]; - return get_normalization_image(layer); - } - else if(net.types[i] == DROPOUT){ - return get_network_image_layer(net, i-1); - } - else if(net.types[i] == CROP){ - crop_layer layer = *(crop_layer *)net.layers[i]; - return get_crop_image(layer); - } - else if(net.types[i] == ROUTE){ - route_layer layer = *(route_layer *)net.layers[i]; - return get_network_image_layer(net, layer.input_layers[0]); - } - return make_empty_image(0,0,0); + image def = {0}; + return def; } image get_network_image(network net) @@ -574,7 +329,8 @@ image get_network_image(network net) image m = get_network_image_layer(net, i); if(m.h != 0) return m; } - return make_empty_image(0,0,0); + image def = {0}; + return def; } void visualize_network(network net) @@ -582,16 +338,11 @@ void visualize_network(network net) image *prev = 0; int i; char buff[256]; - //show_image(get_network_image_layer(net, 0), "Crop"); for(i = 0; i < net.n; ++i){ sprintf(buff, "Layer %d", i); - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - prev = visualize_convolutional_layer(layer, buff, prev); - } - if(net.types[i] == NORMALIZATION){ - normalization_layer layer = *(normalization_layer *)net.layers[i]; - visualize_normalization_layer(layer, buff); + layer l = net.layers[i]; + if(l.type == CONVOLUTIONAL){ + prev = visualize_convolutional_layer(l, buff, prev); } } } @@ -672,36 +423,9 @@ void print_network(network net) { int i,j; for(i = 0; i < net.n; ++i){ - float *output = 0; - int n = 0; - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - output = layer.output; - image m = get_convolutional_image(layer); - n = m.h*m.w*m.c; - } - else if(net.types[i] == MAXPOOL){ - maxpool_layer layer = *(maxpool_layer *)net.layers[i]; - output = layer.output; - image m = get_maxpool_image(layer); - n = m.h*m.w*m.c; - } - else if(net.types[i] == CROP){ - crop_layer layer = *(crop_layer *)net.layers[i]; - output = layer.output; - image m = get_crop_image(layer); - n = m.h*m.w*m.c; - } - else if(net.types[i] == CONNECTED){ - connected_layer layer = *(connected_layer *)net.layers[i]; - output = layer.output; - n = layer.outputs; - } - else if(net.types[i] == SOFTMAX){ - softmax_layer layer = *(softmax_layer *)net.layers[i]; - output = layer.output; - n = layer.inputs; - } + layer l = net.layers[i]; + float *output = l.output; + int n = l.outputs; float mean = mean_array(output, n); float vari = variance_array(output, n); fprintf(stderr, "Layer %d - Mean: %f, Variance: %f\n",i,mean, vari); diff --git a/src/network.h b/src/network.h index 28eab69a..9a8033c8 100644 --- a/src/network.h +++ b/src/network.h @@ -4,22 +4,9 @@ #include "image.h" #include "detection_layer.h" +#include "layer.h" #include "data.h" -typedef enum { - CONVOLUTIONAL, - DECONVOLUTIONAL, - CONNECTED, - MAXPOOL, - SOFTMAX, - DETECTION, - NORMALIZATION, - DROPOUT, - CROP, - ROUTE, - COST -} LAYER_TYPE; - typedef struct { int n; int batch; @@ -28,8 +15,7 @@ typedef struct { float learning_rate; float momentum; float decay; - void **layers; - LAYER_TYPE *types; + layer *layers; int outputs; float *output; @@ -83,7 +69,7 @@ int resize_network(network net, int h, int w, int c); void set_batch_network(network *net, int b); int get_network_input_size(network net); float get_network_cost(network net); -detection_layer *get_network_detection_layer(network net); +detection_layer get_network_detection_layer(network net); int get_network_nuisance(network net); int get_network_background(network net); diff --git a/src/network_kernels.cu b/src/network_kernels.cu index 7ff5d15e..da21d63f 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -15,7 +15,6 @@ extern "C" { #include "deconvolutional_layer.h" #include "maxpool_layer.h" #include "cost_layer.h" -#include "normalization_layer.h" #include "softmax_layer.h" #include "dropout_layer.h" #include "route_layer.h" @@ -29,37 +28,29 @@ void forward_network_gpu(network net, network_state state) { int i; for(i = 0; i < net.n; ++i){ - if(net.types[i] == CONVOLUTIONAL){ - forward_convolutional_layer_gpu(*(convolutional_layer *)net.layers[i], state); + layer l = net.layers[i]; + if(l.type == CONVOLUTIONAL){ + forward_convolutional_layer_gpu(l, state); + } else if(l.type == DECONVOLUTIONAL){ + forward_deconvolutional_layer_gpu(l, state); + } else if(l.type == DETECTION){ + forward_detection_layer_gpu(l, state); + } else if(l.type == CONNECTED){ + forward_connected_layer_gpu(l, state); + } else if(l.type == CROP){ + forward_crop_layer_gpu(l, state); + } else if(l.type == COST){ + forward_cost_layer_gpu(l, state); + } else if(l.type == SOFTMAX){ + forward_softmax_layer_gpu(l, state); + } else if(l.type == MAXPOOL){ + forward_maxpool_layer_gpu(l, state); + } else if(l.type == DROPOUT){ + forward_dropout_layer_gpu(l, state); + } else if(l.type == ROUTE){ + forward_route_layer_gpu(l, net); } - else if(net.types[i] == DECONVOLUTIONAL){ - forward_deconvolutional_layer_gpu(*(deconvolutional_layer *)net.layers[i], state); - } - else if(net.types[i] == COST){ - forward_cost_layer_gpu(*(cost_layer *)net.layers[i], state); - } - else if(net.types[i] == CONNECTED){ - forward_connected_layer_gpu(*(connected_layer *)net.layers[i], state); - } - else if(net.types[i] == DETECTION){ - forward_detection_layer_gpu(*(detection_layer *)net.layers[i], state); - } - else if(net.types[i] == MAXPOOL){ - forward_maxpool_layer_gpu(*(maxpool_layer *)net.layers[i], state); - } - else if(net.types[i] == SOFTMAX){ - forward_softmax_layer_gpu(*(softmax_layer *)net.layers[i], state); - } - else if(net.types[i] == DROPOUT){ - forward_dropout_layer_gpu(*(dropout_layer *)net.layers[i], state); - } - else if(net.types[i] == CROP){ - forward_crop_layer_gpu(*(crop_layer *)net.layers[i], state); - } - else if(net.types[i] == ROUTE){ - forward_route_layer_gpu(*(route_layer *)net.layers[i], net); - } - state.input = get_network_output_gpu_layer(net, i); + state.input = l.output_gpu; } } @@ -68,40 +59,33 @@ void backward_network_gpu(network net, network_state state) int i; float * original_input = state.input; for(i = net.n-1; i >= 0; --i){ + layer l = net.layers[i]; if(i == 0){ state.input = original_input; state.delta = 0; }else{ - state.input = get_network_output_gpu_layer(net, i-1); - state.delta = get_network_delta_gpu_layer(net, i-1); + layer prev = net.layers[i-1]; + state.input = prev.output_gpu; + state.delta = prev.delta_gpu; } - - if(net.types[i] == CONVOLUTIONAL){ - backward_convolutional_layer_gpu(*(convolutional_layer *)net.layers[i], state); - } - else if(net.types[i] == DECONVOLUTIONAL){ - backward_deconvolutional_layer_gpu(*(deconvolutional_layer *)net.layers[i], state); - } - else if(net.types[i] == COST){ - backward_cost_layer_gpu(*(cost_layer *)net.layers[i], state); - } - else if(net.types[i] == CONNECTED){ - backward_connected_layer_gpu(*(connected_layer *)net.layers[i], state); - } - else if(net.types[i] == DETECTION){ - backward_detection_layer_gpu(*(detection_layer *)net.layers[i], state); - } - else if(net.types[i] == MAXPOOL){ - backward_maxpool_layer_gpu(*(maxpool_layer *)net.layers[i], state); - } - else if(net.types[i] == DROPOUT){ - backward_dropout_layer_gpu(*(dropout_layer *)net.layers[i], state); - } - else if(net.types[i] == SOFTMAX){ - backward_softmax_layer_gpu(*(softmax_layer *)net.layers[i], state); - } - else if(net.types[i] == ROUTE){ - backward_route_layer_gpu(*(route_layer *)net.layers[i], net); + if(l.type == CONVOLUTIONAL){ + backward_convolutional_layer_gpu(l, state); + } else if(l.type == DECONVOLUTIONAL){ + backward_deconvolutional_layer_gpu(l, state); + } else if(l.type == MAXPOOL){ + if(i != 0) backward_maxpool_layer_gpu(l, state); + } else if(l.type == DROPOUT){ + backward_dropout_layer_gpu(l, state); + } else if(l.type == DETECTION){ + backward_detection_layer_gpu(l, state); + } else if(l.type == SOFTMAX){ + if(i != 0) backward_softmax_layer_gpu(l, state); + } else if(l.type == CONNECTED){ + backward_connected_layer_gpu(l, state); + } else if(l.type == COST){ + backward_cost_layer_gpu(l, state); + } else if(l.type == ROUTE){ + backward_route_layer_gpu(l, net); } } } @@ -111,89 +95,17 @@ void update_network_gpu(network net) int i; int update_batch = net.batch*net.subdivisions; for(i = 0; i < net.n; ++i){ - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - update_convolutional_layer_gpu(layer, update_batch, net.learning_rate, net.momentum, net.decay); - } - else if(net.types[i] == DECONVOLUTIONAL){ - deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; - update_deconvolutional_layer_gpu(layer, net.learning_rate, net.momentum, net.decay); - } - else if(net.types[i] == CONNECTED){ - connected_layer layer = *(connected_layer *)net.layers[i]; - update_connected_layer_gpu(layer, update_batch, net.learning_rate, net.momentum, net.decay); + layer l = net.layers[i]; + if(l.type == CONVOLUTIONAL){ + update_convolutional_layer_gpu(l, update_batch, net.learning_rate, net.momentum, net.decay); + } else if(l.type == DECONVOLUTIONAL){ + update_deconvolutional_layer_gpu(l, net.learning_rate, net.momentum, net.decay); + } else if(l.type == CONNECTED){ + update_connected_layer_gpu(l, update_batch, net.learning_rate, net.momentum, net.decay); } } } -float * get_network_output_gpu_layer(network net, int i) -{ - if(net.types[i] == CONVOLUTIONAL){ - return ((convolutional_layer *)net.layers[i]) -> output_gpu; - } - else if(net.types[i] == DECONVOLUTIONAL){ - return ((deconvolutional_layer *)net.layers[i]) -> output_gpu; - } - else if(net.types[i] == DETECTION){ - return ((detection_layer *)net.layers[i]) -> output_gpu; - } - else if(net.types[i] == CONNECTED){ - return ((connected_layer *)net.layers[i]) -> output_gpu; - } - else if(net.types[i] == MAXPOOL){ - return ((maxpool_layer *)net.layers[i]) -> output_gpu; - } - else if(net.types[i] == CROP){ - return ((crop_layer *)net.layers[i]) -> output_gpu; - } - else if(net.types[i] == SOFTMAX){ - return ((softmax_layer *)net.layers[i]) -> output_gpu; - } - else if(net.types[i] == ROUTE){ - return ((route_layer *)net.layers[i]) -> output_gpu; - } - else if(net.types[i] == DROPOUT){ - return get_network_output_gpu_layer(net, i-1); - } - return 0; -} - -float * get_network_delta_gpu_layer(network net, int i) -{ - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - return layer.delta_gpu; - } - else if(net.types[i] == DETECTION){ - detection_layer layer = *(detection_layer *)net.layers[i]; - return layer.delta_gpu; - } - else if(net.types[i] == DECONVOLUTIONAL){ - deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; - return layer.delta_gpu; - } - else if(net.types[i] == CONNECTED){ - connected_layer layer = *(connected_layer *)net.layers[i]; - return layer.delta_gpu; - } - else if(net.types[i] == MAXPOOL){ - maxpool_layer layer = *(maxpool_layer *)net.layers[i]; - return layer.delta_gpu; - } - else if(net.types[i] == ROUTE){ - route_layer layer = *(route_layer *)net.layers[i]; - return layer.delta_gpu; - } - else if(net.types[i] == SOFTMAX){ - softmax_layer layer = *(softmax_layer *)net.layers[i]; - return layer.delta_gpu; - } else if(net.types[i] == DROPOUT){ - if(i == 0) return 0; - return get_network_delta_gpu_layer(net, i-1); - } - return 0; -} - float train_network_datum_gpu(network net, float *x, float *y) { network_state state; @@ -219,33 +131,22 @@ float train_network_datum_gpu(network net, float *x, float *y) float *get_network_output_layer_gpu(network net, int i) { - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - return layer.output; - } - else if(net.types[i] == DECONVOLUTIONAL){ - deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; - return layer.output; - } - else if(net.types[i] == CONNECTED){ - connected_layer layer = *(connected_layer *)net.layers[i]; - cuda_pull_array(layer.output_gpu, layer.output, layer.outputs*layer.batch); - return layer.output; - } - else if(net.types[i] == DETECTION){ - detection_layer layer = *(detection_layer *)net.layers[i]; - int outputs = get_detection_layer_output_size(layer); - cuda_pull_array(layer.output_gpu, layer.output, outputs*layer.batch); - return layer.output; - } - else if(net.types[i] == MAXPOOL){ - maxpool_layer layer = *(maxpool_layer *)net.layers[i]; - return layer.output; - } - else if(net.types[i] == SOFTMAX){ - softmax_layer layer = *(softmax_layer *)net.layers[i]; - pull_softmax_layer_output(layer); - return layer.output; + layer l = net.layers[i]; + if(l.type == CONVOLUTIONAL){ + return l.output; + } else if(l.type == DECONVOLUTIONAL){ + return l.output; + } else if(l.type == CONNECTED){ + cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch); + return l.output; + } else if(l.type == DETECTION){ + cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch); + return l.output; + } else if(l.type == MAXPOOL){ + return l.output; + } else if(l.type == SOFTMAX){ + pull_softmax_layer_output(l); + return l.output; } return 0; } @@ -253,7 +154,7 @@ float *get_network_output_layer_gpu(network net, int i) float *get_network_output_gpu(network net) { int i; - for(i = net.n-1; i > 0; --i) if(net.types[i] != COST) break; + for(i = net.n-1; i > 0; --i) if(net.layers[i].type != COST) break; return get_network_output_layer_gpu(net, i); } diff --git a/src/normalization_layer.c b/src/normalization_layer.c deleted file mode 100644 index 93c2ad94..00000000 --- a/src/normalization_layer.c +++ /dev/null @@ -1,96 +0,0 @@ -#include "normalization_layer.h" -#include - -image get_normalization_image(normalization_layer layer) -{ - int h = layer.h; - int w = layer.w; - int c = layer.c; - return float_to_image(w,h,c,layer.output); -} - -image get_normalization_delta(normalization_layer layer) -{ - int h = layer.h; - int w = layer.w; - int c = layer.c; - return float_to_image(w,h,c,layer.delta); -} - -normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa) -{ - fprintf(stderr, "Local Response Normalization Layer: %d x %d x %d image, %d size\n", h,w,c,size); - normalization_layer *layer = calloc(1, sizeof(normalization_layer)); - layer->batch = batch; - layer->h = h; - layer->w = w; - layer->c = c; - layer->kappa = kappa; - layer->size = size; - layer->alpha = alpha; - layer->beta = beta; - layer->output = calloc(h * w * c * batch, sizeof(float)); - layer->delta = calloc(h * w * c * batch, sizeof(float)); - layer->sums = calloc(h*w, sizeof(float)); - return layer; -} - -void resize_normalization_layer(normalization_layer *layer, int h, int w) -{ - layer->h = h; - layer->w = w; - layer->output = realloc(layer->output, h * w * layer->c * layer->batch * sizeof(float)); - layer->delta = realloc(layer->delta, h * w * layer->c * layer->batch * sizeof(float)); - layer->sums = realloc(layer->sums, h*w * sizeof(float)); -} - -void add_square_array(float *src, float *dest, int n) -{ - int i; - for(i = 0; i < n; ++i){ - dest[i] += src[i]*src[i]; - } -} -void sub_square_array(float *src, float *dest, int n) -{ - int i; - for(i = 0; i < n; ++i){ - dest[i] -= src[i]*src[i]; - } -} - -void forward_normalization_layer(const normalization_layer layer, network_state state) -{ - int i,j,k; - memset(layer.sums, 0, layer.h*layer.w*sizeof(float)); - int imsize = layer.h*layer.w; - for(j = 0; j < layer.size/2; ++j){ - if(j < layer.c) add_square_array(state.input+j*imsize, layer.sums, imsize); - } - for(k = 0; k < layer.c; ++k){ - int next = k+layer.size/2; - int prev = k-layer.size/2-1; - if(next < layer.c) add_square_array(state.input+next*imsize, layer.sums, imsize); - if(prev > 0) sub_square_array(state.input+prev*imsize, layer.sums, imsize); - for(i = 0; i < imsize; ++i){ - layer.output[k*imsize + i] = state.input[k*imsize+i] / pow(layer.kappa + layer.alpha * layer.sums[i], layer.beta); - } - } -} - -void backward_normalization_layer(const normalization_layer layer, network_state state) -{ - // TODO! - // OR NOT TODO!! -} - -void visualize_normalization_layer(normalization_layer layer, char *window) -{ - image delta = get_normalization_image(layer); - image dc = collapse_image_layers(delta, 1); - char buff[256]; - sprintf(buff, "%s: Output", window); - show_image(dc, buff); - save_image(dc, buff); - free_image(dc); -} diff --git a/src/normalization_layer.h b/src/normalization_layer.h deleted file mode 100644 index 11f2827d..00000000 --- a/src/normalization_layer.h +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef NORMALIZATION_LAYER_H -#define NORMALIZATION_LAYER_H - -#include "image.h" -#include "params.h" - -typedef struct { - int batch; - int h,w,c; - int size; - float alpha; - float beta; - float kappa; - float *delta; - float *output; - float *sums; -} normalization_layer; - -image get_normalization_image(normalization_layer layer); -normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa); -void resize_normalization_layer(normalization_layer *layer, int h, int w); -void forward_normalization_layer(const normalization_layer layer, network_state state); -void backward_normalization_layer(const normalization_layer layer, network_state state); -void visualize_normalization_layer(normalization_layer layer, char *window); - -#endif - diff --git a/src/old.c b/src/old.c index 13a9be71..52b87fbf 100644 --- a/src/old.c +++ b/src/old.c @@ -1,3 +1,254 @@ +void save_network(network net, char *filename) +{ + FILE *fp = fopen(filename, "w"); + if(!fp) file_error(filename); + int i; + for(i = 0; i < net.n; ++i) + { + if(net.types[i] == CONVOLUTIONAL) + print_convolutional_cfg(fp, (convolutional_layer *)net.layers[i], net, i); + else if(net.types[i] == DECONVOLUTIONAL) + print_deconvolutional_cfg(fp, (deconvolutional_layer *)net.layers[i], net, i); + else if(net.types[i] == CONNECTED) + print_connected_cfg(fp, (connected_layer *)net.layers[i], net, i); + else if(net.types[i] == CROP) + print_crop_cfg(fp, (crop_layer *)net.layers[i], net, i); + else if(net.types[i] == MAXPOOL) + print_maxpool_cfg(fp, (maxpool_layer *)net.layers[i], net, i); + else if(net.types[i] == DROPOUT) + print_dropout_cfg(fp, (dropout_layer *)net.layers[i], net, i); + else if(net.types[i] == SOFTMAX) + print_softmax_cfg(fp, (softmax_layer *)net.layers[i], net, i); + else if(net.types[i] == DETECTION) + print_detection_cfg(fp, (detection_layer *)net.layers[i], net, i); + else if(net.types[i] == COST) + print_cost_cfg(fp, (cost_layer *)net.layers[i], net, i); + } + fclose(fp); +} + +void print_convolutional_cfg(FILE *fp, convolutional_layer *l, network net, int count) +{ +#ifdef GPU + if(gpu_index >= 0) pull_convolutional_layer(*l); +#endif + int i; + fprintf(fp, "[convolutional]\n"); + fprintf(fp, "filters=%d\n" + "size=%d\n" + "stride=%d\n" + "pad=%d\n" + "activation=%s\n", + l->n, l->size, l->stride, l->pad, + get_activation_string(l->activation)); + fprintf(fp, "biases="); + for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]); + fprintf(fp, "\n"); + fprintf(fp, "weights="); + for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]); + fprintf(fp, "\n\n"); +} + +void print_deconvolutional_cfg(FILE *fp, deconvolutional_layer *l, network net, int count) +{ +#ifdef GPU + if(gpu_index >= 0) pull_deconvolutional_layer(*l); +#endif + int i; + fprintf(fp, "[deconvolutional]\n"); + fprintf(fp, "filters=%d\n" + "size=%d\n" + "stride=%d\n" + "activation=%s\n", + l->n, l->size, l->stride, + get_activation_string(l->activation)); + fprintf(fp, "biases="); + for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]); + fprintf(fp, "\n"); + fprintf(fp, "weights="); + for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]); + fprintf(fp, "\n\n"); +} + +void print_dropout_cfg(FILE *fp, dropout_layer *l, network net, int count) +{ + fprintf(fp, "[dropout]\n"); + fprintf(fp, "probability=%g\n\n", l->probability); +} + +void print_connected_cfg(FILE *fp, connected_layer *l, network net, int count) +{ +#ifdef GPU + if(gpu_index >= 0) pull_connected_layer(*l); +#endif + int i; + fprintf(fp, "[connected]\n"); + fprintf(fp, "output=%d\n" + "activation=%s\n", + l->outputs, + get_activation_string(l->activation)); + fprintf(fp, "biases="); + for(i = 0; i < l->outputs; ++i) fprintf(fp, "%g,", l->biases[i]); + fprintf(fp, "\n"); + fprintf(fp, "weights="); + for(i = 0; i < l->outputs*l->inputs; ++i) fprintf(fp, "%g,", l->weights[i]); + fprintf(fp, "\n\n"); +} + +void print_crop_cfg(FILE *fp, crop_layer *l, network net, int count) +{ + fprintf(fp, "[crop]\n"); + fprintf(fp, "crop_height=%d\ncrop_width=%d\nflip=%d\n\n", l->crop_height, l->crop_width, l->flip); +} + +void print_maxpool_cfg(FILE *fp, maxpool_layer *l, network net, int count) +{ + fprintf(fp, "[maxpool]\n"); + fprintf(fp, "size=%d\nstride=%d\n\n", l->size, l->stride); +} + +void print_softmax_cfg(FILE *fp, softmax_layer *l, network net, int count) +{ + fprintf(fp, "[softmax]\n"); + fprintf(fp, "\n"); +} + +void print_detection_cfg(FILE *fp, detection_layer *l, network net, int count) +{ + fprintf(fp, "[detection]\n"); + fprintf(fp, "classes=%d\ncoords=%d\nrescore=%d\nnuisance=%d\n", l->classes, l->coords, l->rescore, l->nuisance); + fprintf(fp, "\n"); +} + +void print_cost_cfg(FILE *fp, cost_layer *l, network net, int count) +{ + fprintf(fp, "[cost]\ntype=%s\n", get_cost_string(l->type)); + fprintf(fp, "\n"); +} + + +#ifndef NORMALIZATION_LAYER_H +#define NORMALIZATION_LAYER_H + +#include "image.h" +#include "params.h" + +typedef struct { + int batch; + int h,w,c; + int size; + float alpha; + float beta; + float kappa; + float *delta; + float *output; + float *sums; +} normalization_layer; + +image get_normalization_image(normalization_layer layer); +normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa); +void resize_normalization_layer(normalization_layer *layer, int h, int w); +void forward_normalization_layer(const normalization_layer layer, network_state state); +void backward_normalization_layer(const normalization_layer layer, network_state state); +void visualize_normalization_layer(normalization_layer layer, char *window); + +#endif +#include "normalization_layer.h" +#include + +image get_normalization_image(normalization_layer layer) +{ + int h = layer.h; + int w = layer.w; + int c = layer.c; + return float_to_image(w,h,c,layer.output); +} + +image get_normalization_delta(normalization_layer layer) +{ + int h = layer.h; + int w = layer.w; + int c = layer.c; + return float_to_image(w,h,c,layer.delta); +} + +normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa) +{ + fprintf(stderr, "Local Response Normalization Layer: %d x %d x %d image, %d size\n", h,w,c,size); + normalization_layer *layer = calloc(1, sizeof(normalization_layer)); + layer->batch = batch; + layer->h = h; + layer->w = w; + layer->c = c; + layer->kappa = kappa; + layer->size = size; + layer->alpha = alpha; + layer->beta = beta; + layer->output = calloc(h * w * c * batch, sizeof(float)); + layer->delta = calloc(h * w * c * batch, sizeof(float)); + layer->sums = calloc(h*w, sizeof(float)); + return layer; +} + +void resize_normalization_layer(normalization_layer *layer, int h, int w) +{ + layer->h = h; + layer->w = w; + layer->output = realloc(layer->output, h * w * layer->c * layer->batch * sizeof(float)); + layer->delta = realloc(layer->delta, h * w * layer->c * layer->batch * sizeof(float)); + layer->sums = realloc(layer->sums, h*w * sizeof(float)); +} + +void add_square_array(float *src, float *dest, int n) +{ + int i; + for(i = 0; i < n; ++i){ + dest[i] += src[i]*src[i]; + } +} +void sub_square_array(float *src, float *dest, int n) +{ + int i; + for(i = 0; i < n; ++i){ + dest[i] -= src[i]*src[i]; + } +} + +void forward_normalization_layer(const normalization_layer layer, network_state state) +{ + int i,j,k; + memset(layer.sums, 0, layer.h*layer.w*sizeof(float)); + int imsize = layer.h*layer.w; + for(j = 0; j < layer.size/2; ++j){ + if(j < layer.c) add_square_array(state.input+j*imsize, layer.sums, imsize); + } + for(k = 0; k < layer.c; ++k){ + int next = k+layer.size/2; + int prev = k-layer.size/2-1; + if(next < layer.c) add_square_array(state.input+next*imsize, layer.sums, imsize); + if(prev > 0) sub_square_array(state.input+prev*imsize, layer.sums, imsize); + for(i = 0; i < imsize; ++i){ + layer.output[k*imsize + i] = state.input[k*imsize+i] / pow(layer.kappa + layer.alpha * layer.sums[i], layer.beta); + } + } +} + +void backward_normalization_layer(const normalization_layer layer, network_state state) +{ + // TODO! + // OR NOT TODO!! +} + +void visualize_normalization_layer(normalization_layer layer, char *window) +{ + image delta = get_normalization_image(layer); + image dc = collapse_image_layers(delta, 1); + char buff[256]; + sprintf(buff, "%s: Output", window); + show_image(dc, buff); + save_image(dc, buff); + free_image(dc); +} void test_load() { diff --git a/src/parser.c b/src/parser.c index 46bd8ef8..48567a11 100644 --- a/src/parser.c +++ b/src/parser.c @@ -10,7 +10,6 @@ #include "deconvolutional_layer.h" #include "connected_layer.h" #include "maxpool_layer.h" -#include "normalization_layer.h" #include "softmax_layer.h" #include "dropout_layer.h" #include "detection_layer.h" @@ -34,7 +33,6 @@ int is_softmax(section *s); int is_crop(section *s); int is_cost(section *s); int is_detection(section *s); -int is_normalization(section *s); int is_route(section *s); list *read_cfg(char *filename); @@ -78,7 +76,7 @@ typedef struct size_params{ int c; } size_params; -deconvolutional_layer *parse_deconvolutional(list *options, size_params params) +deconvolutional_layer parse_deconvolutional(list *options, size_params params) { int n = option_find_int(options, "filters",1); int size = option_find_int(options, "size",1); @@ -93,20 +91,20 @@ deconvolutional_layer *parse_deconvolutional(list *options, size_params params) batch=params.batch; if(!(h && w && c)) error("Layer before deconvolutional layer must output image."); - deconvolutional_layer *layer = make_deconvolutional_layer(batch,h,w,c,n,size,stride,activation); + deconvolutional_layer layer = make_deconvolutional_layer(batch,h,w,c,n,size,stride,activation); char *weights = option_find_str(options, "weights", 0); char *biases = option_find_str(options, "biases", 0); - parse_data(weights, layer->filters, c*n*size*size); - parse_data(biases, layer->biases, n); + parse_data(weights, layer.filters, c*n*size*size); + parse_data(biases, layer.biases, n); #ifdef GPU - if(weights || biases) push_deconvolutional_layer(*layer); + if(weights || biases) push_deconvolutional_layer(layer); #endif option_unused(options); return layer; } -convolutional_layer *parse_convolutional(list *options, size_params params) +convolutional_layer parse_convolutional(list *options, size_params params) { int n = option_find_int(options, "filters",1); int size = option_find_int(options, "size",1); @@ -122,68 +120,68 @@ convolutional_layer *parse_convolutional(list *options, size_params params) batch=params.batch; if(!(h && w && c)) error("Layer before convolutional layer must output image."); - convolutional_layer *layer = make_convolutional_layer(batch,h,w,c,n,size,stride,pad,activation); + convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,pad,activation); char *weights = option_find_str(options, "weights", 0); char *biases = option_find_str(options, "biases", 0); - parse_data(weights, layer->filters, c*n*size*size); - parse_data(biases, layer->biases, n); + parse_data(weights, layer.filters, c*n*size*size); + parse_data(biases, layer.biases, n); #ifdef GPU - if(weights || biases) push_convolutional_layer(*layer); + if(weights || biases) push_convolutional_layer(layer); #endif option_unused(options); return layer; } -connected_layer *parse_connected(list *options, size_params params) +connected_layer parse_connected(list *options, size_params params) { int output = option_find_int(options, "output",1); char *activation_s = option_find_str(options, "activation", "logistic"); ACTIVATION activation = get_activation(activation_s); - connected_layer *layer = make_connected_layer(params.batch, params.inputs, output, activation); + connected_layer layer = make_connected_layer(params.batch, params.inputs, output, activation); char *weights = option_find_str(options, "weights", 0); char *biases = option_find_str(options, "biases", 0); - parse_data(biases, layer->biases, output); - parse_data(weights, layer->weights, params.inputs*output); + parse_data(biases, layer.biases, output); + parse_data(weights, layer.weights, params.inputs*output); #ifdef GPU - if(weights || biases) push_connected_layer(*layer); + if(weights || biases) push_connected_layer(layer); #endif option_unused(options); return layer; } -softmax_layer *parse_softmax(list *options, size_params params) +softmax_layer parse_softmax(list *options, size_params params) { int groups = option_find_int(options, "groups",1); - softmax_layer *layer = make_softmax_layer(params.batch, params.inputs, groups); + softmax_layer layer = make_softmax_layer(params.batch, params.inputs, groups); option_unused(options); return layer; } -detection_layer *parse_detection(list *options, size_params params) +detection_layer parse_detection(list *options, size_params params) { int coords = option_find_int(options, "coords", 1); int classes = option_find_int(options, "classes", 1); int rescore = option_find_int(options, "rescore", 1); int nuisance = option_find_int(options, "nuisance", 0); int background = option_find_int(options, "background", 1); - detection_layer *layer = make_detection_layer(params.batch, params.inputs, classes, coords, rescore, background, nuisance); + detection_layer layer = make_detection_layer(params.batch, params.inputs, classes, coords, rescore, background, nuisance); option_unused(options); return layer; } -cost_layer *parse_cost(list *options, size_params params) +cost_layer parse_cost(list *options, size_params params) { char *type_s = option_find_str(options, "type", "sse"); COST_TYPE type = get_cost_type(type_s); - cost_layer *layer = make_cost_layer(params.batch, params.inputs, type); + cost_layer layer = make_cost_layer(params.batch, params.inputs, type); option_unused(options); return layer; } -crop_layer *parse_crop(list *options, size_params params) +crop_layer parse_crop(list *options, size_params params) { int crop_height = option_find_int(options, "crop_height",1); int crop_width = option_find_int(options, "crop_width",1); @@ -199,12 +197,12 @@ crop_layer *parse_crop(list *options, size_params params) batch=params.batch; if(!(h && w && c)) error("Layer before crop layer must output image."); - crop_layer *layer = make_crop_layer(batch,h,w,c,crop_height,crop_width,flip, angle, saturation, exposure); + crop_layer l = make_crop_layer(batch,h,w,c,crop_height,crop_width,flip, angle, saturation, exposure); option_unused(options); - return layer; + return l; } -maxpool_layer *parse_maxpool(list *options, size_params params) +maxpool_layer parse_maxpool(list *options, size_params params) { int stride = option_find_int(options, "stride",1); int size = option_find_int(options, "size",stride); @@ -216,39 +214,20 @@ maxpool_layer *parse_maxpool(list *options, size_params params) batch=params.batch; if(!(h && w && c)) error("Layer before maxpool layer must output image."); - maxpool_layer *layer = make_maxpool_layer(batch,h,w,c,size,stride); + maxpool_layer layer = make_maxpool_layer(batch,h,w,c,size,stride); option_unused(options); return layer; } -dropout_layer *parse_dropout(list *options, size_params params) +dropout_layer parse_dropout(list *options, size_params params) { float probability = option_find_float(options, "probability", .5); - dropout_layer *layer = make_dropout_layer(params.batch, params.inputs, probability); + dropout_layer layer = make_dropout_layer(params.batch, params.inputs, probability); option_unused(options); return layer; } -normalization_layer *parse_normalization(list *options, size_params params) -{ - int size = option_find_int(options, "size",1); - float alpha = option_find_float(options, "alpha", 0.); - float beta = option_find_float(options, "beta", 1.); - float kappa = option_find_float(options, "kappa", 1.); - - int batch,h,w,c; - h = params.h; - w = params.w; - c = params.c; - batch=params.batch; - if(!(h && w && c)) error("Layer before normalization layer must output image."); - - normalization_layer *layer = make_normalization_layer(batch,h,w,c,size, alpha, beta, kappa); - option_unused(options); - return layer; -} - -route_layer *parse_route(list *options, size_params params, network net) +route_layer parse_route(list *options, size_params params, network net) { char *l = option_find(options, "layers"); int len = strlen(l); @@ -265,11 +244,26 @@ route_layer *parse_route(list *options, size_params params, network net) int index = atoi(l); l = strchr(l, ',')+1; layers[i] = index; - sizes[i] = get_network_output_size_layer(net, index); + sizes[i] = net.layers[index].outputs; } int batch = params.batch; - route_layer *layer = make_route_layer(batch, n, layers, sizes); + route_layer layer = make_route_layer(batch, n, layers, sizes); + + convolutional_layer first = net.layers[layers[0]]; + layer.out_w = first.out_w; + layer.out_h = first.out_h; + layer.out_c = first.out_c; + for(i = 1; i < n; ++i){ + int index = layers[i]; + convolutional_layer next = net.layers[index]; + if(next.out_w == first.out_w && next.out_h == first.out_h){ + layer.out_c += next.out_c; + }else{ + layer.out_h = layer.out_w = layer.out_c = 0; + } + } + option_unused(options); return layer; } @@ -318,61 +312,44 @@ network parse_network_cfg(char *filename) fprintf(stderr, "%d: ", count); s = (section *)n->val; options = s->options; + layer l = {0}; if(is_convolutional(s)){ - convolutional_layer *layer = parse_convolutional(options, params); - net.types[count] = CONVOLUTIONAL; - net.layers[count] = layer; + l = parse_convolutional(options, params); }else if(is_deconvolutional(s)){ - deconvolutional_layer *layer = parse_deconvolutional(options, params); - net.types[count] = DECONVOLUTIONAL; - net.layers[count] = layer; + l = parse_deconvolutional(options, params); }else if(is_connected(s)){ - connected_layer *layer = parse_connected(options, params); - net.types[count] = CONNECTED; - net.layers[count] = layer; + l = parse_connected(options, params); }else if(is_crop(s)){ - crop_layer *layer = parse_crop(options, params); - net.types[count] = CROP; - net.layers[count] = layer; + l = parse_crop(options, params); }else if(is_cost(s)){ - cost_layer *layer = parse_cost(options, params); - net.types[count] = COST; - net.layers[count] = layer; + l = parse_cost(options, params); }else if(is_detection(s)){ - detection_layer *layer = parse_detection(options, params); - net.types[count] = DETECTION; - net.layers[count] = layer; + l = parse_detection(options, params); }else if(is_softmax(s)){ - softmax_layer *layer = parse_softmax(options, params); - net.types[count] = SOFTMAX; - net.layers[count] = layer; + l = parse_softmax(options, params); }else if(is_maxpool(s)){ - maxpool_layer *layer = parse_maxpool(options, params); - net.types[count] = MAXPOOL; - net.layers[count] = layer; - }else if(is_normalization(s)){ - normalization_layer *layer = parse_normalization(options, params); - net.types[count] = NORMALIZATION; - net.layers[count] = layer; + l = parse_maxpool(options, params); }else if(is_route(s)){ - route_layer *layer = parse_route(options, params, net); - net.types[count] = ROUTE; - net.layers[count] = layer; + l = parse_route(options, params, net); }else if(is_dropout(s)){ - dropout_layer *layer = parse_dropout(options, params); - net.types[count] = DROPOUT; - net.layers[count] = layer; + l = parse_dropout(options, params); + l.output = net.layers[count-1].output; + l.delta = net.layers[count-1].delta; + #ifdef GPU + l.output_gpu = net.layers[count-1].output_gpu; + l.delta_gpu = net.layers[count-1].delta_gpu; + #endif }else{ fprintf(stderr, "Type not recognized: %s\n", s->type); } + net.layers[count] = l; free_section(s); n = n->next; if(n){ - image im = get_network_image_layer(net, count); - params.h = im.h; - params.w = im.w; - params.c = im.c; - params.inputs = get_network_output_size_layer(net, count); + params.h = l.out_h; + params.w = l.out_w; + params.c = l.out_c; + params.inputs = l.outputs; } ++count; } @@ -429,11 +406,6 @@ int is_softmax(section *s) return (strcmp(s->type, "[soft]")==0 || strcmp(s->type, "[softmax]")==0); } -int is_normalization(section *s) -{ - return (strcmp(s->type, "[lrnorm]")==0 - || strcmp(s->type, "[localresponsenormalization]")==0); -} int is_route(section *s) { return (strcmp(s->type, "[route]")==0); @@ -492,114 +464,6 @@ list *read_cfg(char *filename) return sections; } -void print_convolutional_cfg(FILE *fp, convolutional_layer *l, network net, int count) -{ -#ifdef GPU - if(gpu_index >= 0) pull_convolutional_layer(*l); -#endif - int i; - fprintf(fp, "[convolutional]\n"); - fprintf(fp, "filters=%d\n" - "size=%d\n" - "stride=%d\n" - "pad=%d\n" - "activation=%s\n", - l->n, l->size, l->stride, l->pad, - get_activation_string(l->activation)); - fprintf(fp, "biases="); - for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]); - fprintf(fp, "\n"); - fprintf(fp, "weights="); - for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]); - fprintf(fp, "\n\n"); -} - -void print_deconvolutional_cfg(FILE *fp, deconvolutional_layer *l, network net, int count) -{ -#ifdef GPU - if(gpu_index >= 0) pull_deconvolutional_layer(*l); -#endif - int i; - fprintf(fp, "[deconvolutional]\n"); - fprintf(fp, "filters=%d\n" - "size=%d\n" - "stride=%d\n" - "activation=%s\n", - l->n, l->size, l->stride, - get_activation_string(l->activation)); - fprintf(fp, "biases="); - for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]); - fprintf(fp, "\n"); - fprintf(fp, "weights="); - for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]); - fprintf(fp, "\n\n"); -} - -void print_dropout_cfg(FILE *fp, dropout_layer *l, network net, int count) -{ - fprintf(fp, "[dropout]\n"); - fprintf(fp, "probability=%g\n\n", l->probability); -} - -void print_connected_cfg(FILE *fp, connected_layer *l, network net, int count) -{ -#ifdef GPU - if(gpu_index >= 0) pull_connected_layer(*l); -#endif - int i; - fprintf(fp, "[connected]\n"); - fprintf(fp, "output=%d\n" - "activation=%s\n", - l->outputs, - get_activation_string(l->activation)); - fprintf(fp, "biases="); - for(i = 0; i < l->outputs; ++i) fprintf(fp, "%g,", l->biases[i]); - fprintf(fp, "\n"); - fprintf(fp, "weights="); - for(i = 0; i < l->outputs*l->inputs; ++i) fprintf(fp, "%g,", l->weights[i]); - fprintf(fp, "\n\n"); -} - -void print_crop_cfg(FILE *fp, crop_layer *l, network net, int count) -{ - fprintf(fp, "[crop]\n"); - fprintf(fp, "crop_height=%d\ncrop_width=%d\nflip=%d\n\n", l->crop_height, l->crop_width, l->flip); -} - -void print_maxpool_cfg(FILE *fp, maxpool_layer *l, network net, int count) -{ - fprintf(fp, "[maxpool]\n"); - fprintf(fp, "size=%d\nstride=%d\n\n", l->size, l->stride); -} - -void print_normalization_cfg(FILE *fp, normalization_layer *l, network net, int count) -{ - fprintf(fp, "[localresponsenormalization]\n"); - fprintf(fp, "size=%d\n" - "alpha=%g\n" - "beta=%g\n" - "kappa=%g\n\n", l->size, l->alpha, l->beta, l->kappa); -} - -void print_softmax_cfg(FILE *fp, softmax_layer *l, network net, int count) -{ - fprintf(fp, "[softmax]\n"); - fprintf(fp, "\n"); -} - -void print_detection_cfg(FILE *fp, detection_layer *l, network net, int count) -{ - fprintf(fp, "[detection]\n"); - fprintf(fp, "classes=%d\ncoords=%d\nrescore=%d\nnuisance=%d\n", l->classes, l->coords, l->rescore, l->nuisance); - fprintf(fp, "\n"); -} - -void print_cost_cfg(FILE *fp, cost_layer *l, network net, int count) -{ - fprintf(fp, "[cost]\ntype=%s\n", get_cost_string(l->type)); - fprintf(fp, "\n"); -} - void save_weights(network net, char *filename) { fprintf(stderr, "Saving weights to %s\n", filename); @@ -613,37 +477,35 @@ void save_weights(network net, char *filename) int i; for(i = 0; i < net.n; ++i){ - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *) net.layers[i]; + layer l = net.layers[i]; + if(l.type == CONVOLUTIONAL){ #ifdef GPU if(gpu_index >= 0){ - pull_convolutional_layer(layer); + pull_convolutional_layer(l); } #endif - int num = layer.n*layer.c*layer.size*layer.size; - fwrite(layer.biases, sizeof(float), layer.n, fp); - fwrite(layer.filters, sizeof(float), num, fp); + int num = l.n*l.c*l.size*l.size; + fwrite(l.biases, sizeof(float), l.n, fp); + fwrite(l.filters, sizeof(float), num, fp); } - if(net.types[i] == DECONVOLUTIONAL){ - deconvolutional_layer layer = *(deconvolutional_layer *) net.layers[i]; + if(l.type == DECONVOLUTIONAL){ #ifdef GPU if(gpu_index >= 0){ - pull_deconvolutional_layer(layer); + pull_deconvolutional_layer(l); } #endif - int num = layer.n*layer.c*layer.size*layer.size; - fwrite(layer.biases, sizeof(float), layer.n, fp); - fwrite(layer.filters, sizeof(float), num, fp); + int num = l.n*l.c*l.size*l.size; + fwrite(l.biases, sizeof(float), l.n, fp); + fwrite(l.filters, sizeof(float), num, fp); } - if(net.types[i] == CONNECTED){ - connected_layer layer = *(connected_layer *) net.layers[i]; + if(l.type == CONNECTED){ #ifdef GPU if(gpu_index >= 0){ - pull_connected_layer(layer); + pull_connected_layer(l); } #endif - fwrite(layer.biases, sizeof(float), layer.outputs, fp); - fwrite(layer.weights, sizeof(float), layer.outputs*layer.inputs, fp); + fwrite(l.biases, sizeof(float), l.outputs, fp); + fwrite(l.weights, sizeof(float), l.outputs*l.inputs, fp); } } fclose(fp); @@ -663,35 +525,33 @@ void load_weights_upto(network *net, char *filename, int cutoff) int i; for(i = 0; i < net->n && i < cutoff; ++i){ - if(net->types[i] == CONVOLUTIONAL){ - convolutional_layer layer = *(convolutional_layer *) net->layers[i]; - int num = layer.n*layer.c*layer.size*layer.size; - fread(layer.biases, sizeof(float), layer.n, fp); - fread(layer.filters, sizeof(float), num, fp); + layer l = net->layers[i]; + if(l.type == CONVOLUTIONAL){ + int num = l.n*l.c*l.size*l.size; + fread(l.biases, sizeof(float), l.n, fp); + fread(l.filters, sizeof(float), num, fp); #ifdef GPU if(gpu_index >= 0){ - push_convolutional_layer(layer); + push_convolutional_layer(l); } #endif } - if(net->types[i] == DECONVOLUTIONAL){ - deconvolutional_layer layer = *(deconvolutional_layer *) net->layers[i]; - int num = layer.n*layer.c*layer.size*layer.size; - fread(layer.biases, sizeof(float), layer.n, fp); - fread(layer.filters, sizeof(float), num, fp); + if(l.type == DECONVOLUTIONAL){ + int num = l.n*l.c*l.size*l.size; + fread(l.biases, sizeof(float), l.n, fp); + fread(l.filters, sizeof(float), num, fp); #ifdef GPU if(gpu_index >= 0){ - push_deconvolutional_layer(layer); + push_deconvolutional_layer(l); } #endif } - if(net->types[i] == CONNECTED){ - connected_layer layer = *(connected_layer *) net->layers[i]; - fread(layer.biases, sizeof(float), layer.outputs, fp); - fread(layer.weights, sizeof(float), layer.outputs*layer.inputs, fp); + if(l.type == CONNECTED){ + fread(l.biases, sizeof(float), l.outputs, fp); + fread(l.weights, sizeof(float), l.outputs*l.inputs, fp); #ifdef GPU if(gpu_index >= 0){ - push_connected_layer(layer); + push_connected_layer(l); } #endif } @@ -704,34 +564,3 @@ void load_weights(network *net, char *filename) load_weights_upto(net, filename, net->n); } -void save_network(network net, char *filename) -{ - FILE *fp = fopen(filename, "w"); - if(!fp) file_error(filename); - int i; - for(i = 0; i < net.n; ++i) - { - if(net.types[i] == CONVOLUTIONAL) - print_convolutional_cfg(fp, (convolutional_layer *)net.layers[i], net, i); - else if(net.types[i] == DECONVOLUTIONAL) - print_deconvolutional_cfg(fp, (deconvolutional_layer *)net.layers[i], net, i); - else if(net.types[i] == CONNECTED) - print_connected_cfg(fp, (connected_layer *)net.layers[i], net, i); - else if(net.types[i] == CROP) - print_crop_cfg(fp, (crop_layer *)net.layers[i], net, i); - else if(net.types[i] == MAXPOOL) - print_maxpool_cfg(fp, (maxpool_layer *)net.layers[i], net, i); - else if(net.types[i] == DROPOUT) - print_dropout_cfg(fp, (dropout_layer *)net.layers[i], net, i); - else if(net.types[i] == NORMALIZATION) - print_normalization_cfg(fp, (normalization_layer *)net.layers[i], net, i); - else if(net.types[i] == SOFTMAX) - print_softmax_cfg(fp, (softmax_layer *)net.layers[i], net, i); - else if(net.types[i] == DETECTION) - print_detection_cfg(fp, (detection_layer *)net.layers[i], net, i); - else if(net.types[i] == COST) - print_cost_cfg(fp, (cost_layer *)net.layers[i], net, i); - } - fclose(fp); -} - diff --git a/src/route_layer.c b/src/route_layer.c index c8897b14..e3802b7d 100644 --- a/src/route_layer.c +++ b/src/route_layer.c @@ -3,83 +3,89 @@ #include "blas.h" #include -route_layer *make_route_layer(int batch, int n, int *input_layers, int *input_sizes) +route_layer make_route_layer(int batch, int n, int *input_layers, int *input_sizes) { - printf("Route Layer:"); - route_layer *layer = calloc(1, sizeof(route_layer)); - layer->batch = batch; - layer->n = n; - layer->input_layers = input_layers; - layer->input_sizes = input_sizes; + fprintf(stderr,"Route Layer:"); + route_layer l = {0}; + l.type = ROUTE; + l.batch = batch; + l.n = n; + l.input_layers = input_layers; + l.input_sizes = input_sizes; int i; int outputs = 0; for(i = 0; i < n; ++i){ - printf(" %d", input_layers[i]); + fprintf(stderr," %d", input_layers[i]); outputs += input_sizes[i]; } - printf("\n"); - layer->outputs = outputs; - layer->delta = calloc(outputs*batch, sizeof(float)); - layer->output = calloc(outputs*batch, sizeof(float));; + fprintf(stderr, "\n"); + l.outputs = outputs; + l.inputs = outputs; + l.delta = calloc(outputs*batch, sizeof(float)); + l.output = calloc(outputs*batch, sizeof(float));; #ifdef GPU - layer->delta_gpu = cuda_make_array(0, outputs*batch); - layer->output_gpu = cuda_make_array(0, outputs*batch); + l.delta_gpu = cuda_make_array(0, outputs*batch); + l.output_gpu = cuda_make_array(0, outputs*batch); #endif - return layer; + return l; } -void forward_route_layer(const route_layer layer, network net) +void forward_route_layer(const route_layer l, network net) { int i, j; int offset = 0; - for(i = 0; i < layer.n; ++i){ - float *input = get_network_output_layer(net, layer.input_layers[i]); - int input_size = layer.input_sizes[i]; - for(j = 0; j < layer.batch; ++j){ - copy_cpu(input_size, input + j*input_size, 1, layer.output + offset + j*layer.outputs, 1); + for(i = 0; i < l.n; ++i){ + int index = l.input_layers[i]; + float *input = net.layers[index].output; + int input_size = l.input_sizes[i]; + for(j = 0; j < l.batch; ++j){ + copy_cpu(input_size, input + j*input_size, 1, l.output + offset + j*l.outputs, 1); } offset += input_size; } } -void backward_route_layer(const route_layer layer, network net) +void backward_route_layer(const route_layer l, network net) { int i, j; int offset = 0; - for(i = 0; i < layer.n; ++i){ - float *delta = get_network_delta_layer(net, layer.input_layers[i]); - int input_size = layer.input_sizes[i]; - for(j = 0; j < layer.batch; ++j){ - copy_cpu(input_size, layer.delta + offset + j*layer.outputs, 1, delta + j*input_size, 1); + for(i = 0; i < l.n; ++i){ + int index = l.input_layers[i]; + float *delta = net.layers[index].delta; + int input_size = l.input_sizes[i]; + for(j = 0; j < l.batch; ++j){ + copy_cpu(input_size, l.delta + offset + j*l.outputs, 1, delta + j*input_size, 1); } offset += input_size; } } #ifdef GPU -void forward_route_layer_gpu(const route_layer layer, network net) +void forward_route_layer_gpu(const route_layer l, network net) { int i, j; int offset = 0; - for(i = 0; i < layer.n; ++i){ - float *input = get_network_output_gpu_layer(net, layer.input_layers[i]); - int input_size = layer.input_sizes[i]; - for(j = 0; j < layer.batch; ++j){ - copy_ongpu(input_size, input + j*input_size, 1, layer.output_gpu + offset + j*layer.outputs, 1); + for(i = 0; i < l.n; ++i){ + int index = l.input_layers[i]; + float *input = net.layers[index].output_gpu; + int input_size = l.input_sizes[i]; + for(j = 0; j < l.batch; ++j){ + copy_ongpu(input_size, input + j*input_size, 1, l.output_gpu + offset + j*l.outputs, 1); } offset += input_size; } } -void backward_route_layer_gpu(const route_layer layer, network net) +void backward_route_layer_gpu(const route_layer l, network net) { int i, j; int offset = 0; - for(i = 0; i < layer.n; ++i){ - float *delta = get_network_delta_gpu_layer(net, layer.input_layers[i]); - int input_size = layer.input_sizes[i]; - for(j = 0; j < layer.batch; ++j){ - copy_ongpu(input_size, layer.delta_gpu + offset + j*layer.outputs, 1, delta + j*input_size, 1); + for(i = 0; i < l.n; ++i){ + int index = l.input_layers[i]; + float *delta = net.layers[index].delta_gpu; + int input_size = l.input_sizes[i]; + for(j = 0; j < l.batch; ++j){ + copy_ongpu(input_size, l.delta_gpu + offset + j*l.outputs, 1, delta + j*input_size, 1); } offset += input_size; } diff --git a/src/route_layer.h b/src/route_layer.h index 086ef879..1f0d6e32 100644 --- a/src/route_layer.h +++ b/src/route_layer.h @@ -1,28 +1,17 @@ #ifndef ROUTE_LAYER_H #define ROUTE_LAYER_H #include "network.h" +#include "layer.h" -typedef struct { - int batch; - int outputs; - int n; - int * input_layers; - int * input_sizes; - float * delta; - float * output; - #ifdef GPU - float * delta_gpu; - float * output_gpu; - #endif -} route_layer; +typedef layer route_layer; -route_layer *make_route_layer(int batch, int n, int *input_layers, int *input_size); -void forward_route_layer(const route_layer layer, network net); -void backward_route_layer(const route_layer layer, network net); +route_layer make_route_layer(int batch, int n, int *input_layers, int *input_size); +void forward_route_layer(const route_layer l, network net); +void backward_route_layer(const route_layer l, network net); #ifdef GPU -void forward_route_layer_gpu(const route_layer layer, network net); -void backward_route_layer_gpu(const route_layer layer, network net); +void forward_route_layer_gpu(const route_layer l, network net); +void backward_route_layer_gpu(const route_layer l, network net); #endif #endif diff --git a/src/softmax_layer.c b/src/softmax_layer.c index e344d166..ea22d059 100644 --- a/src/softmax_layer.c +++ b/src/softmax_layer.c @@ -7,21 +7,23 @@ #include #include -softmax_layer *make_softmax_layer(int batch, int inputs, int groups) +softmax_layer make_softmax_layer(int batch, int inputs, int groups) { assert(inputs%groups == 0); fprintf(stderr, "Softmax Layer: %d inputs\n", inputs); - softmax_layer *layer = calloc(1, sizeof(softmax_layer)); - layer->batch = batch; - layer->groups = groups; - layer->inputs = inputs; - layer->output = calloc(inputs*batch, sizeof(float)); - layer->delta = calloc(inputs*batch, sizeof(float)); + softmax_layer l = {0}; + l.type = SOFTMAX; + l.batch = batch; + l.groups = groups; + l.inputs = inputs; + l.outputs = inputs; + l.output = calloc(inputs*batch, sizeof(float)); + l.delta = calloc(inputs*batch, sizeof(float)); #ifdef GPU - layer->output_gpu = cuda_make_array(layer->output, inputs*batch); - layer->delta_gpu = cuda_make_array(layer->delta, inputs*batch); + l.output_gpu = cuda_make_array(l.output, inputs*batch); + l.delta_gpu = cuda_make_array(l.delta, inputs*batch); #endif - return layer; + return l; } void softmax_array(float *input, int n, float *output) @@ -42,21 +44,21 @@ void softmax_array(float *input, int n, float *output) } } -void forward_softmax_layer(const softmax_layer layer, network_state state) +void forward_softmax_layer(const softmax_layer l, network_state state) { int b; - int inputs = layer.inputs / layer.groups; - int batch = layer.batch * layer.groups; + int inputs = l.inputs / l.groups; + int batch = l.batch * l.groups; for(b = 0; b < batch; ++b){ - softmax_array(state.input+b*inputs, inputs, layer.output+b*inputs); + softmax_array(state.input+b*inputs, inputs, l.output+b*inputs); } } -void backward_softmax_layer(const softmax_layer layer, network_state state) +void backward_softmax_layer(const softmax_layer l, network_state state) { int i; - for(i = 0; i < layer.inputs*layer.batch; ++i){ - state.delta[i] = layer.delta[i]; + for(i = 0; i < l.inputs*l.batch; ++i){ + state.delta[i] = l.delta[i]; } } diff --git a/src/softmax_layer.h b/src/softmax_layer.h index ecdec1ed..f29c6521 100644 --- a/src/softmax_layer.h +++ b/src/softmax_layer.h @@ -1,28 +1,19 @@ #ifndef SOFTMAX_LAYER_H #define SOFTMAX_LAYER_H #include "params.h" +#include "layer.h" -typedef struct { - int inputs; - int batch; - int groups; - float *delta; - float *output; - #ifdef GPU - float * delta_gpu; - float * output_gpu; - #endif -} softmax_layer; +typedef layer softmax_layer; void softmax_array(float *input, int n, float *output); -softmax_layer *make_softmax_layer(int batch, int inputs, int groups); -void forward_softmax_layer(const softmax_layer layer, network_state state); -void backward_softmax_layer(const softmax_layer layer, network_state state); +softmax_layer make_softmax_layer(int batch, int inputs, int groups); +void forward_softmax_layer(const softmax_layer l, network_state state); +void backward_softmax_layer(const softmax_layer l, network_state state); #ifdef GPU -void pull_softmax_layer_output(const softmax_layer layer); -void forward_softmax_layer_gpu(const softmax_layer layer, network_state state); -void backward_softmax_layer_gpu(const softmax_layer layer, network_state state); +void pull_softmax_layer_output(const softmax_layer l); +void forward_softmax_layer_gpu(const softmax_layer l, network_state state); +void backward_softmax_layer_gpu(const softmax_layer l, network_state state); #endif #endif