route handles input images well....ish

This commit is contained in:
Joseph Redmon 2015-05-11 13:46:49 -07:00
parent dc0d7bb8a8
commit 516f019ba6
31 changed files with 1250 additions and 1819 deletions

View File

@ -25,7 +25,7 @@ CFLAGS+=-DGPU
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o
ifeq ($(GPU), 1) ifeq ($(GPU), 1)
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
endif endif

View File

@ -9,99 +9,97 @@
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation) connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation)
{ {
int i; int i;
connected_layer *layer = calloc(1, sizeof(connected_layer)); connected_layer l = {0};
l.type = CONNECTED;
layer->inputs = inputs; l.inputs = inputs;
layer->outputs = outputs; l.outputs = outputs;
layer->batch=batch; l.batch=batch;
layer->output = calloc(batch*outputs, sizeof(float*)); l.output = calloc(batch*outputs, sizeof(float*));
layer->delta = calloc(batch*outputs, sizeof(float*)); l.delta = calloc(batch*outputs, sizeof(float*));
layer->weight_updates = calloc(inputs*outputs, sizeof(float)); l.weight_updates = calloc(inputs*outputs, sizeof(float));
layer->bias_updates = calloc(outputs, sizeof(float)); l.bias_updates = calloc(outputs, sizeof(float));
layer->weight_prev = calloc(inputs*outputs, sizeof(float)); l.weights = calloc(inputs*outputs, sizeof(float));
layer->bias_prev = calloc(outputs, sizeof(float)); l.biases = calloc(outputs, sizeof(float));
layer->weights = calloc(inputs*outputs, sizeof(float));
layer->biases = calloc(outputs, sizeof(float));
float scale = 1./sqrt(inputs); float scale = 1./sqrt(inputs);
for(i = 0; i < inputs*outputs; ++i){ for(i = 0; i < inputs*outputs; ++i){
layer->weights[i] = 2*scale*rand_uniform() - scale; l.weights[i] = 2*scale*rand_uniform() - scale;
} }
for(i = 0; i < outputs; ++i){ for(i = 0; i < outputs; ++i){
layer->biases[i] = scale; l.biases[i] = scale;
} }
#ifdef GPU #ifdef GPU
layer->weights_gpu = cuda_make_array(layer->weights, inputs*outputs); l.weights_gpu = cuda_make_array(l.weights, inputs*outputs);
layer->biases_gpu = cuda_make_array(layer->biases, outputs); l.biases_gpu = cuda_make_array(l.biases, outputs);
layer->weight_updates_gpu = cuda_make_array(layer->weight_updates, inputs*outputs); l.weight_updates_gpu = cuda_make_array(l.weight_updates, inputs*outputs);
layer->bias_updates_gpu = cuda_make_array(layer->bias_updates, outputs); l.bias_updates_gpu = cuda_make_array(l.bias_updates, outputs);
layer->output_gpu = cuda_make_array(layer->output, outputs*batch); l.output_gpu = cuda_make_array(l.output, outputs*batch);
layer->delta_gpu = cuda_make_array(layer->delta, outputs*batch); l.delta_gpu = cuda_make_array(l.delta, outputs*batch);
#endif #endif
layer->activation = activation; l.activation = activation;
fprintf(stderr, "Connected Layer: %d inputs, %d outputs\n", inputs, outputs); fprintf(stderr, "Connected Layer: %d inputs, %d outputs\n", inputs, outputs);
return layer; return l;
} }
void update_connected_layer(connected_layer layer, int batch, float learning_rate, float momentum, float decay) void update_connected_layer(connected_layer l, int batch, float learning_rate, float momentum, float decay)
{ {
axpy_cpu(layer.outputs, learning_rate/batch, layer.bias_updates, 1, layer.biases, 1); axpy_cpu(l.outputs, learning_rate/batch, l.bias_updates, 1, l.biases, 1);
scal_cpu(layer.outputs, momentum, layer.bias_updates, 1); scal_cpu(l.outputs, momentum, l.bias_updates, 1);
axpy_cpu(layer.inputs*layer.outputs, -decay*batch, layer.weights, 1, layer.weight_updates, 1); axpy_cpu(l.inputs*l.outputs, -decay*batch, l.weights, 1, l.weight_updates, 1);
axpy_cpu(layer.inputs*layer.outputs, learning_rate/batch, layer.weight_updates, 1, layer.weights, 1); axpy_cpu(l.inputs*l.outputs, learning_rate/batch, l.weight_updates, 1, l.weights, 1);
scal_cpu(layer.inputs*layer.outputs, momentum, layer.weight_updates, 1); scal_cpu(l.inputs*l.outputs, momentum, l.weight_updates, 1);
} }
void forward_connected_layer(connected_layer layer, network_state state) void forward_connected_layer(connected_layer l, network_state state)
{ {
int i; int i;
for(i = 0; i < layer.batch; ++i){ for(i = 0; i < l.batch; ++i){
copy_cpu(layer.outputs, layer.biases, 1, layer.output + i*layer.outputs, 1); copy_cpu(l.outputs, l.biases, 1, l.output + i*l.outputs, 1);
} }
int m = layer.batch; int m = l.batch;
int k = layer.inputs; int k = l.inputs;
int n = layer.outputs; int n = l.outputs;
float *a = state.input; float *a = state.input;
float *b = layer.weights; float *b = l.weights;
float *c = layer.output; float *c = l.output;
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
activate_array(layer.output, layer.outputs*layer.batch, layer.activation); activate_array(l.output, l.outputs*l.batch, l.activation);
} }
void backward_connected_layer(connected_layer layer, network_state state) void backward_connected_layer(connected_layer l, network_state state)
{ {
int i; int i;
gradient_array(layer.output, layer.outputs*layer.batch, layer.activation, layer.delta); gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
for(i = 0; i < layer.batch; ++i){ for(i = 0; i < l.batch; ++i){
axpy_cpu(layer.outputs, 1, layer.delta + i*layer.outputs, 1, layer.bias_updates, 1); axpy_cpu(l.outputs, 1, l.delta + i*l.outputs, 1, l.bias_updates, 1);
} }
int m = layer.inputs; int m = l.inputs;
int k = layer.batch; int k = l.batch;
int n = layer.outputs; int n = l.outputs;
float *a = state.input; float *a = state.input;
float *b = layer.delta; float *b = l.delta;
float *c = layer.weight_updates; float *c = l.weight_updates;
gemm(1,0,m,n,k,1,a,m,b,n,1,c,n); gemm(1,0,m,n,k,1,a,m,b,n,1,c,n);
m = layer.batch; m = l.batch;
k = layer.outputs; k = l.outputs;
n = layer.inputs; n = l.inputs;
a = layer.delta; a = l.delta;
b = layer.weights; b = l.weights;
c = state.delta; c = state.delta;
if(c) gemm(0,1,m,n,k,1,a,k,b,k,0,c,n); if(c) gemm(0,1,m,n,k,1,a,k,b,k,0,c,n);
@ -109,69 +107,69 @@ void backward_connected_layer(connected_layer layer, network_state state)
#ifdef GPU #ifdef GPU
void pull_connected_layer(connected_layer layer) void pull_connected_layer(connected_layer l)
{ {
cuda_pull_array(layer.weights_gpu, layer.weights, layer.inputs*layer.outputs); cuda_pull_array(l.weights_gpu, l.weights, l.inputs*l.outputs);
cuda_pull_array(layer.biases_gpu, layer.biases, layer.outputs); cuda_pull_array(l.biases_gpu, l.biases, l.outputs);
cuda_pull_array(layer.weight_updates_gpu, layer.weight_updates, layer.inputs*layer.outputs); cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.inputs*l.outputs);
cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.outputs); cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
} }
void push_connected_layer(connected_layer layer) void push_connected_layer(connected_layer l)
{ {
cuda_push_array(layer.weights_gpu, layer.weights, layer.inputs*layer.outputs); cuda_push_array(l.weights_gpu, l.weights, l.inputs*l.outputs);
cuda_push_array(layer.biases_gpu, layer.biases, layer.outputs); cuda_push_array(l.biases_gpu, l.biases, l.outputs);
cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.inputs*layer.outputs); cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.inputs*l.outputs);
cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.outputs); cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
} }
void update_connected_layer_gpu(connected_layer layer, int batch, float learning_rate, float momentum, float decay) void update_connected_layer_gpu(connected_layer l, int batch, float learning_rate, float momentum, float decay)
{ {
axpy_ongpu(layer.outputs, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1); axpy_ongpu(l.outputs, learning_rate/batch, l.bias_updates_gpu, 1, l.biases_gpu, 1);
scal_ongpu(layer.outputs, momentum, layer.bias_updates_gpu, 1); scal_ongpu(l.outputs, momentum, l.bias_updates_gpu, 1);
axpy_ongpu(layer.inputs*layer.outputs, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1); axpy_ongpu(l.inputs*l.outputs, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1);
axpy_ongpu(layer.inputs*layer.outputs, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1); axpy_ongpu(l.inputs*l.outputs, learning_rate/batch, l.weight_updates_gpu, 1, l.weights_gpu, 1);
scal_ongpu(layer.inputs*layer.outputs, momentum, layer.weight_updates_gpu, 1); scal_ongpu(l.inputs*l.outputs, momentum, l.weight_updates_gpu, 1);
} }
void forward_connected_layer_gpu(connected_layer layer, network_state state) void forward_connected_layer_gpu(connected_layer l, network_state state)
{ {
int i; int i;
for(i = 0; i < layer.batch; ++i){ for(i = 0; i < l.batch; ++i){
copy_ongpu_offset(layer.outputs, layer.biases_gpu, 0, 1, layer.output_gpu, i*layer.outputs, 1); copy_ongpu_offset(l.outputs, l.biases_gpu, 0, 1, l.output_gpu, i*l.outputs, 1);
} }
int m = layer.batch; int m = l.batch;
int k = layer.inputs; int k = l.inputs;
int n = layer.outputs; int n = l.outputs;
float * a = state.input; float * a = state.input;
float * b = layer.weights_gpu; float * b = l.weights_gpu;
float * c = layer.output_gpu; float * c = l.output_gpu;
gemm_ongpu(0,0,m,n,k,1,a,k,b,n,1,c,n); gemm_ongpu(0,0,m,n,k,1,a,k,b,n,1,c,n);
activate_array_ongpu(layer.output_gpu, layer.outputs*layer.batch, layer.activation); activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
} }
void backward_connected_layer_gpu(connected_layer layer, network_state state) void backward_connected_layer_gpu(connected_layer l, network_state state)
{ {
int i; int i;
gradient_array_ongpu(layer.output_gpu, layer.outputs*layer.batch, layer.activation, layer.delta_gpu); gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
for(i = 0; i < layer.batch; ++i){ for(i = 0; i < l.batch; ++i){
axpy_ongpu_offset(layer.outputs, 1, layer.delta_gpu, i*layer.outputs, 1, layer.bias_updates_gpu, 0, 1); axpy_ongpu_offset(l.outputs, 1, l.delta_gpu, i*l.outputs, 1, l.bias_updates_gpu, 0, 1);
} }
int m = layer.inputs; int m = l.inputs;
int k = layer.batch; int k = l.batch;
int n = layer.outputs; int n = l.outputs;
float * a = state.input; float * a = state.input;
float * b = layer.delta_gpu; float * b = l.delta_gpu;
float * c = layer.weight_updates_gpu; float * c = l.weight_updates_gpu;
gemm_ongpu(1,0,m,n,k,1,a,m,b,n,1,c,n); gemm_ongpu(1,0,m,n,k,1,a,m,b,n,1,c,n);
m = layer.batch; m = l.batch;
k = layer.outputs; k = l.outputs;
n = layer.inputs; n = l.inputs;
a = layer.delta_gpu; a = l.delta_gpu;
b = layer.weights_gpu; b = l.weights_gpu;
c = state.delta; c = state.delta;
if(c) gemm_ongpu(0,1,m,n,k,1,a,k,b,k,0,c,n); if(c) gemm_ongpu(0,1,m,n,k,1,a,k,b,k,0,c,n);

View File

@ -3,38 +3,11 @@
#include "activations.h" #include "activations.h"
#include "params.h" #include "params.h"
#include "layer.h"
typedef struct{ typedef layer connected_layer;
int batch;
int inputs;
int outputs;
float *weights;
float *biases;
float *weight_updates; connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation);
float *bias_updates;
float *weight_prev;
float *bias_prev;
float *output;
float *delta;
#ifdef GPU
float * weights_gpu;
float * biases_gpu;
float * weight_updates_gpu;
float * bias_updates_gpu;
float * output_gpu;
float * delta_gpu;
#endif
ACTIVATION activation;
} connected_layer;
connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation);
void forward_connected_layer(connected_layer layer, network_state state); void forward_connected_layer(connected_layer layer, network_state state);
void backward_connected_layer(connected_layer layer, network_state state); void backward_connected_layer(connected_layer layer, network_state state);

View File

@ -7,111 +7,117 @@
#include <stdio.h> #include <stdio.h>
#include <time.h> #include <time.h>
int convolutional_out_height(convolutional_layer layer) int convolutional_out_height(convolutional_layer l)
{ {
int h = layer.h; int h = l.h;
if (!layer.pad) h -= layer.size; if (!l.pad) h -= l.size;
else h -= 1; else h -= 1;
return h/layer.stride + 1; return h/l.stride + 1;
} }
int convolutional_out_width(convolutional_layer layer) int convolutional_out_width(convolutional_layer l)
{ {
int w = layer.w; int w = l.w;
if (!layer.pad) w -= layer.size; if (!l.pad) w -= l.size;
else w -= 1; else w -= 1;
return w/layer.stride + 1; return w/l.stride + 1;
} }
image get_convolutional_image(convolutional_layer layer) image get_convolutional_image(convolutional_layer l)
{ {
int h,w,c; int h,w,c;
h = convolutional_out_height(layer); h = convolutional_out_height(l);
w = convolutional_out_width(layer); w = convolutional_out_width(l);
c = layer.n; c = l.n;
return float_to_image(w,h,c,layer.output); return float_to_image(w,h,c,l.output);
} }
image get_convolutional_delta(convolutional_layer layer) image get_convolutional_delta(convolutional_layer l)
{ {
int h,w,c; int h,w,c;
h = convolutional_out_height(layer); h = convolutional_out_height(l);
w = convolutional_out_width(layer); w = convolutional_out_width(l);
c = layer.n; c = l.n;
return float_to_image(w,h,c,layer.delta); return float_to_image(w,h,c,l.delta);
} }
convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation) convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation)
{ {
int i; int i;
convolutional_layer *layer = calloc(1, sizeof(convolutional_layer)); convolutional_layer l = {0};
l.type = CONVOLUTIONAL;
layer->h = h; l.h = h;
layer->w = w; l.w = w;
layer->c = c; l.c = c;
layer->n = n; l.n = n;
layer->batch = batch; l.batch = batch;
layer->stride = stride; l.stride = stride;
layer->size = size; l.size = size;
layer->pad = pad; l.pad = pad;
layer->filters = calloc(c*n*size*size, sizeof(float)); l.filters = calloc(c*n*size*size, sizeof(float));
layer->filter_updates = calloc(c*n*size*size, sizeof(float)); l.filter_updates = calloc(c*n*size*size, sizeof(float));
layer->biases = calloc(n, sizeof(float)); l.biases = calloc(n, sizeof(float));
layer->bias_updates = calloc(n, sizeof(float)); l.bias_updates = calloc(n, sizeof(float));
float scale = 1./sqrt(size*size*c); float scale = 1./sqrt(size*size*c);
for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = 2*scale*rand_uniform() - scale; for(i = 0; i < c*n*size*size; ++i) l.filters[i] = 2*scale*rand_uniform() - scale;
for(i = 0; i < n; ++i){ for(i = 0; i < n; ++i){
layer->biases[i] = scale; l.biases[i] = scale;
} }
int out_h = convolutional_out_height(*layer); int out_h = convolutional_out_height(l);
int out_w = convolutional_out_width(*layer); int out_w = convolutional_out_width(l);
l.out_h = out_h;
l.out_w = out_w;
l.out_c = n;
l.outputs = l.out_h * l.out_w * l.out_c;
l.inputs = l.w * l.h * l.c;
layer->col_image = calloc(out_h*out_w*size*size*c, sizeof(float)); l.col_image = calloc(out_h*out_w*size*size*c, sizeof(float));
layer->output = calloc(layer->batch*out_h * out_w * n, sizeof(float)); l.output = calloc(l.batch*out_h * out_w * n, sizeof(float));
layer->delta = calloc(layer->batch*out_h * out_w * n, sizeof(float)); l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float));
#ifdef GPU #ifdef GPU
layer->filters_gpu = cuda_make_array(layer->filters, c*n*size*size); l.filters_gpu = cuda_make_array(l.filters, c*n*size*size);
layer->filter_updates_gpu = cuda_make_array(layer->filter_updates, c*n*size*size); l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size);
layer->biases_gpu = cuda_make_array(layer->biases, n); l.biases_gpu = cuda_make_array(l.biases, n);
layer->bias_updates_gpu = cuda_make_array(layer->bias_updates, n); l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
layer->col_image_gpu = cuda_make_array(layer->col_image, out_h*out_w*size*size*c); l.col_image_gpu = cuda_make_array(l.col_image, out_h*out_w*size*size*c);
layer->delta_gpu = cuda_make_array(layer->delta, layer->batch*out_h*out_w*n); l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
layer->output_gpu = cuda_make_array(layer->output, layer->batch*out_h*out_w*n); l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
#endif #endif
layer->activation = activation; l.activation = activation;
fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n); fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
return layer; return l;
} }
void resize_convolutional_layer(convolutional_layer *layer, int h, int w) void resize_convolutional_layer(convolutional_layer *l, int h, int w)
{ {
layer->h = h; l->h = h;
layer->w = w; l->w = w;
int out_h = convolutional_out_height(*layer); int out_h = convolutional_out_height(*l);
int out_w = convolutional_out_width(*layer); int out_w = convolutional_out_width(*l);
layer->col_image = realloc(layer->col_image, l->col_image = realloc(l->col_image,
out_h*out_w*layer->size*layer->size*layer->c*sizeof(float)); out_h*out_w*l->size*l->size*l->c*sizeof(float));
layer->output = realloc(layer->output, l->output = realloc(l->output,
layer->batch*out_h * out_w * layer->n*sizeof(float)); l->batch*out_h * out_w * l->n*sizeof(float));
layer->delta = realloc(layer->delta, l->delta = realloc(l->delta,
layer->batch*out_h * out_w * layer->n*sizeof(float)); l->batch*out_h * out_w * l->n*sizeof(float));
#ifdef GPU #ifdef GPU
cuda_free(layer->col_image_gpu); cuda_free(l->col_image_gpu);
cuda_free(layer->delta_gpu); cuda_free(l->delta_gpu);
cuda_free(layer->output_gpu); cuda_free(l->output_gpu);
layer->col_image_gpu = cuda_make_array(layer->col_image, out_h*out_w*layer->size*layer->size*layer->c); l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c);
layer->delta_gpu = cuda_make_array(layer->delta, layer->batch*out_h*out_w*layer->n); l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
layer->output_gpu = cuda_make_array(layer->output, layer->batch*out_h*out_w*layer->n); l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
#endif #endif
} }
@ -138,104 +144,104 @@ void backward_bias(float *bias_updates, float *delta, int batch, int n, int size
} }
void forward_convolutional_layer(const convolutional_layer layer, network_state state) void forward_convolutional_layer(const convolutional_layer l, network_state state)
{ {
int out_h = convolutional_out_height(layer); int out_h = convolutional_out_height(l);
int out_w = convolutional_out_width(layer); int out_w = convolutional_out_width(l);
int i; int i;
bias_output(layer.output, layer.biases, layer.batch, layer.n, out_h*out_w); bias_output(l.output, l.biases, l.batch, l.n, out_h*out_w);
int m = layer.n; int m = l.n;
int k = layer.size*layer.size*layer.c; int k = l.size*l.size*l.c;
int n = out_h*out_w; int n = out_h*out_w;
float *a = layer.filters; float *a = l.filters;
float *b = layer.col_image; float *b = l.col_image;
float *c = layer.output; float *c = l.output;
for(i = 0; i < layer.batch; ++i){ for(i = 0; i < l.batch; ++i){
im2col_cpu(state.input, layer.c, layer.h, layer.w, im2col_cpu(state.input, l.c, l.h, l.w,
layer.size, layer.stride, layer.pad, b); l.size, l.stride, l.pad, b);
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
c += n*m; c += n*m;
state.input += layer.c*layer.h*layer.w; state.input += l.c*l.h*l.w;
} }
activate_array(layer.output, m*n*layer.batch, layer.activation); activate_array(l.output, m*n*l.batch, l.activation);
} }
void backward_convolutional_layer(convolutional_layer layer, network_state state) void backward_convolutional_layer(convolutional_layer l, network_state state)
{ {
int i; int i;
int m = layer.n; int m = l.n;
int n = layer.size*layer.size*layer.c; int n = l.size*l.size*l.c;
int k = convolutional_out_height(layer)* int k = convolutional_out_height(l)*
convolutional_out_width(layer); convolutional_out_width(l);
gradient_array(layer.output, m*k*layer.batch, layer.activation, layer.delta); gradient_array(l.output, m*k*l.batch, l.activation, l.delta);
backward_bias(layer.bias_updates, layer.delta, layer.batch, layer.n, k); backward_bias(l.bias_updates, l.delta, l.batch, l.n, k);
if(state.delta) memset(state.delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float)); if(state.delta) memset(state.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float));
for(i = 0; i < layer.batch; ++i){ for(i = 0; i < l.batch; ++i){
float *a = layer.delta + i*m*k; float *a = l.delta + i*m*k;
float *b = layer.col_image; float *b = l.col_image;
float *c = layer.filter_updates; float *c = l.filter_updates;
float *im = state.input+i*layer.c*layer.h*layer.w; float *im = state.input+i*l.c*l.h*l.w;
im2col_cpu(im, layer.c, layer.h, layer.w, im2col_cpu(im, l.c, l.h, l.w,
layer.size, layer.stride, layer.pad, b); l.size, l.stride, l.pad, b);
gemm(0,1,m,n,k,1,a,k,b,k,1,c,n); gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
if(state.delta){ if(state.delta){
a = layer.filters; a = l.filters;
b = layer.delta + i*m*k; b = l.delta + i*m*k;
c = layer.col_image; c = l.col_image;
gemm(1,0,n,k,m,1,a,n,b,k,0,c,k); gemm(1,0,n,k,m,1,a,n,b,k,0,c,k);
col2im_cpu(layer.col_image, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, state.delta+i*layer.c*layer.h*layer.w); col2im_cpu(l.col_image, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
} }
} }
} }
void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay) void update_convolutional_layer(convolutional_layer l, int batch, float learning_rate, float momentum, float decay)
{ {
int size = layer.size*layer.size*layer.c*layer.n; int size = l.size*l.size*l.c*l.n;
axpy_cpu(layer.n, learning_rate/batch, layer.bias_updates, 1, layer.biases, 1); axpy_cpu(l.n, learning_rate/batch, l.bias_updates, 1, l.biases, 1);
scal_cpu(layer.n, momentum, layer.bias_updates, 1); scal_cpu(l.n, momentum, l.bias_updates, 1);
axpy_cpu(size, -decay*batch, layer.filters, 1, layer.filter_updates, 1); axpy_cpu(size, -decay*batch, l.filters, 1, l.filter_updates, 1);
axpy_cpu(size, learning_rate/batch, layer.filter_updates, 1, layer.filters, 1); axpy_cpu(size, learning_rate/batch, l.filter_updates, 1, l.filters, 1);
scal_cpu(size, momentum, layer.filter_updates, 1); scal_cpu(size, momentum, l.filter_updates, 1);
} }
image get_convolutional_filter(convolutional_layer layer, int i) image get_convolutional_filter(convolutional_layer l, int i)
{ {
int h = layer.size; int h = l.size;
int w = layer.size; int w = l.size;
int c = layer.c; int c = l.c;
return float_to_image(w,h,c,layer.filters+i*h*w*c); return float_to_image(w,h,c,l.filters+i*h*w*c);
} }
image *get_filters(convolutional_layer layer) image *get_filters(convolutional_layer l)
{ {
image *filters = calloc(layer.n, sizeof(image)); image *filters = calloc(l.n, sizeof(image));
int i; int i;
for(i = 0; i < layer.n; ++i){ for(i = 0; i < l.n; ++i){
filters[i] = copy_image(get_convolutional_filter(layer, i)); filters[i] = copy_image(get_convolutional_filter(l, i));
} }
return filters; return filters;
} }
image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters) image *visualize_convolutional_layer(convolutional_layer l, char *window, image *prev_filters)
{ {
image *single_filters = get_filters(layer); image *single_filters = get_filters(l);
show_images(single_filters, layer.n, window); show_images(single_filters, l.n, window);
image delta = get_convolutional_image(layer); image delta = get_convolutional_image(l);
image dc = collapse_image_layers(delta, 1); image dc = collapse_image_layers(delta, 1);
char buff[256]; char buff[256];
sprintf(buff, "%s: Output", window); sprintf(buff, "%s: Output", window);

View File

@ -5,38 +5,9 @@
#include "params.h" #include "params.h"
#include "image.h" #include "image.h"
#include "activations.h" #include "activations.h"
#include "layer.h"
typedef struct { typedef layer convolutional_layer;
int batch;
int h,w,c;
int n;
int size;
int stride;
int pad;
float *filters;
float *filter_updates;
float *biases;
float *bias_updates;
float *col_image;
float *delta;
float *output;
#ifdef GPU
float * filters_gpu;
float * filter_updates_gpu;
float * biases_gpu;
float * bias_updates_gpu;
float * col_image_gpu;
float * delta_gpu;
float * output_gpu;
#endif
ACTIVATION activation;
} convolutional_layer;
#ifdef GPU #ifdef GPU
void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state); void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state);
@ -50,7 +21,7 @@ void bias_output_gpu(float *output, float *biases, int batch, int n, int size);
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size); void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size);
#endif #endif
convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation); convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation);
void resize_convolutional_layer(convolutional_layer *layer, int h, int w); void resize_convolutional_layer(convolutional_layer *layer, int h, int w);
void forward_convolutional_layer(const convolutional_layer layer, network_state state); void forward_convolutional_layer(const convolutional_layer layer, network_state state);
void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay); void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay);

View File

@ -26,70 +26,73 @@ char *get_cost_string(COST_TYPE a)
return "sse"; return "sse";
} }
cost_layer *make_cost_layer(int batch, int inputs, COST_TYPE type) cost_layer make_cost_layer(int batch, int inputs, COST_TYPE cost_type)
{ {
fprintf(stderr, "Cost Layer: %d inputs\n", inputs); fprintf(stderr, "Cost Layer: %d inputs\n", inputs);
cost_layer *layer = calloc(1, sizeof(cost_layer)); cost_layer l = {0};
layer->batch = batch; l.type = COST;
layer->inputs = inputs;
layer->type = type; l.batch = batch;
layer->delta = calloc(inputs*batch, sizeof(float)); l.inputs = inputs;
layer->output = calloc(1, sizeof(float)); l.outputs = inputs;
l.cost_type = cost_type;
l.delta = calloc(inputs*batch, sizeof(float));
l.output = calloc(1, sizeof(float));
#ifdef GPU #ifdef GPU
layer->delta_gpu = cuda_make_array(layer->delta, inputs*batch); l.delta_gpu = cuda_make_array(l.delta, inputs*batch);
#endif #endif
return layer; return l;
} }
void forward_cost_layer(cost_layer layer, network_state state) void forward_cost_layer(cost_layer l, network_state state)
{ {
if (!state.truth) return; if (!state.truth) return;
if(layer.type == MASKED){ if(l.cost_type == MASKED){
int i; int i;
for(i = 0; i < layer.batch*layer.inputs; ++i){ for(i = 0; i < l.batch*l.inputs; ++i){
if(state.truth[i] == 0) state.input[i] = 0; if(state.truth[i] == 0) state.input[i] = 0;
} }
} }
copy_cpu(layer.batch*layer.inputs, state.truth, 1, layer.delta, 1); copy_cpu(l.batch*l.inputs, state.truth, 1, l.delta, 1);
axpy_cpu(layer.batch*layer.inputs, -1, state.input, 1, layer.delta, 1); axpy_cpu(l.batch*l.inputs, -1, state.input, 1, l.delta, 1);
*(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1); *(l.output) = dot_cpu(l.batch*l.inputs, l.delta, 1, l.delta, 1);
//printf("cost: %f\n", *layer.output); //printf("cost: %f\n", *l.output);
} }
void backward_cost_layer(const cost_layer layer, network_state state) void backward_cost_layer(const cost_layer l, network_state state)
{ {
copy_cpu(layer.batch*layer.inputs, layer.delta, 1, state.delta, 1); copy_cpu(l.batch*l.inputs, l.delta, 1, state.delta, 1);
} }
#ifdef GPU #ifdef GPU
void pull_cost_layer(cost_layer layer) void pull_cost_layer(cost_layer l)
{ {
cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs); cuda_pull_array(l.delta_gpu, l.delta, l.batch*l.inputs);
} }
void push_cost_layer(cost_layer layer) void push_cost_layer(cost_layer l)
{ {
cuda_push_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs); cuda_push_array(l.delta_gpu, l.delta, l.batch*l.inputs);
} }
void forward_cost_layer_gpu(cost_layer layer, network_state state) void forward_cost_layer_gpu(cost_layer l, network_state state)
{ {
if (!state.truth) return; if (!state.truth) return;
if (layer.type == MASKED) { if (l.cost_type == MASKED) {
mask_ongpu(layer.batch*layer.inputs, state.input, state.truth); mask_ongpu(l.batch*l.inputs, state.input, state.truth);
} }
copy_ongpu(layer.batch*layer.inputs, state.truth, 1, layer.delta_gpu, 1); copy_ongpu(l.batch*l.inputs, state.truth, 1, l.delta_gpu, 1);
axpy_ongpu(layer.batch*layer.inputs, -1, state.input, 1, layer.delta_gpu, 1); axpy_ongpu(l.batch*l.inputs, -1, state.input, 1, l.delta_gpu, 1);
cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs); cuda_pull_array(l.delta_gpu, l.delta, l.batch*l.inputs);
*(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1); *(l.output) = dot_cpu(l.batch*l.inputs, l.delta, 1, l.delta, 1);
} }
void backward_cost_layer_gpu(const cost_layer layer, network_state state) void backward_cost_layer_gpu(const cost_layer l, network_state state)
{ {
copy_ongpu(layer.batch*layer.inputs, layer.delta_gpu, 1, state.delta, 1); copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1);
} }
#endif #endif

View File

@ -1,33 +1,19 @@
#ifndef COST_LAYER_H #ifndef COST_LAYER_H
#define COST_LAYER_H #define COST_LAYER_H
#include "params.h" #include "params.h"
#include "layer.h"
typedef enum{ typedef layer cost_layer;
SSE, MASKED
} COST_TYPE;
typedef struct {
int inputs;
int batch;
int coords;
int classes;
float *delta;
float *output;
COST_TYPE type;
#ifdef GPU
float * delta_gpu;
#endif
} cost_layer;
COST_TYPE get_cost_type(char *s); COST_TYPE get_cost_type(char *s);
char *get_cost_string(COST_TYPE a); char *get_cost_string(COST_TYPE a);
cost_layer *make_cost_layer(int batch, int inputs, COST_TYPE type); cost_layer make_cost_layer(int batch, int inputs, COST_TYPE type);
void forward_cost_layer(const cost_layer layer, network_state state); void forward_cost_layer(const cost_layer l, network_state state);
void backward_cost_layer(const cost_layer layer, network_state state); void backward_cost_layer(const cost_layer l, network_state state);
#ifdef GPU #ifdef GPU
void forward_cost_layer_gpu(cost_layer layer, network_state state); void forward_cost_layer_gpu(cost_layer l, network_state state);
void backward_cost_layer_gpu(const cost_layer layer, network_state state); void backward_cost_layer_gpu(const cost_layer l, network_state state);
#endif #endif
#endif #endif

View File

@ -2,63 +2,69 @@
#include "cuda.h" #include "cuda.h"
#include <stdio.h> #include <stdio.h>
image get_crop_image(crop_layer layer) image get_crop_image(crop_layer l)
{ {
int h = layer.crop_height; int h = l.out_h;
int w = layer.crop_width; int w = l.out_w;
int c = layer.c; int c = l.out_c;
return float_to_image(w,h,c,layer.output); return float_to_image(w,h,c,l.output);
} }
crop_layer *make_crop_layer(int batch, int h, int w, int c, int crop_height, int crop_width, int flip, float angle, float saturation, float exposure) crop_layer make_crop_layer(int batch, int h, int w, int c, int crop_height, int crop_width, int flip, float angle, float saturation, float exposure)
{ {
fprintf(stderr, "Crop Layer: %d x %d -> %d x %d x %d image\n", h,w,crop_height,crop_width,c); fprintf(stderr, "Crop Layer: %d x %d -> %d x %d x %d image\n", h,w,crop_height,crop_width,c);
crop_layer *layer = calloc(1, sizeof(crop_layer)); crop_layer l = {0};
layer->batch = batch; l.type = CROP;
layer->h = h; l.batch = batch;
layer->w = w; l.h = h;
layer->c = c; l.w = w;
layer->flip = flip; l.c = c;
layer->angle = angle; l.flip = flip;
layer->saturation = saturation; l.angle = angle;
layer->exposure = exposure; l.saturation = saturation;
layer->crop_width = crop_width; l.exposure = exposure;
layer->crop_height = crop_height; l.crop_width = crop_width;
layer->output = calloc(crop_width*crop_height * c*batch, sizeof(float)); l.crop_height = crop_height;
l.out_w = crop_width;
l.out_h = crop_height;
l.out_c = c;
l.inputs = l.w * l.h * l.c;
l.outputs = l.out_w * l.out_h * l.out_c;
l.output = calloc(crop_width*crop_height * c*batch, sizeof(float));
#ifdef GPU #ifdef GPU
layer->output_gpu = cuda_make_array(layer->output, crop_width*crop_height*c*batch); l.output_gpu = cuda_make_array(l.output, crop_width*crop_height*c*batch);
layer->rand_gpu = cuda_make_array(0, layer->batch*8); l.rand_gpu = cuda_make_array(0, l.batch*8);
#endif #endif
return layer; return l;
} }
void forward_crop_layer(const crop_layer layer, network_state state) void forward_crop_layer(const crop_layer l, network_state state)
{ {
int i,j,c,b,row,col; int i,j,c,b,row,col;
int index; int index;
int count = 0; int count = 0;
int flip = (layer.flip && rand()%2); int flip = (l.flip && rand()%2);
int dh = rand()%(layer.h - layer.crop_height + 1); int dh = rand()%(l.h - l.crop_height + 1);
int dw = rand()%(layer.w - layer.crop_width + 1); int dw = rand()%(l.w - l.crop_width + 1);
float scale = 2; float scale = 2;
float trans = -1; float trans = -1;
if(!state.train){ if(!state.train){
flip = 0; flip = 0;
dh = (layer.h - layer.crop_height)/2; dh = (l.h - l.crop_height)/2;
dw = (layer.w - layer.crop_width)/2; dw = (l.w - l.crop_width)/2;
} }
for(b = 0; b < layer.batch; ++b){ for(b = 0; b < l.batch; ++b){
for(c = 0; c < layer.c; ++c){ for(c = 0; c < l.c; ++c){
for(i = 0; i < layer.crop_height; ++i){ for(i = 0; i < l.crop_height; ++i){
for(j = 0; j < layer.crop_width; ++j){ for(j = 0; j < l.crop_width; ++j){
if(flip){ if(flip){
col = layer.w - dw - j - 1; col = l.w - dw - j - 1;
}else{ }else{
col = j + dw; col = j + dw;
} }
row = i + dh; row = i + dh;
index = col+layer.w*(row+layer.h*(c + layer.c*b)); index = col+l.w*(row+l.h*(c + l.c*b));
layer.output[count++] = state.input[index]*scale + trans; l.output[count++] = state.input[index]*scale + trans;
} }
} }
} }

View File

@ -3,29 +3,16 @@
#include "image.h" #include "image.h"
#include "params.h" #include "params.h"
#include "layer.h"
typedef struct { typedef layer crop_layer;
int batch;
int h,w,c;
int crop_width;
int crop_height;
int flip;
float angle;
float saturation;
float exposure;
float *output;
#ifdef GPU
float *output_gpu;
float *rand_gpu;
#endif
} crop_layer;
image get_crop_image(crop_layer layer); image get_crop_image(crop_layer l);
crop_layer *make_crop_layer(int batch, int h, int w, int c, int crop_height, int crop_width, int flip, float angle, float saturation, float exposure); crop_layer make_crop_layer(int batch, int h, int w, int c, int crop_height, int crop_width, int flip, float angle, float saturation, float exposure);
void forward_crop_layer(const crop_layer layer, network_state state); void forward_crop_layer(const crop_layer l, network_state state);
#ifdef GPU #ifdef GPU
void forward_crop_layer_gpu(crop_layer layer, network_state state); void forward_crop_layer_gpu(crop_layer l, network_state state);
#endif #endif
#endif #endif

View File

@ -72,15 +72,6 @@ void partial(char *cfgfile, char *weightfile, char *outfile, int max)
save_weights(net, outfile); save_weights(net, outfile);
} }
void convert(char *cfgfile, char *outfile, char *weightfile)
{
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
save_network(net, outfile);
}
void visualize(char *cfgfile, char *weightfile) void visualize(char *cfgfile, char *weightfile)
{ {
network net = parse_network_cfg(cfgfile); network net = parse_network_cfg(cfgfile);
@ -120,8 +111,6 @@ int main(int argc, char **argv)
run_captcha(argc, argv); run_captcha(argc, argv);
} else if (0 == strcmp(argv[1], "change")){ } else if (0 == strcmp(argv[1], "change")){
change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0); change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0);
} else if (0 == strcmp(argv[1], "convert")){
convert(argv[2], argv[3], (argc > 4) ? argv[4] : 0);
} else if (0 == strcmp(argv[1], "partial")){ } else if (0 == strcmp(argv[1], "partial")){
partial(argv[2], argv[3], argv[4], atoi(argv[5])); partial(argv[2], argv[3], argv[4], atoi(argv[5]));
} else if (0 == strcmp(argv[1], "visualize")){ } else if (0 == strcmp(argv[1], "visualize")){

View File

@ -174,7 +174,7 @@ void fill_truth_detection(char *path, float *truth, int classes, int num_boxes,
} }
int index = (i+j*num_boxes)*(4+classes+background); int index = (i+j*num_boxes)*(4+classes+background);
if(truth[index+classes+background+2]) continue; //if(truth[index+classes+background+2]) continue;
if(background) truth[index++] = 0; if(background) truth[index++] = 0;
truth[index+id] = 1; truth[index+id] = 1;
index += classes; index += classes;

View File

@ -8,172 +8,179 @@
#include <stdio.h> #include <stdio.h>
#include <time.h> #include <time.h>
int deconvolutional_out_height(deconvolutional_layer layer) int deconvolutional_out_height(deconvolutional_layer l)
{ {
int h = layer.stride*(layer.h - 1) + layer.size; int h = l.stride*(l.h - 1) + l.size;
return h; return h;
} }
int deconvolutional_out_width(deconvolutional_layer layer) int deconvolutional_out_width(deconvolutional_layer l)
{ {
int w = layer.stride*(layer.w - 1) + layer.size; int w = l.stride*(l.w - 1) + l.size;
return w; return w;
} }
int deconvolutional_out_size(deconvolutional_layer layer) int deconvolutional_out_size(deconvolutional_layer l)
{ {
return deconvolutional_out_height(layer) * deconvolutional_out_width(layer); return deconvolutional_out_height(l) * deconvolutional_out_width(l);
} }
image get_deconvolutional_image(deconvolutional_layer layer) image get_deconvolutional_image(deconvolutional_layer l)
{ {
int h,w,c; int h,w,c;
h = deconvolutional_out_height(layer); h = deconvolutional_out_height(l);
w = deconvolutional_out_width(layer); w = deconvolutional_out_width(l);
c = layer.n; c = l.n;
return float_to_image(w,h,c,layer.output); return float_to_image(w,h,c,l.output);
} }
image get_deconvolutional_delta(deconvolutional_layer layer) image get_deconvolutional_delta(deconvolutional_layer l)
{ {
int h,w,c; int h,w,c;
h = deconvolutional_out_height(layer); h = deconvolutional_out_height(l);
w = deconvolutional_out_width(layer); w = deconvolutional_out_width(l);
c = layer.n; c = l.n;
return float_to_image(w,h,c,layer.delta); return float_to_image(w,h,c,l.delta);
} }
deconvolutional_layer *make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation) deconvolutional_layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation)
{ {
int i; int i;
deconvolutional_layer *layer = calloc(1, sizeof(deconvolutional_layer)); deconvolutional_layer l = {0};
l.type = DECONVOLUTIONAL;
layer->h = h; l.h = h;
layer->w = w; l.w = w;
layer->c = c; l.c = c;
layer->n = n; l.n = n;
layer->batch = batch; l.batch = batch;
layer->stride = stride; l.stride = stride;
layer->size = size; l.size = size;
layer->filters = calloc(c*n*size*size, sizeof(float)); l.filters = calloc(c*n*size*size, sizeof(float));
layer->filter_updates = calloc(c*n*size*size, sizeof(float)); l.filter_updates = calloc(c*n*size*size, sizeof(float));
layer->biases = calloc(n, sizeof(float)); l.biases = calloc(n, sizeof(float));
layer->bias_updates = calloc(n, sizeof(float)); l.bias_updates = calloc(n, sizeof(float));
float scale = 1./sqrt(size*size*c); float scale = 1./sqrt(size*size*c);
for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*rand_normal(); for(i = 0; i < c*n*size*size; ++i) l.filters[i] = scale*rand_normal();
for(i = 0; i < n; ++i){ for(i = 0; i < n; ++i){
layer->biases[i] = scale; l.biases[i] = scale;
} }
int out_h = deconvolutional_out_height(*layer); int out_h = deconvolutional_out_height(l);
int out_w = deconvolutional_out_width(*layer); int out_w = deconvolutional_out_width(l);
layer->col_image = calloc(h*w*size*size*n, sizeof(float)); l.out_h = out_h;
layer->output = calloc(layer->batch*out_h * out_w * n, sizeof(float)); l.out_w = out_w;
layer->delta = calloc(layer->batch*out_h * out_w * n, sizeof(float)); l.out_c = n;
l.outputs = l.out_w * l.out_h * l.out_c;
l.inputs = l.w * l.h * l.c;
l.col_image = calloc(h*w*size*size*n, sizeof(float));
l.output = calloc(l.batch*out_h * out_w * n, sizeof(float));
l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float));
#ifdef GPU #ifdef GPU
layer->filters_gpu = cuda_make_array(layer->filters, c*n*size*size); l.filters_gpu = cuda_make_array(l.filters, c*n*size*size);
layer->filter_updates_gpu = cuda_make_array(layer->filter_updates, c*n*size*size); l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size);
layer->biases_gpu = cuda_make_array(layer->biases, n); l.biases_gpu = cuda_make_array(l.biases, n);
layer->bias_updates_gpu = cuda_make_array(layer->bias_updates, n); l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
layer->col_image_gpu = cuda_make_array(layer->col_image, h*w*size*size*n); l.col_image_gpu = cuda_make_array(l.col_image, h*w*size*size*n);
layer->delta_gpu = cuda_make_array(layer->delta, layer->batch*out_h*out_w*n); l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
layer->output_gpu = cuda_make_array(layer->output, layer->batch*out_h*out_w*n); l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
#endif #endif
layer->activation = activation; l.activation = activation;
fprintf(stderr, "Deconvolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n); fprintf(stderr, "Deconvolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
return layer; return l;
} }
void resize_deconvolutional_layer(deconvolutional_layer *layer, int h, int w) void resize_deconvolutional_layer(deconvolutional_layer *l, int h, int w)
{ {
layer->h = h; l->h = h;
layer->w = w; l->w = w;
int out_h = deconvolutional_out_height(*layer); int out_h = deconvolutional_out_height(*l);
int out_w = deconvolutional_out_width(*layer); int out_w = deconvolutional_out_width(*l);
layer->col_image = realloc(layer->col_image, l->col_image = realloc(l->col_image,
out_h*out_w*layer->size*layer->size*layer->c*sizeof(float)); out_h*out_w*l->size*l->size*l->c*sizeof(float));
layer->output = realloc(layer->output, l->output = realloc(l->output,
layer->batch*out_h * out_w * layer->n*sizeof(float)); l->batch*out_h * out_w * l->n*sizeof(float));
layer->delta = realloc(layer->delta, l->delta = realloc(l->delta,
layer->batch*out_h * out_w * layer->n*sizeof(float)); l->batch*out_h * out_w * l->n*sizeof(float));
#ifdef GPU #ifdef GPU
cuda_free(layer->col_image_gpu); cuda_free(l->col_image_gpu);
cuda_free(layer->delta_gpu); cuda_free(l->delta_gpu);
cuda_free(layer->output_gpu); cuda_free(l->output_gpu);
layer->col_image_gpu = cuda_make_array(layer->col_image, out_h*out_w*layer->size*layer->size*layer->c); l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c);
layer->delta_gpu = cuda_make_array(layer->delta, layer->batch*out_h*out_w*layer->n); l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
layer->output_gpu = cuda_make_array(layer->output, layer->batch*out_h*out_w*layer->n); l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
#endif #endif
} }
void forward_deconvolutional_layer(const deconvolutional_layer layer, network_state state) void forward_deconvolutional_layer(const deconvolutional_layer l, network_state state)
{ {
int i; int i;
int out_h = deconvolutional_out_height(layer); int out_h = deconvolutional_out_height(l);
int out_w = deconvolutional_out_width(layer); int out_w = deconvolutional_out_width(l);
int size = out_h*out_w; int size = out_h*out_w;
int m = layer.size*layer.size*layer.n; int m = l.size*l.size*l.n;
int n = layer.h*layer.w; int n = l.h*l.w;
int k = layer.c; int k = l.c;
bias_output(layer.output, layer.biases, layer.batch, layer.n, size); bias_output(l.output, l.biases, l.batch, l.n, size);
for(i = 0; i < layer.batch; ++i){ for(i = 0; i < l.batch; ++i){
float *a = layer.filters; float *a = l.filters;
float *b = state.input + i*layer.c*layer.h*layer.w; float *b = state.input + i*l.c*l.h*l.w;
float *c = layer.col_image; float *c = l.col_image;
gemm(1,0,m,n,k,1,a,m,b,n,0,c,n); gemm(1,0,m,n,k,1,a,m,b,n,0,c,n);
col2im_cpu(c, layer.n, out_h, out_w, layer.size, layer.stride, 0, layer.output+i*layer.n*size); col2im_cpu(c, l.n, out_h, out_w, l.size, l.stride, 0, l.output+i*l.n*size);
} }
activate_array(layer.output, layer.batch*layer.n*size, layer.activation); activate_array(l.output, l.batch*l.n*size, l.activation);
} }
void backward_deconvolutional_layer(deconvolutional_layer layer, network_state state) void backward_deconvolutional_layer(deconvolutional_layer l, network_state state)
{ {
float alpha = 1./layer.batch; float alpha = 1./l.batch;
int out_h = deconvolutional_out_height(layer); int out_h = deconvolutional_out_height(l);
int out_w = deconvolutional_out_width(layer); int out_w = deconvolutional_out_width(l);
int size = out_h*out_w; int size = out_h*out_w;
int i; int i;
gradient_array(layer.output, size*layer.n*layer.batch, layer.activation, layer.delta); gradient_array(l.output, size*l.n*l.batch, l.activation, l.delta);
backward_bias(layer.bias_updates, layer.delta, layer.batch, layer.n, size); backward_bias(l.bias_updates, l.delta, l.batch, l.n, size);
if(state.delta) memset(state.delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float)); if(state.delta) memset(state.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float));
for(i = 0; i < layer.batch; ++i){ for(i = 0; i < l.batch; ++i){
int m = layer.c; int m = l.c;
int n = layer.size*layer.size*layer.n; int n = l.size*l.size*l.n;
int k = layer.h*layer.w; int k = l.h*l.w;
float *a = state.input + i*m*n; float *a = state.input + i*m*n;
float *b = layer.col_image; float *b = l.col_image;
float *c = layer.filter_updates; float *c = l.filter_updates;
im2col_cpu(layer.delta + i*layer.n*size, layer.n, out_h, out_w, im2col_cpu(l.delta + i*l.n*size, l.n, out_h, out_w,
layer.size, layer.stride, 0, b); l.size, l.stride, 0, b);
gemm(0,1,m,n,k,alpha,a,k,b,k,1,c,n); gemm(0,1,m,n,k,alpha,a,k,b,k,1,c,n);
if(state.delta){ if(state.delta){
int m = layer.c; int m = l.c;
int n = layer.h*layer.w; int n = l.h*l.w;
int k = layer.size*layer.size*layer.n; int k = l.size*l.size*l.n;
float *a = layer.filters; float *a = l.filters;
float *b = layer.col_image; float *b = l.col_image;
float *c = state.delta + i*n*m; float *c = state.delta + i*n*m;
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
@ -181,15 +188,15 @@ void backward_deconvolutional_layer(deconvolutional_layer layer, network_state s
} }
} }
void update_deconvolutional_layer(deconvolutional_layer layer, float learning_rate, float momentum, float decay) void update_deconvolutional_layer(deconvolutional_layer l, float learning_rate, float momentum, float decay)
{ {
int size = layer.size*layer.size*layer.c*layer.n; int size = l.size*l.size*l.c*l.n;
axpy_cpu(layer.n, learning_rate, layer.bias_updates, 1, layer.biases, 1); axpy_cpu(l.n, learning_rate, l.bias_updates, 1, l.biases, 1);
scal_cpu(layer.n, momentum, layer.bias_updates, 1); scal_cpu(l.n, momentum, l.bias_updates, 1);
axpy_cpu(size, -decay, layer.filters, 1, layer.filter_updates, 1); axpy_cpu(size, -decay, l.filters, 1, l.filter_updates, 1);
axpy_cpu(size, learning_rate, layer.filter_updates, 1, layer.filters, 1); axpy_cpu(size, learning_rate, l.filter_updates, 1, l.filters, 1);
scal_cpu(size, momentum, layer.filter_updates, 1); scal_cpu(size, momentum, l.filter_updates, 1);
} }

View File

@ -5,37 +5,9 @@
#include "params.h" #include "params.h"
#include "image.h" #include "image.h"
#include "activations.h" #include "activations.h"
#include "layer.h"
typedef struct { typedef layer deconvolutional_layer;
int batch;
int h,w,c;
int n;
int size;
int stride;
float *filters;
float *filter_updates;
float *biases;
float *bias_updates;
float *col_image;
float *delta;
float *output;
#ifdef GPU
float * filters_gpu;
float * filter_updates_gpu;
float * biases_gpu;
float * bias_updates_gpu;
float * col_image_gpu;
float * delta_gpu;
float * output_gpu;
#endif
ACTIVATION activation;
} deconvolutional_layer;
#ifdef GPU #ifdef GPU
void forward_deconvolutional_layer_gpu(deconvolutional_layer layer, network_state state); void forward_deconvolutional_layer_gpu(deconvolutional_layer layer, network_state state);
@ -45,7 +17,7 @@ void push_deconvolutional_layer(deconvolutional_layer layer);
void pull_deconvolutional_layer(deconvolutional_layer layer); void pull_deconvolutional_layer(deconvolutional_layer layer);
#endif #endif
deconvolutional_layer *make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation); deconvolutional_layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation);
void resize_deconvolutional_layer(deconvolutional_layer *layer, int h, int w); void resize_deconvolutional_layer(deconvolutional_layer *layer, int h, int w);
void forward_deconvolutional_layer(const deconvolutional_layer layer, network_state state); void forward_deconvolutional_layer(const deconvolutional_layer layer, network_state state);
void update_deconvolutional_layer(deconvolutional_layer layer, float learning_rate, float momentum, float decay); void update_deconvolutional_layer(deconvolutional_layer layer, float learning_rate, float momentum, float decay);

View File

@ -115,6 +115,7 @@ void train_localization(char *cfgfile, char *weightfile)
time=clock(); time=clock();
float loss = train_network(net, train); float loss = train_network(net, train);
//TODO
float *out = get_network_output_gpu(net); float *out = get_network_output_gpu(net);
image im = float_to_image(net.w, net.h, 3, train.X.vals[127]); image im = float_to_image(net.w, net.h, 3, train.X.vals[127]);
image copy = copy_image(im); image copy = copy_image(im);
@ -149,7 +150,7 @@ void train_detection_teststuff(char *cfgfile, char *weightfile)
if(weightfile){ if(weightfile){
load_weights(&net, weightfile); load_weights(&net, weightfile);
} }
detection_layer *layer = get_network_detection_layer(net); detection_layer layer = get_network_detection_layer(net);
net.learning_rate = 0; net.learning_rate = 0;
net.decay = 0; net.decay = 0;
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
@ -157,9 +158,9 @@ void train_detection_teststuff(char *cfgfile, char *weightfile)
int i = net.seen/imgs; int i = net.seen/imgs;
data train, buffer; data train, buffer;
int classes = layer->classes; int classes = layer.classes;
int background = layer->background; int background = layer.background;
int side = sqrt(get_detection_layer_locations(*layer)); int side = sqrt(get_detection_layer_locations(layer));
char **paths; char **paths;
list *plist; list *plist;
@ -174,7 +175,7 @@ void train_detection_teststuff(char *cfgfile, char *weightfile)
paths = (char **)list_to_array(plist); paths = (char **)list_to_array(plist);
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer); pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
clock_t time; clock_t time;
cost_layer clayer = *((cost_layer *)net.layers[net.n-1]); cost_layer clayer = net.layers[net.n-1];
while(1){ while(1){
i += 1; i += 1;
time=clock(); time=clock();
@ -235,15 +236,15 @@ void train_detection(char *cfgfile, char *weightfile)
if(weightfile){ if(weightfile){
load_weights(&net, weightfile); load_weights(&net, weightfile);
} }
detection_layer *layer = get_network_detection_layer(net); detection_layer layer = get_network_detection_layer(net);
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = 128; int imgs = 128;
int i = net.seen/imgs; int i = net.seen/imgs;
data train, buffer; data train, buffer;
int classes = layer->classes; int classes = layer.classes;
int background = layer->background; int background = layer.background;
int side = sqrt(get_detection_layer_locations(*layer)); int side = sqrt(get_detection_layer_locations(layer));
char **paths; char **paths;
list *plist; list *plist;
@ -325,7 +326,7 @@ void validate_detection(char *cfgfile, char *weightfile)
if(weightfile){ if(weightfile){
load_weights(&net, weightfile); load_weights(&net, weightfile);
} }
detection_layer *layer = get_network_detection_layer(net); detection_layer layer = get_network_detection_layer(net);
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
srand(time(0)); srand(time(0));
@ -336,10 +337,10 @@ void validate_detection(char *cfgfile, char *weightfile)
//list *plist = get_paths("/home/pjreddie/data/voc/train.txt"); //list *plist = get_paths("/home/pjreddie/data/voc/train.txt");
char **paths = (char **)list_to_array(plist); char **paths = (char **)list_to_array(plist);
int classes = layer->classes; int classes = layer.classes;
int nuisance = layer->nuisance; int nuisance = layer.nuisance;
int background = (layer->background && !nuisance); int background = (layer.background && !nuisance);
int num_boxes = sqrt(get_detection_layer_locations(*layer)); int num_boxes = sqrt(get_detection_layer_locations(layer));
int per_box = 4+classes+background+nuisance; int per_box = 4+classes+background+nuisance;
int num_output = num_boxes*num_boxes*per_box; int num_output = num_boxes*num_boxes*per_box;
@ -393,7 +394,7 @@ void validate_detection_post(char *cfgfile, char *weightfile)
load_weights(&post, "/home/pjreddie/imagenet_backup/localize_1000.weights"); load_weights(&post, "/home/pjreddie/imagenet_backup/localize_1000.weights");
set_batch_network(&post, 1); set_batch_network(&post, 1);
detection_layer *layer = get_network_detection_layer(net); detection_layer layer = get_network_detection_layer(net);
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
srand(time(0)); srand(time(0));
@ -404,10 +405,10 @@ void validate_detection_post(char *cfgfile, char *weightfile)
//list *plist = get_paths("/home/pjreddie/data/voc/train.txt"); //list *plist = get_paths("/home/pjreddie/data/voc/train.txt");
char **paths = (char **)list_to_array(plist); char **paths = (char **)list_to_array(plist);
int classes = layer->classes; int classes = layer.classes;
int nuisance = layer->nuisance; int nuisance = layer.nuisance;
int background = (layer->background && !nuisance); int background = (layer.background && !nuisance);
int num_boxes = sqrt(get_detection_layer_locations(*layer)); int num_boxes = sqrt(get_detection_layer_locations(layer));
int per_box = 4+classes+background+nuisance; int per_box = 4+classes+background+nuisance;

View File

@ -8,47 +8,49 @@
#include <string.h> #include <string.h>
#include <stdlib.h> #include <stdlib.h>
int get_detection_layer_locations(detection_layer layer) int get_detection_layer_locations(detection_layer l)
{ {
return layer.inputs / (layer.classes+layer.coords+layer.rescore+layer.background); return l.inputs / (l.classes+l.coords+l.rescore+l.background);
} }
int get_detection_layer_output_size(detection_layer layer) int get_detection_layer_output_size(detection_layer l)
{ {
return get_detection_layer_locations(layer)*(layer.background + layer.classes + layer.coords); return get_detection_layer_locations(l)*(l.background + l.classes + l.coords);
} }
detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background, int nuisance) detection_layer make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background, int nuisance)
{ {
detection_layer *layer = calloc(1, sizeof(detection_layer)); detection_layer l = {0};
l.type = DETECTION;
layer->batch = batch; l.batch = batch;
layer->inputs = inputs; l.inputs = inputs;
layer->classes = classes; l.classes = classes;
layer->coords = coords; l.coords = coords;
layer->rescore = rescore; l.rescore = rescore;
layer->nuisance = nuisance; l.nuisance = nuisance;
layer->cost = calloc(1, sizeof(float)); l.cost = calloc(1, sizeof(float));
layer->does_cost=1; l.does_cost=1;
layer->background = background; l.background = background;
int outputs = get_detection_layer_output_size(*layer); int outputs = get_detection_layer_output_size(l);
layer->output = calloc(batch*outputs, sizeof(float)); l.outputs = outputs;
layer->delta = calloc(batch*outputs, sizeof(float)); l.output = calloc(batch*outputs, sizeof(float));
l.delta = calloc(batch*outputs, sizeof(float));
#ifdef GPU #ifdef GPU
layer->output_gpu = cuda_make_array(0, batch*outputs); l.output_gpu = cuda_make_array(0, batch*outputs);
layer->delta_gpu = cuda_make_array(0, batch*outputs); l.delta_gpu = cuda_make_array(0, batch*outputs);
#endif #endif
fprintf(stderr, "Detection Layer\n"); fprintf(stderr, "Detection Layer\n");
srand(0); srand(0);
return layer; return l;
} }
void dark_zone(detection_layer layer, int class, int start, network_state state) void dark_zone(detection_layer l, int class, int start, network_state state)
{ {
int index = start+layer.background+class; int index = start+l.background+class;
int size = layer.classes+layer.coords+layer.background; int size = l.classes+l.coords+l.background;
int location = (index%(7*7*size)) / size ; int location = (index%(7*7*size)) / size ;
int r = location / 7; int r = location / 7;
int c = location % 7; int c = location % 7;
@ -60,9 +62,9 @@ void dark_zone(detection_layer layer, int class, int start, network_state state)
if((c + dc) > 6 || (c + dc) < 0) continue; if((c + dc) > 6 || (c + dc) < 0) continue;
int di = (dr*7 + dc) * size; int di = (dr*7 + dc) * size;
if(state.truth[index+di]) continue; if(state.truth[index+di]) continue;
layer.output[index + di] = 0; l.output[index + di] = 0;
//if(!state.truth[start+di]) continue; //if(!state.truth[start+di]) continue;
//layer.output[start + di] = 1; //l.output[start + di] = 1;
} }
} }
} }
@ -299,47 +301,47 @@ dbox diou(box a, box b)
return dd; return dd;
} }
void forward_detection_layer(const detection_layer layer, network_state state) void forward_detection_layer(const detection_layer l, network_state state)
{ {
int in_i = 0; int in_i = 0;
int out_i = 0; int out_i = 0;
int locations = get_detection_layer_locations(layer); int locations = get_detection_layer_locations(l);
int i,j; int i,j;
for(i = 0; i < layer.batch*locations; ++i){ for(i = 0; i < l.batch*locations; ++i){
int mask = (!state.truth || state.truth[out_i + layer.background + layer.classes + 2]); int mask = (!state.truth || state.truth[out_i + l.background + l.classes + 2]);
float scale = 1; float scale = 1;
if(layer.rescore) scale = state.input[in_i++]; if(l.rescore) scale = state.input[in_i++];
else if(layer.nuisance){ else if(l.nuisance){
layer.output[out_i++] = 1-state.input[in_i++]; l.output[out_i++] = 1-state.input[in_i++];
scale = mask; scale = mask;
} }
else if(layer.background) layer.output[out_i++] = scale*state.input[in_i++]; else if(l.background) l.output[out_i++] = scale*state.input[in_i++];
for(j = 0; j < layer.classes; ++j){ for(j = 0; j < l.classes; ++j){
layer.output[out_i++] = scale*state.input[in_i++]; l.output[out_i++] = scale*state.input[in_i++];
} }
if(layer.nuisance){ if(l.nuisance){
}else if(layer.background){ }else if(l.background){
softmax_array(layer.output + out_i - layer.classes-layer.background, layer.classes+layer.background, layer.output + out_i - layer.classes-layer.background); softmax_array(l.output + out_i - l.classes-l.background, l.classes+l.background, l.output + out_i - l.classes-l.background);
activate_array(state.input+in_i, layer.coords, LOGISTIC); activate_array(state.input+in_i, l.coords, LOGISTIC);
} }
for(j = 0; j < layer.coords; ++j){ for(j = 0; j < l.coords; ++j){
layer.output[out_i++] = mask*state.input[in_i++]; l.output[out_i++] = mask*state.input[in_i++];
} }
} }
if(layer.does_cost && state.train && 0){ if(l.does_cost && state.train && 0){
int count = 0; int count = 0;
float avg = 0; float avg = 0;
*(layer.cost) = 0; *(l.cost) = 0;
int size = get_detection_layer_output_size(layer) * layer.batch; int size = get_detection_layer_output_size(l) * l.batch;
memset(layer.delta, 0, size * sizeof(float)); memset(l.delta, 0, size * sizeof(float));
for (i = 0; i < layer.batch*locations; ++i) { for (i = 0; i < l.batch*locations; ++i) {
int classes = layer.nuisance+layer.classes; int classes = l.nuisance+l.classes;
int offset = i*(classes+layer.coords); int offset = i*(classes+l.coords);
for (j = offset; j < offset+classes; ++j) { for (j = offset; j < offset+classes; ++j) {
*(layer.cost) += pow(state.truth[j] - layer.output[j], 2); *(l.cost) += pow(state.truth[j] - l.output[j], 2);
layer.delta[j] = state.truth[j] - layer.output[j]; l.delta[j] = state.truth[j] - l.output[j];
} }
box truth; box truth;
truth.x = state.truth[j+0]; truth.x = state.truth[j+0];
@ -347,17 +349,17 @@ void forward_detection_layer(const detection_layer layer, network_state state)
truth.w = state.truth[j+2]; truth.w = state.truth[j+2];
truth.h = state.truth[j+3]; truth.h = state.truth[j+3];
box out; box out;
out.x = layer.output[j+0]; out.x = l.output[j+0];
out.y = layer.output[j+1]; out.y = l.output[j+1];
out.w = layer.output[j+2]; out.w = l.output[j+2];
out.h = layer.output[j+3]; out.h = l.output[j+3];
if(!(truth.w*truth.h)) continue; if(!(truth.w*truth.h)) continue;
//printf("iou: %f\n", iou); //printf("iou: %f\n", iou);
dbox d = diou(out, truth); dbox d = diou(out, truth);
layer.delta[j+0] = d.dx; l.delta[j+0] = d.dx;
layer.delta[j+1] = d.dy; l.delta[j+1] = d.dy;
layer.delta[j+2] = d.dw; l.delta[j+2] = d.dw;
layer.delta[j+3] = d.dh; l.delta[j+3] = d.dh;
int sqr = 1; int sqr = 1;
if(sqr){ if(sqr){
@ -367,7 +369,7 @@ void forward_detection_layer(const detection_layer layer, network_state state)
out.h *= out.h; out.h *= out.h;
} }
float iou = box_iou(truth, out); float iou = box_iou(truth, out);
*(layer.cost) += pow((1-iou), 2); *(l.cost) += pow((1-iou), 2);
avg += iou; avg += iou;
++count; ++count;
} }
@ -375,24 +377,24 @@ void forward_detection_layer(const detection_layer layer, network_state state)
} }
/* /*
int count = 0; int count = 0;
for(i = 0; i < layer.batch*locations; ++i){ for(i = 0; i < l.batch*locations; ++i){
for(j = 0; j < layer.classes+layer.background; ++j){ for(j = 0; j < l.classes+l.background; ++j){
printf("%f, ", layer.output[count++]); printf("%f, ", l.output[count++]);
} }
printf("\n"); printf("\n");
for(j = 0; j < layer.coords; ++j){ for(j = 0; j < l.coords; ++j){
printf("%f, ", layer.output[count++]); printf("%f, ", l.output[count++]);
} }
printf("\n"); printf("\n");
} }
*/ */
/* /*
if(layer.background || 1){ if(l.background || 1){
for(i = 0; i < layer.batch*locations; ++i){ for(i = 0; i < l.batch*locations; ++i){
int index = i*(layer.classes+layer.coords+layer.background); int index = i*(l.classes+l.coords+l.background);
for(j= 0; j < layer.classes; ++j){ for(j= 0; j < l.classes; ++j){
if(state.truth[index+j+layer.background]){ if(state.truth[index+j+l.background]){
//dark_zone(layer, j, index, state); //dark_zone(l, j, index, state);
} }
} }
} }
@ -400,66 +402,66 @@ void forward_detection_layer(const detection_layer layer, network_state state)
*/ */
} }
void backward_detection_layer(const detection_layer layer, network_state state) void backward_detection_layer(const detection_layer l, network_state state)
{ {
int locations = get_detection_layer_locations(layer); int locations = get_detection_layer_locations(l);
int i,j; int i,j;
int in_i = 0; int in_i = 0;
int out_i = 0; int out_i = 0;
for(i = 0; i < layer.batch*locations; ++i){ for(i = 0; i < l.batch*locations; ++i){
float scale = 1; float scale = 1;
float latent_delta = 0; float latent_delta = 0;
if(layer.rescore) scale = state.input[in_i++]; if(l.rescore) scale = state.input[in_i++];
else if (layer.nuisance) state.delta[in_i++] = -layer.delta[out_i++]; else if (l.nuisance) state.delta[in_i++] = -l.delta[out_i++];
else if (layer.background) state.delta[in_i++] = scale*layer.delta[out_i++]; else if (l.background) state.delta[in_i++] = scale*l.delta[out_i++];
for(j = 0; j < layer.classes; ++j){ for(j = 0; j < l.classes; ++j){
latent_delta += state.input[in_i]*layer.delta[out_i]; latent_delta += state.input[in_i]*l.delta[out_i];
state.delta[in_i++] = scale*layer.delta[out_i++]; state.delta[in_i++] = scale*l.delta[out_i++];
} }
if (layer.nuisance) { if (l.nuisance) {
}else if (layer.background) gradient_array(layer.output + out_i, layer.coords, LOGISTIC, layer.delta + out_i); }else if (l.background) gradient_array(l.output + out_i, l.coords, LOGISTIC, l.delta + out_i);
for(j = 0; j < layer.coords; ++j){ for(j = 0; j < l.coords; ++j){
state.delta[in_i++] = layer.delta[out_i++]; state.delta[in_i++] = l.delta[out_i++];
} }
if(layer.rescore) state.delta[in_i-layer.coords-layer.classes-layer.rescore-layer.background] = latent_delta; if(l.rescore) state.delta[in_i-l.coords-l.classes-l.rescore-l.background] = latent_delta;
} }
} }
#ifdef GPU #ifdef GPU
void forward_detection_layer_gpu(const detection_layer layer, network_state state) void forward_detection_layer_gpu(const detection_layer l, network_state state)
{ {
int outputs = get_detection_layer_output_size(layer); int outputs = get_detection_layer_output_size(l);
float *in_cpu = calloc(layer.batch*layer.inputs, sizeof(float)); float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
float *truth_cpu = 0; float *truth_cpu = 0;
if(state.truth){ if(state.truth){
truth_cpu = calloc(layer.batch*outputs, sizeof(float)); truth_cpu = calloc(l.batch*outputs, sizeof(float));
cuda_pull_array(state.truth, truth_cpu, layer.batch*outputs); cuda_pull_array(state.truth, truth_cpu, l.batch*outputs);
} }
cuda_pull_array(state.input, in_cpu, layer.batch*layer.inputs); cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
network_state cpu_state; network_state cpu_state;
cpu_state.train = state.train; cpu_state.train = state.train;
cpu_state.truth = truth_cpu; cpu_state.truth = truth_cpu;
cpu_state.input = in_cpu; cpu_state.input = in_cpu;
forward_detection_layer(layer, cpu_state); forward_detection_layer(l, cpu_state);
cuda_push_array(layer.output_gpu, layer.output, layer.batch*outputs); cuda_push_array(l.output_gpu, l.output, l.batch*outputs);
cuda_push_array(layer.delta_gpu, layer.delta, layer.batch*outputs); cuda_push_array(l.delta_gpu, l.delta, l.batch*outputs);
free(cpu_state.input); free(cpu_state.input);
if(cpu_state.truth) free(cpu_state.truth); if(cpu_state.truth) free(cpu_state.truth);
} }
void backward_detection_layer_gpu(detection_layer layer, network_state state) void backward_detection_layer_gpu(detection_layer l, network_state state)
{ {
int outputs = get_detection_layer_output_size(layer); int outputs = get_detection_layer_output_size(l);
float *in_cpu = calloc(layer.batch*layer.inputs, sizeof(float)); float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
float *delta_cpu = calloc(layer.batch*layer.inputs, sizeof(float)); float *delta_cpu = calloc(l.batch*l.inputs, sizeof(float));
float *truth_cpu = 0; float *truth_cpu = 0;
if(state.truth){ if(state.truth){
truth_cpu = calloc(layer.batch*outputs, sizeof(float)); truth_cpu = calloc(l.batch*outputs, sizeof(float));
cuda_pull_array(state.truth, truth_cpu, layer.batch*outputs); cuda_pull_array(state.truth, truth_cpu, l.batch*outputs);
} }
network_state cpu_state; network_state cpu_state;
cpu_state.train = state.train; cpu_state.train = state.train;
@ -467,10 +469,10 @@ void backward_detection_layer_gpu(detection_layer layer, network_state state)
cpu_state.truth = truth_cpu; cpu_state.truth = truth_cpu;
cpu_state.delta = delta_cpu; cpu_state.delta = delta_cpu;
cuda_pull_array(state.input, in_cpu, layer.batch*layer.inputs); cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*outputs); cuda_pull_array(l.delta_gpu, l.delta, l.batch*outputs);
backward_detection_layer(layer, cpu_state); backward_detection_layer(l, cpu_state);
cuda_push_array(state.delta, delta_cpu, layer.batch*layer.inputs); cuda_push_array(state.delta, delta_cpu, l.batch*l.inputs);
free(in_cpu); free(in_cpu);
free(delta_cpu); free(delta_cpu);

View File

@ -2,34 +2,19 @@
#define DETECTION_LAYER_H #define DETECTION_LAYER_H
#include "params.h" #include "params.h"
#include "layer.h"
typedef struct { typedef layer detection_layer;
int batch;
int inputs;
int classes;
int coords;
int background;
int rescore;
int nuisance;
int does_cost;
float *cost;
float *output;
float *delta;
#ifdef GPU
float * output_gpu;
float * delta_gpu;
#endif
} detection_layer;
detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background, int nuisance); detection_layer make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background, int nuisance);
void forward_detection_layer(const detection_layer layer, network_state state); void forward_detection_layer(const detection_layer l, network_state state);
void backward_detection_layer(const detection_layer layer, network_state state); void backward_detection_layer(const detection_layer l, network_state state);
int get_detection_layer_output_size(detection_layer layer); int get_detection_layer_output_size(detection_layer l);
int get_detection_layer_locations(detection_layer layer); int get_detection_layer_locations(detection_layer l);
#ifdef GPU #ifdef GPU
void forward_detection_layer_gpu(const detection_layer layer, network_state state); void forward_detection_layer_gpu(const detection_layer l, network_state state);
void backward_detection_layer_gpu(detection_layer layer, network_state state); void backward_detection_layer_gpu(detection_layer l, network_state state);
#endif #endif
#endif #endif

View File

@ -5,51 +5,53 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdio.h> #include <stdio.h>
dropout_layer *make_dropout_layer(int batch, int inputs, float probability) dropout_layer make_dropout_layer(int batch, int inputs, float probability)
{ {
fprintf(stderr, "Dropout Layer: %d inputs, %f probability\n", inputs, probability); fprintf(stderr, "Dropout Layer: %d inputs, %f probability\n", inputs, probability);
dropout_layer *layer = calloc(1, sizeof(dropout_layer)); dropout_layer l = {0};
layer->probability = probability; l.type = DROPOUT;
layer->inputs = inputs; l.probability = probability;
layer->batch = batch; l.inputs = inputs;
layer->rand = calloc(inputs*batch, sizeof(float)); l.outputs = inputs;
layer->scale = 1./(1.-probability); l.batch = batch;
l.rand = calloc(inputs*batch, sizeof(float));
l.scale = 1./(1.-probability);
#ifdef GPU #ifdef GPU
layer->rand_gpu = cuda_make_array(layer->rand, inputs*batch); l.rand_gpu = cuda_make_array(l.rand, inputs*batch);
#endif #endif
return layer; return l;
} }
void resize_dropout_layer(dropout_layer *layer, int inputs) void resize_dropout_layer(dropout_layer *l, int inputs)
{ {
layer->rand = realloc(layer->rand, layer->inputs*layer->batch*sizeof(float)); l->rand = realloc(l->rand, l->inputs*l->batch*sizeof(float));
#ifdef GPU #ifdef GPU
cuda_free(layer->rand_gpu); cuda_free(l->rand_gpu);
layer->rand_gpu = cuda_make_array(layer->rand, inputs*layer->batch); l->rand_gpu = cuda_make_array(l->rand, inputs*l->batch);
#endif #endif
} }
void forward_dropout_layer(dropout_layer layer, network_state state) void forward_dropout_layer(dropout_layer l, network_state state)
{ {
int i; int i;
if (!state.train) return; if (!state.train) return;
for(i = 0; i < layer.batch * layer.inputs; ++i){ for(i = 0; i < l.batch * l.inputs; ++i){
float r = rand_uniform(); float r = rand_uniform();
layer.rand[i] = r; l.rand[i] = r;
if(r < layer.probability) state.input[i] = 0; if(r < l.probability) state.input[i] = 0;
else state.input[i] *= layer.scale; else state.input[i] *= l.scale;
} }
} }
void backward_dropout_layer(dropout_layer layer, network_state state) void backward_dropout_layer(dropout_layer l, network_state state)
{ {
int i; int i;
if(!state.delta) return; if(!state.delta) return;
for(i = 0; i < layer.batch * layer.inputs; ++i){ for(i = 0; i < l.batch * l.inputs; ++i){
float r = layer.rand[i]; float r = l.rand[i];
if(r < layer.probability) state.delta[i] = 0; if(r < l.probability) state.delta[i] = 0;
else state.delta[i] *= layer.scale; else state.delta[i] *= l.scale;
} }
} }

View File

@ -1,27 +1,20 @@
#ifndef DROPOUT_LAYER_H #ifndef DROPOUT_LAYER_H
#define DROPOUT_LAYER_H #define DROPOUT_LAYER_H
#include "params.h" #include "params.h"
#include "layer.h"
typedef struct{ typedef layer dropout_layer;
int batch;
int inputs;
float probability;
float scale;
float *rand;
#ifdef GPU
float * rand_gpu;
#endif
} dropout_layer;
dropout_layer *make_dropout_layer(int batch, int inputs, float probability); dropout_layer make_dropout_layer(int batch, int inputs, float probability);
void forward_dropout_layer(dropout_layer layer, network_state state); void forward_dropout_layer(dropout_layer l, network_state state);
void backward_dropout_layer(dropout_layer layer, network_state state); void backward_dropout_layer(dropout_layer l, network_state state);
void resize_dropout_layer(dropout_layer *layer, int inputs); void resize_dropout_layer(dropout_layer *l, int inputs);
#ifdef GPU #ifdef GPU
void forward_dropout_layer_gpu(dropout_layer layer, network_state state); void forward_dropout_layer_gpu(dropout_layer l, network_state state);
void backward_dropout_layer_gpu(dropout_layer layer, network_state state); void backward_dropout_layer_gpu(dropout_layer l, network_state state);
#endif #endif
#endif #endif

View File

@ -2,109 +2,115 @@
#include "cuda.h" #include "cuda.h"
#include <stdio.h> #include <stdio.h>
image get_maxpool_image(maxpool_layer layer) image get_maxpool_image(maxpool_layer l)
{ {
int h = (layer.h-1)/layer.stride + 1; int h = (l.h-1)/l.stride + 1;
int w = (layer.w-1)/layer.stride + 1; int w = (l.w-1)/l.stride + 1;
int c = layer.c; int c = l.c;
return float_to_image(w,h,c,layer.output); return float_to_image(w,h,c,l.output);
} }
image get_maxpool_delta(maxpool_layer layer) image get_maxpool_delta(maxpool_layer l)
{ {
int h = (layer.h-1)/layer.stride + 1; int h = (l.h-1)/l.stride + 1;
int w = (layer.w-1)/layer.stride + 1; int w = (l.w-1)/l.stride + 1;
int c = layer.c; int c = l.c;
return float_to_image(w,h,c,layer.delta); return float_to_image(w,h,c,l.delta);
} }
maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int size, int stride) maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride)
{ {
fprintf(stderr, "Maxpool Layer: %d x %d x %d image, %d size, %d stride\n", h,w,c,size,stride); fprintf(stderr, "Maxpool Layer: %d x %d x %d image, %d size, %d stride\n", h,w,c,size,stride);
maxpool_layer *layer = calloc(1, sizeof(maxpool_layer)); maxpool_layer l = {0};
layer->batch = batch; l.type = MAXPOOL;
layer->h = h; l.batch = batch;
layer->w = w; l.h = h;
layer->c = c; l.w = w;
layer->size = size; l.c = c;
layer->stride = stride; l.out_h = h;
l.out_w = w;
l.out_c = c;
l.outputs = l.out_h * l.out_w * l.out_c;
l.inputs = l.outputs;
l.size = size;
l.stride = stride;
int output_size = ((h-1)/stride+1) * ((w-1)/stride+1) * c * batch; int output_size = ((h-1)/stride+1) * ((w-1)/stride+1) * c * batch;
layer->indexes = calloc(output_size, sizeof(int)); l.indexes = calloc(output_size, sizeof(int));
layer->output = calloc(output_size, sizeof(float)); l.output = calloc(output_size, sizeof(float));
layer->delta = calloc(output_size, sizeof(float)); l.delta = calloc(output_size, sizeof(float));
#ifdef GPU #ifdef GPU
layer->indexes_gpu = cuda_make_int_array(output_size); l.indexes_gpu = cuda_make_int_array(output_size);
layer->output_gpu = cuda_make_array(layer->output, output_size); l.output_gpu = cuda_make_array(l.output, output_size);
layer->delta_gpu = cuda_make_array(layer->delta, output_size); l.delta_gpu = cuda_make_array(l.delta, output_size);
#endif #endif
return layer; return l;
} }
void resize_maxpool_layer(maxpool_layer *layer, int h, int w) void resize_maxpool_layer(maxpool_layer *l, int h, int w)
{ {
layer->h = h; l->h = h;
layer->w = w; l->w = w;
int output_size = ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * layer->c * layer->batch; int output_size = ((h-1)/l->stride+1) * ((w-1)/l->stride+1) * l->c * l->batch;
layer->output = realloc(layer->output, output_size * sizeof(float)); l->output = realloc(l->output, output_size * sizeof(float));
layer->delta = realloc(layer->delta, output_size * sizeof(float)); l->delta = realloc(l->delta, output_size * sizeof(float));
#ifdef GPU #ifdef GPU
cuda_free((float *)layer->indexes_gpu); cuda_free((float *)l->indexes_gpu);
cuda_free(layer->output_gpu); cuda_free(l->output_gpu);
cuda_free(layer->delta_gpu); cuda_free(l->delta_gpu);
layer->indexes_gpu = cuda_make_int_array(output_size); l->indexes_gpu = cuda_make_int_array(output_size);
layer->output_gpu = cuda_make_array(layer->output, output_size); l->output_gpu = cuda_make_array(l->output, output_size);
layer->delta_gpu = cuda_make_array(layer->delta, output_size); l->delta_gpu = cuda_make_array(l->delta, output_size);
#endif #endif
} }
void forward_maxpool_layer(const maxpool_layer layer, network_state state) void forward_maxpool_layer(const maxpool_layer l, network_state state)
{ {
int b,i,j,k,l,m; int b,i,j,k,m,n;
int w_offset = (-layer.size-1)/2 + 1; int w_offset = (-l.size-1)/2 + 1;
int h_offset = (-layer.size-1)/2 + 1; int h_offset = (-l.size-1)/2 + 1;
int h = (layer.h-1)/layer.stride + 1; int h = (l.h-1)/l.stride + 1;
int w = (layer.w-1)/layer.stride + 1; int w = (l.w-1)/l.stride + 1;
int c = layer.c; int c = l.c;
for(b = 0; b < layer.batch; ++b){ for(b = 0; b < l.batch; ++b){
for(k = 0; k < c; ++k){ for(k = 0; k < c; ++k){
for(i = 0; i < h; ++i){ for(i = 0; i < h; ++i){
for(j = 0; j < w; ++j){ for(j = 0; j < w; ++j){
int out_index = j + w*(i + h*(k + c*b)); int out_index = j + w*(i + h*(k + c*b));
float max = -FLT_MAX; float max = -FLT_MAX;
int max_i = -1; int max_i = -1;
for(l = 0; l < layer.size; ++l){ for(n = 0; n < l.size; ++n){
for(m = 0; m < layer.size; ++m){ for(m = 0; m < l.size; ++m){
int cur_h = h_offset + i*layer.stride + l; int cur_h = h_offset + i*l.stride + n;
int cur_w = w_offset + j*layer.stride + m; int cur_w = w_offset + j*l.stride + m;
int index = cur_w + layer.w*(cur_h + layer.h*(k + b*layer.c)); int index = cur_w + l.w*(cur_h + l.h*(k + b*l.c));
int valid = (cur_h >= 0 && cur_h < layer.h && int valid = (cur_h >= 0 && cur_h < l.h &&
cur_w >= 0 && cur_w < layer.w); cur_w >= 0 && cur_w < l.w);
float val = (valid != 0) ? state.input[index] : -FLT_MAX; float val = (valid != 0) ? state.input[index] : -FLT_MAX;
max_i = (val > max) ? index : max_i; max_i = (val > max) ? index : max_i;
max = (val > max) ? val : max; max = (val > max) ? val : max;
} }
} }
layer.output[out_index] = max; l.output[out_index] = max;
layer.indexes[out_index] = max_i; l.indexes[out_index] = max_i;
} }
} }
} }
} }
} }
void backward_maxpool_layer(const maxpool_layer layer, network_state state) void backward_maxpool_layer(const maxpool_layer l, network_state state)
{ {
int i; int i;
int h = (layer.h-1)/layer.stride + 1; int h = (l.h-1)/l.stride + 1;
int w = (layer.w-1)/layer.stride + 1; int w = (l.w-1)/l.stride + 1;
int c = layer.c; int c = l.c;
memset(state.delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float)); memset(state.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float));
for(i = 0; i < h*w*c*layer.batch; ++i){ for(i = 0; i < h*w*c*l.batch; ++i){
int index = layer.indexes[i]; int index = l.indexes[i];
state.delta[index] += layer.delta[i]; state.delta[index] += l.delta[i];
} }
} }

View File

@ -4,31 +4,19 @@
#include "image.h" #include "image.h"
#include "params.h" #include "params.h"
#include "cuda.h" #include "cuda.h"
#include "layer.h"
typedef struct { typedef layer maxpool_layer;
int batch;
int h,w,c;
int stride;
int size;
int *indexes;
float *delta;
float *output;
#ifdef GPU
int *indexes_gpu;
float *output_gpu;
float *delta_gpu;
#endif
} maxpool_layer;
image get_maxpool_image(maxpool_layer layer); image get_maxpool_image(maxpool_layer l);
maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int size, int stride); maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride);
void resize_maxpool_layer(maxpool_layer *layer, int h, int w); void resize_maxpool_layer(maxpool_layer *l, int h, int w);
void forward_maxpool_layer(const maxpool_layer layer, network_state state); void forward_maxpool_layer(const maxpool_layer l, network_state state);
void backward_maxpool_layer(const maxpool_layer layer, network_state state); void backward_maxpool_layer(const maxpool_layer l, network_state state);
#ifdef GPU #ifdef GPU
void forward_maxpool_layer_gpu(maxpool_layer layer, network_state state); void forward_maxpool_layer_gpu(maxpool_layer l, network_state state);
void backward_maxpool_layer_gpu(maxpool_layer layer, network_state state); void backward_maxpool_layer_gpu(maxpool_layer l, network_state state);
#endif #endif
#endif #endif

View File

@ -12,7 +12,6 @@
#include "detection_layer.h" #include "detection_layer.h"
#include "maxpool_layer.h" #include "maxpool_layer.h"
#include "cost_layer.h" #include "cost_layer.h"
#include "normalization_layer.h"
#include "softmax_layer.h" #include "softmax_layer.h"
#include "dropout_layer.h" #include "dropout_layer.h"
#include "route_layer.h" #include "route_layer.h"
@ -32,8 +31,6 @@ char *get_layer_string(LAYER_TYPE a)
return "softmax"; return "softmax";
case DETECTION: case DETECTION:
return "detection"; return "detection";
case NORMALIZATION:
return "normalization";
case DROPOUT: case DROPOUT:
return "dropout"; return "dropout";
case CROP: case CROP:
@ -50,16 +47,9 @@ char *get_layer_string(LAYER_TYPE a)
network make_network(int n) network make_network(int n)
{ {
network net; network net = {0};
net.n = n; net.n = n;
net.layers = calloc(net.n, sizeof(void *)); net.layers = calloc(net.n, sizeof(layer));
net.types = calloc(net.n, sizeof(LAYER_TYPE));
net.outputs = 0;
net.output = 0;
net.seen = 0;
net.batch = 0;
net.inputs = 0;
net.h = net.w = net.c = 0;
#ifdef GPU #ifdef GPU
net.input_gpu = calloc(1, sizeof(float *)); net.input_gpu = calloc(1, sizeof(float *));
net.truth_gpu = calloc(1, sizeof(float *)); net.truth_gpu = calloc(1, sizeof(float *));
@ -71,40 +61,29 @@ void forward_network(network net, network_state state)
{ {
int i; int i;
for(i = 0; i < net.n; ++i){ for(i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){ layer l = net.layers[i];
forward_convolutional_layer(*(convolutional_layer *)net.layers[i], state); if(l.type == CONVOLUTIONAL){
forward_convolutional_layer(l, state);
} else if(l.type == DECONVOLUTIONAL){
forward_deconvolutional_layer(l, state);
} else if(l.type == DETECTION){
forward_detection_layer(l, state);
} else if(l.type == CONNECTED){
forward_connected_layer(l, state);
} else if(l.type == CROP){
forward_crop_layer(l, state);
} else if(l.type == COST){
forward_cost_layer(l, state);
} else if(l.type == SOFTMAX){
forward_softmax_layer(l, state);
} else if(l.type == MAXPOOL){
forward_maxpool_layer(l, state);
} else if(l.type == DROPOUT){
forward_dropout_layer(l, state);
} else if(l.type == ROUTE){
forward_route_layer(l, net);
} }
else if(net.types[i] == DECONVOLUTIONAL){ state.input = l.output;
forward_deconvolutional_layer(*(deconvolutional_layer *)net.layers[i], state);
}
else if(net.types[i] == DETECTION){
forward_detection_layer(*(detection_layer *)net.layers[i], state);
}
else if(net.types[i] == CONNECTED){
forward_connected_layer(*(connected_layer *)net.layers[i], state);
}
else if(net.types[i] == CROP){
forward_crop_layer(*(crop_layer *)net.layers[i], state);
}
else if(net.types[i] == COST){
forward_cost_layer(*(cost_layer *)net.layers[i], state);
}
else if(net.types[i] == SOFTMAX){
forward_softmax_layer(*(softmax_layer *)net.layers[i], state);
}
else if(net.types[i] == MAXPOOL){
forward_maxpool_layer(*(maxpool_layer *)net.layers[i], state);
}
else if(net.types[i] == NORMALIZATION){
forward_normalization_layer(*(normalization_layer *)net.layers[i], state);
}
else if(net.types[i] == DROPOUT){
forward_dropout_layer(*(dropout_layer *)net.layers[i], state);
}
else if(net.types[i] == ROUTE){
forward_route_layer(*(route_layer *)net.layers[i], net);
}
state.input = get_network_output_layer(net, i);
} }
} }
@ -113,99 +92,35 @@ void update_network(network net)
int i; int i;
int update_batch = net.batch*net.subdivisions; int update_batch = net.batch*net.subdivisions;
for(i = 0; i < net.n; ++i){ for(i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){ layer l = net.layers[i];
convolutional_layer layer = *(convolutional_layer *)net.layers[i]; if(l.type == CONVOLUTIONAL){
update_convolutional_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay); update_convolutional_layer(l, update_batch, net.learning_rate, net.momentum, net.decay);
} } else if(l.type == DECONVOLUTIONAL){
else if(net.types[i] == DECONVOLUTIONAL){ update_deconvolutional_layer(l, net.learning_rate, net.momentum, net.decay);
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; } else if(l.type == CONNECTED){
update_deconvolutional_layer(layer, net.learning_rate, net.momentum, net.decay); update_connected_layer(l, update_batch, net.learning_rate, net.momentum, net.decay);
}
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
update_connected_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay);
} }
} }
} }
float *get_network_output_layer(network net, int i)
{
if(net.types[i] == CONVOLUTIONAL){
return ((convolutional_layer *)net.layers[i]) -> output;
} else if(net.types[i] == DECONVOLUTIONAL){
return ((deconvolutional_layer *)net.layers[i]) -> output;
} else if(net.types[i] == MAXPOOL){
return ((maxpool_layer *)net.layers[i]) -> output;
} else if(net.types[i] == DETECTION){
return ((detection_layer *)net.layers[i]) -> output;
} else if(net.types[i] == SOFTMAX){
return ((softmax_layer *)net.layers[i]) -> output;
} else if(net.types[i] == DROPOUT){
return get_network_output_layer(net, i-1);
} else if(net.types[i] == CONNECTED){
return ((connected_layer *)net.layers[i]) -> output;
} else if(net.types[i] == CROP){
return ((crop_layer *)net.layers[i]) -> output;
} else if(net.types[i] == NORMALIZATION){
return ((normalization_layer *)net.layers[i]) -> output;
} else if(net.types[i] == ROUTE){
return ((route_layer *)net.layers[i]) -> output;
}
return 0;
}
float *get_network_output(network net) float *get_network_output(network net)
{ {
int i; int i;
for(i = net.n-1; i > 0; --i) if(net.types[i] != COST) break; for(i = net.n-1; i > 0; --i) if(net.layers[i].type != COST) break;
return get_network_output_layer(net, i); return net.layers[i].output;
}
float *get_network_delta_layer(network net, int i)
{
if(net.types[i] == CONVOLUTIONAL){
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
return layer.delta;
} else if(net.types[i] == DECONVOLUTIONAL){
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
return layer.delta;
} else if(net.types[i] == MAXPOOL){
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
return layer.delta;
} else if(net.types[i] == SOFTMAX){
softmax_layer layer = *(softmax_layer *)net.layers[i];
return layer.delta;
} else if(net.types[i] == DETECTION){
detection_layer layer = *(detection_layer *)net.layers[i];
return layer.delta;
} else if(net.types[i] == DROPOUT){
if(i == 0) return 0;
return get_network_delta_layer(net, i-1);
} else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
return layer.delta;
} else if(net.types[i] == ROUTE){
return ((route_layer *)net.layers[i]) -> delta;
}
return 0;
} }
float get_network_cost(network net) float get_network_cost(network net)
{ {
if(net.types[net.n-1] == COST){ if(net.layers[net.n-1].type == COST){
return ((cost_layer *)net.layers[net.n-1])->output[0]; return net.layers[net.n-1].output[0];
} }
if(net.types[net.n-1] == DETECTION){ if(net.layers[net.n-1].type == DETECTION){
return ((detection_layer *)net.layers[net.n-1])->cost[0]; return net.layers[net.n-1].cost[0];
} }
return 0; return 0;
} }
float *get_network_delta(network net)
{
return get_network_delta_layer(net, net.n-1);
}
int get_predicted_class_network(network net) int get_predicted_class_network(network net)
{ {
float *out = get_network_output(net); float *out = get_network_output(net);
@ -222,46 +137,29 @@ void backward_network(network net, network_state state)
state.input = original_input; state.input = original_input;
state.delta = 0; state.delta = 0;
}else{ }else{
state.input = get_network_output_layer(net, i-1); layer prev = net.layers[i-1];
state.delta = get_network_delta_layer(net, i-1); state.input = prev.output;
state.delta = prev.delta;
} }
layer l = net.layers[i];
if(net.types[i] == CONVOLUTIONAL){ if(l.type == CONVOLUTIONAL){
convolutional_layer layer = *(convolutional_layer *)net.layers[i]; backward_convolutional_layer(l, state);
backward_convolutional_layer(layer, state); } else if(l.type == DECONVOLUTIONAL){
} else if(net.types[i] == DECONVOLUTIONAL){ backward_deconvolutional_layer(l, state);
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; } else if(l.type == MAXPOOL){
backward_deconvolutional_layer(layer, state); if(i != 0) backward_maxpool_layer(l, state);
} } else if(l.type == DROPOUT){
else if(net.types[i] == MAXPOOL){ backward_dropout_layer(l, state);
maxpool_layer layer = *(maxpool_layer *)net.layers[i]; } else if(l.type == DETECTION){
if(i != 0) backward_maxpool_layer(layer, state); backward_detection_layer(l, state);
} } else if(l.type == SOFTMAX){
else if(net.types[i] == DROPOUT){ if(i != 0) backward_softmax_layer(l, state);
dropout_layer layer = *(dropout_layer *)net.layers[i]; } else if(l.type == CONNECTED){
backward_dropout_layer(layer, state); backward_connected_layer(l, state);
} } else if(l.type == COST){
else if(net.types[i] == DETECTION){ backward_cost_layer(l, state);
detection_layer layer = *(detection_layer *)net.layers[i]; } else if(l.type == ROUTE){
backward_detection_layer(layer, state); backward_route_layer(l, net);
}
else if(net.types[i] == NORMALIZATION){
normalization_layer layer = *(normalization_layer *)net.layers[i];
if(i != 0) backward_normalization_layer(layer, state);
}
else if(net.types[i] == SOFTMAX){
softmax_layer layer = *(softmax_layer *)net.layers[i];
if(i != 0) backward_softmax_layer(layer, state);
}
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
backward_connected_layer(layer, state);
} else if(net.types[i] == COST){
cost_layer layer = *(cost_layer *)net.layers[i];
backward_cost_layer(layer, state);
} else if(net.types[i] == ROUTE){
route_layer layer = *(route_layer *)net.layers[i];
backward_route_layer(layer, net);
} }
} }
} }
@ -347,127 +245,11 @@ void set_batch_network(network *net, int b)
net->batch = b; net->batch = b;
int i; int i;
for(i = 0; i < net->n; ++i){ for(i = 0; i < net->n; ++i){
if(net->types[i] == CONVOLUTIONAL){ net->layers[i].batch = b;
convolutional_layer *layer = (convolutional_layer *)net->layers[i];
layer->batch = b;
}else if(net->types[i] == DECONVOLUTIONAL){
deconvolutional_layer *layer = (deconvolutional_layer *)net->layers[i];
layer->batch = b;
}
else if(net->types[i] == MAXPOOL){
maxpool_layer *layer = (maxpool_layer *)net->layers[i];
layer->batch = b;
}
else if(net->types[i] == CONNECTED){
connected_layer *layer = (connected_layer *)net->layers[i];
layer->batch = b;
} else if(net->types[i] == DROPOUT){
dropout_layer *layer = (dropout_layer *) net->layers[i];
layer->batch = b;
} else if(net->types[i] == DETECTION){
detection_layer *layer = (detection_layer *) net->layers[i];
layer->batch = b;
}
else if(net->types[i] == SOFTMAX){
softmax_layer *layer = (softmax_layer *)net->layers[i];
layer->batch = b;
}
else if(net->types[i] == COST){
cost_layer *layer = (cost_layer *)net->layers[i];
layer->batch = b;
}
else if(net->types[i] == CROP){
crop_layer *layer = (crop_layer *)net->layers[i];
layer->batch = b;
}
else if(net->types[i] == ROUTE){
route_layer *layer = (route_layer *)net->layers[i];
layer->batch = b;
}
} }
} }
/*
int get_network_input_size_layer(network net, int i)
{
if(net.types[i] == CONVOLUTIONAL){
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
return layer.h*layer.w*layer.c;
}
if(net.types[i] == DECONVOLUTIONAL){
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
return layer.h*layer.w*layer.c;
}
else if(net.types[i] == MAXPOOL){
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
return layer.h*layer.w*layer.c;
}
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
return layer.inputs;
} else if(net.types[i] == DROPOUT){
dropout_layer layer = *(dropout_layer *) net.layers[i];
return layer.inputs;
} else if(net.types[i] == DETECTION){
detection_layer layer = *(detection_layer *) net.layers[i];
return layer.inputs;
} else if(net.types[i] == CROP){
crop_layer layer = *(crop_layer *) net.layers[i];
return layer.c*layer.h*layer.w;
}
else if(net.types[i] == SOFTMAX){
softmax_layer layer = *(softmax_layer *)net.layers[i];
return layer.inputs;
}
fprintf(stderr, "Can't find input size\n");
return 0;
}
int get_network_output_size_layer(network net, int i)
{
if(net.types[i] == CONVOLUTIONAL){
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
image output = get_convolutional_image(layer);
return output.h*output.w*output.c;
}
else if(net.types[i] == DECONVOLUTIONAL){
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
image output = get_deconvolutional_image(layer);
return output.h*output.w*output.c;
}
else if(net.types[i] == DETECTION){
detection_layer layer = *(detection_layer *)net.layers[i];
return get_detection_layer_output_size(layer);
}
else if(net.types[i] == MAXPOOL){
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
image output = get_maxpool_image(layer);
return output.h*output.w*output.c;
}
else if(net.types[i] == CROP){
crop_layer layer = *(crop_layer *) net.layers[i];
return layer.c*layer.crop_height*layer.crop_width;
}
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
return layer.outputs;
}
else if(net.types[i] == DROPOUT){
dropout_layer layer = *(dropout_layer *) net.layers[i];
return layer.inputs;
}
else if(net.types[i] == SOFTMAX){
softmax_layer layer = *(softmax_layer *)net.layers[i];
return layer.inputs;
}
else if(net.types[i] == ROUTE){
route_layer layer = *(route_layer *)net.layers[i];
return layer.outputs;
}
fprintf(stderr, "Can't find output size\n");
return 0;
}
int resize_network(network net, int h, int w, int c) int resize_network(network net, int h, int w, int c)
{ {
fprintf(stderr, "Might be broken, careful!!"); fprintf(stderr, "Might be broken, careful!!");
@ -497,74 +279,47 @@ int resize_network(network net, int h, int w, int c)
}else if(net.types[i] == DROPOUT){ }else if(net.types[i] == DROPOUT){
dropout_layer *layer = (dropout_layer *)net.layers[i]; dropout_layer *layer = (dropout_layer *)net.layers[i];
resize_dropout_layer(layer, h*w*c); resize_dropout_layer(layer, h*w*c);
}else if(net.types[i] == NORMALIZATION){
normalization_layer *layer = (normalization_layer *)net.layers[i];
resize_normalization_layer(layer, h, w);
image output = get_normalization_image(*layer);
h = output.h;
w = output.w;
c = output.c;
}else{ }else{
error("Cannot resize this type of layer"); error("Cannot resize this type of layer");
} }
} }
return 0; return 0;
} }
*/
int get_network_output_size(network net) int get_network_output_size(network net)
{ {
int i; int i;
for(i = net.n-1; i > 0; --i) if(net.types[i] != COST) break; for(i = net.n-1; i > 0; --i) if(net.layers[i].type != COST) break;
return get_network_output_size_layer(net, i); return net.layers[i].outputs;
} }
int get_network_input_size(network net) int get_network_input_size(network net)
{ {
return get_network_input_size_layer(net, 0); return net.layers[0].inputs;
} }
detection_layer *get_network_detection_layer(network net) detection_layer get_network_detection_layer(network net)
{ {
int i; int i;
for(i = 0; i < net.n; ++i){ for(i = 0; i < net.n; ++i){
if(net.types[i] == DETECTION){ if(net.layers[i].type == DETECTION){
detection_layer *layer = (detection_layer *)net.layers[i]; return net.layers[i];
return layer;
} }
} }
return 0; fprintf(stderr, "Detection layer not found!!\n");
detection_layer l = {0};
return l;
} }
image get_network_image_layer(network net, int i) image get_network_image_layer(network net, int i)
{ {
if(net.types[i] == CONVOLUTIONAL){ layer l = net.layers[i];
convolutional_layer layer = *(convolutional_layer *)net.layers[i]; if (l.out_w && l.out_h && l.out_c){
return get_convolutional_image(layer); return float_to_image(l.out_w, l.out_h, l.out_c, l.output);
} }
else if(net.types[i] == DECONVOLUTIONAL){ image def = {0};
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; return def;
return get_deconvolutional_image(layer);
}
else if(net.types[i] == MAXPOOL){
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
return get_maxpool_image(layer);
}
else if(net.types[i] == NORMALIZATION){
normalization_layer layer = *(normalization_layer *)net.layers[i];
return get_normalization_image(layer);
}
else if(net.types[i] == DROPOUT){
return get_network_image_layer(net, i-1);
}
else if(net.types[i] == CROP){
crop_layer layer = *(crop_layer *)net.layers[i];
return get_crop_image(layer);
}
else if(net.types[i] == ROUTE){
route_layer layer = *(route_layer *)net.layers[i];
return get_network_image_layer(net, layer.input_layers[0]);
}
return make_empty_image(0,0,0);
} }
image get_network_image(network net) image get_network_image(network net)
@ -574,7 +329,8 @@ image get_network_image(network net)
image m = get_network_image_layer(net, i); image m = get_network_image_layer(net, i);
if(m.h != 0) return m; if(m.h != 0) return m;
} }
return make_empty_image(0,0,0); image def = {0};
return def;
} }
void visualize_network(network net) void visualize_network(network net)
@ -582,16 +338,11 @@ void visualize_network(network net)
image *prev = 0; image *prev = 0;
int i; int i;
char buff[256]; char buff[256];
//show_image(get_network_image_layer(net, 0), "Crop");
for(i = 0; i < net.n; ++i){ for(i = 0; i < net.n; ++i){
sprintf(buff, "Layer %d", i); sprintf(buff, "Layer %d", i);
if(net.types[i] == CONVOLUTIONAL){ layer l = net.layers[i];
convolutional_layer layer = *(convolutional_layer *)net.layers[i]; if(l.type == CONVOLUTIONAL){
prev = visualize_convolutional_layer(layer, buff, prev); prev = visualize_convolutional_layer(l, buff, prev);
}
if(net.types[i] == NORMALIZATION){
normalization_layer layer = *(normalization_layer *)net.layers[i];
visualize_normalization_layer(layer, buff);
} }
} }
} }
@ -672,36 +423,9 @@ void print_network(network net)
{ {
int i,j; int i,j;
for(i = 0; i < net.n; ++i){ for(i = 0; i < net.n; ++i){
float *output = 0; layer l = net.layers[i];
int n = 0; float *output = l.output;
if(net.types[i] == CONVOLUTIONAL){ int n = l.outputs;
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
output = layer.output;
image m = get_convolutional_image(layer);
n = m.h*m.w*m.c;
}
else if(net.types[i] == MAXPOOL){
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
output = layer.output;
image m = get_maxpool_image(layer);
n = m.h*m.w*m.c;
}
else if(net.types[i] == CROP){
crop_layer layer = *(crop_layer *)net.layers[i];
output = layer.output;
image m = get_crop_image(layer);
n = m.h*m.w*m.c;
}
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
output = layer.output;
n = layer.outputs;
}
else if(net.types[i] == SOFTMAX){
softmax_layer layer = *(softmax_layer *)net.layers[i];
output = layer.output;
n = layer.inputs;
}
float mean = mean_array(output, n); float mean = mean_array(output, n);
float vari = variance_array(output, n); float vari = variance_array(output, n);
fprintf(stderr, "Layer %d - Mean: %f, Variance: %f\n",i,mean, vari); fprintf(stderr, "Layer %d - Mean: %f, Variance: %f\n",i,mean, vari);

View File

@ -4,22 +4,9 @@
#include "image.h" #include "image.h"
#include "detection_layer.h" #include "detection_layer.h"
#include "layer.h"
#include "data.h" #include "data.h"
typedef enum {
CONVOLUTIONAL,
DECONVOLUTIONAL,
CONNECTED,
MAXPOOL,
SOFTMAX,
DETECTION,
NORMALIZATION,
DROPOUT,
CROP,
ROUTE,
COST
} LAYER_TYPE;
typedef struct { typedef struct {
int n; int n;
int batch; int batch;
@ -28,8 +15,7 @@ typedef struct {
float learning_rate; float learning_rate;
float momentum; float momentum;
float decay; float decay;
void **layers; layer *layers;
LAYER_TYPE *types;
int outputs; int outputs;
float *output; float *output;
@ -83,7 +69,7 @@ int resize_network(network net, int h, int w, int c);
void set_batch_network(network *net, int b); void set_batch_network(network *net, int b);
int get_network_input_size(network net); int get_network_input_size(network net);
float get_network_cost(network net); float get_network_cost(network net);
detection_layer *get_network_detection_layer(network net); detection_layer get_network_detection_layer(network net);
int get_network_nuisance(network net); int get_network_nuisance(network net);
int get_network_background(network net); int get_network_background(network net);

View File

@ -15,7 +15,6 @@ extern "C" {
#include "deconvolutional_layer.h" #include "deconvolutional_layer.h"
#include "maxpool_layer.h" #include "maxpool_layer.h"
#include "cost_layer.h" #include "cost_layer.h"
#include "normalization_layer.h"
#include "softmax_layer.h" #include "softmax_layer.h"
#include "dropout_layer.h" #include "dropout_layer.h"
#include "route_layer.h" #include "route_layer.h"
@ -29,37 +28,29 @@ void forward_network_gpu(network net, network_state state)
{ {
int i; int i;
for(i = 0; i < net.n; ++i){ for(i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){ layer l = net.layers[i];
forward_convolutional_layer_gpu(*(convolutional_layer *)net.layers[i], state); if(l.type == CONVOLUTIONAL){
forward_convolutional_layer_gpu(l, state);
} else if(l.type == DECONVOLUTIONAL){
forward_deconvolutional_layer_gpu(l, state);
} else if(l.type == DETECTION){
forward_detection_layer_gpu(l, state);
} else if(l.type == CONNECTED){
forward_connected_layer_gpu(l, state);
} else if(l.type == CROP){
forward_crop_layer_gpu(l, state);
} else if(l.type == COST){
forward_cost_layer_gpu(l, state);
} else if(l.type == SOFTMAX){
forward_softmax_layer_gpu(l, state);
} else if(l.type == MAXPOOL){
forward_maxpool_layer_gpu(l, state);
} else if(l.type == DROPOUT){
forward_dropout_layer_gpu(l, state);
} else if(l.type == ROUTE){
forward_route_layer_gpu(l, net);
} }
else if(net.types[i] == DECONVOLUTIONAL){ state.input = l.output_gpu;
forward_deconvolutional_layer_gpu(*(deconvolutional_layer *)net.layers[i], state);
}
else if(net.types[i] == COST){
forward_cost_layer_gpu(*(cost_layer *)net.layers[i], state);
}
else if(net.types[i] == CONNECTED){
forward_connected_layer_gpu(*(connected_layer *)net.layers[i], state);
}
else if(net.types[i] == DETECTION){
forward_detection_layer_gpu(*(detection_layer *)net.layers[i], state);
}
else if(net.types[i] == MAXPOOL){
forward_maxpool_layer_gpu(*(maxpool_layer *)net.layers[i], state);
}
else if(net.types[i] == SOFTMAX){
forward_softmax_layer_gpu(*(softmax_layer *)net.layers[i], state);
}
else if(net.types[i] == DROPOUT){
forward_dropout_layer_gpu(*(dropout_layer *)net.layers[i], state);
}
else if(net.types[i] == CROP){
forward_crop_layer_gpu(*(crop_layer *)net.layers[i], state);
}
else if(net.types[i] == ROUTE){
forward_route_layer_gpu(*(route_layer *)net.layers[i], net);
}
state.input = get_network_output_gpu_layer(net, i);
} }
} }
@ -68,40 +59,33 @@ void backward_network_gpu(network net, network_state state)
int i; int i;
float * original_input = state.input; float * original_input = state.input;
for(i = net.n-1; i >= 0; --i){ for(i = net.n-1; i >= 0; --i){
layer l = net.layers[i];
if(i == 0){ if(i == 0){
state.input = original_input; state.input = original_input;
state.delta = 0; state.delta = 0;
}else{ }else{
state.input = get_network_output_gpu_layer(net, i-1); layer prev = net.layers[i-1];
state.delta = get_network_delta_gpu_layer(net, i-1); state.input = prev.output_gpu;
state.delta = prev.delta_gpu;
} }
if(l.type == CONVOLUTIONAL){
if(net.types[i] == CONVOLUTIONAL){ backward_convolutional_layer_gpu(l, state);
backward_convolutional_layer_gpu(*(convolutional_layer *)net.layers[i], state); } else if(l.type == DECONVOLUTIONAL){
} backward_deconvolutional_layer_gpu(l, state);
else if(net.types[i] == DECONVOLUTIONAL){ } else if(l.type == MAXPOOL){
backward_deconvolutional_layer_gpu(*(deconvolutional_layer *)net.layers[i], state); if(i != 0) backward_maxpool_layer_gpu(l, state);
} } else if(l.type == DROPOUT){
else if(net.types[i] == COST){ backward_dropout_layer_gpu(l, state);
backward_cost_layer_gpu(*(cost_layer *)net.layers[i], state); } else if(l.type == DETECTION){
} backward_detection_layer_gpu(l, state);
else if(net.types[i] == CONNECTED){ } else if(l.type == SOFTMAX){
backward_connected_layer_gpu(*(connected_layer *)net.layers[i], state); if(i != 0) backward_softmax_layer_gpu(l, state);
} } else if(l.type == CONNECTED){
else if(net.types[i] == DETECTION){ backward_connected_layer_gpu(l, state);
backward_detection_layer_gpu(*(detection_layer *)net.layers[i], state); } else if(l.type == COST){
} backward_cost_layer_gpu(l, state);
else if(net.types[i] == MAXPOOL){ } else if(l.type == ROUTE){
backward_maxpool_layer_gpu(*(maxpool_layer *)net.layers[i], state); backward_route_layer_gpu(l, net);
}
else if(net.types[i] == DROPOUT){
backward_dropout_layer_gpu(*(dropout_layer *)net.layers[i], state);
}
else if(net.types[i] == SOFTMAX){
backward_softmax_layer_gpu(*(softmax_layer *)net.layers[i], state);
}
else if(net.types[i] == ROUTE){
backward_route_layer_gpu(*(route_layer *)net.layers[i], net);
} }
} }
} }
@ -111,89 +95,17 @@ void update_network_gpu(network net)
int i; int i;
int update_batch = net.batch*net.subdivisions; int update_batch = net.batch*net.subdivisions;
for(i = 0; i < net.n; ++i){ for(i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){ layer l = net.layers[i];
convolutional_layer layer = *(convolutional_layer *)net.layers[i]; if(l.type == CONVOLUTIONAL){
update_convolutional_layer_gpu(layer, update_batch, net.learning_rate, net.momentum, net.decay); update_convolutional_layer_gpu(l, update_batch, net.learning_rate, net.momentum, net.decay);
} } else if(l.type == DECONVOLUTIONAL){
else if(net.types[i] == DECONVOLUTIONAL){ update_deconvolutional_layer_gpu(l, net.learning_rate, net.momentum, net.decay);
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; } else if(l.type == CONNECTED){
update_deconvolutional_layer_gpu(layer, net.learning_rate, net.momentum, net.decay); update_connected_layer_gpu(l, update_batch, net.learning_rate, net.momentum, net.decay);
}
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
update_connected_layer_gpu(layer, update_batch, net.learning_rate, net.momentum, net.decay);
} }
} }
} }
float * get_network_output_gpu_layer(network net, int i)
{
if(net.types[i] == CONVOLUTIONAL){
return ((convolutional_layer *)net.layers[i]) -> output_gpu;
}
else if(net.types[i] == DECONVOLUTIONAL){
return ((deconvolutional_layer *)net.layers[i]) -> output_gpu;
}
else if(net.types[i] == DETECTION){
return ((detection_layer *)net.layers[i]) -> output_gpu;
}
else if(net.types[i] == CONNECTED){
return ((connected_layer *)net.layers[i]) -> output_gpu;
}
else if(net.types[i] == MAXPOOL){
return ((maxpool_layer *)net.layers[i]) -> output_gpu;
}
else if(net.types[i] == CROP){
return ((crop_layer *)net.layers[i]) -> output_gpu;
}
else if(net.types[i] == SOFTMAX){
return ((softmax_layer *)net.layers[i]) -> output_gpu;
}
else if(net.types[i] == ROUTE){
return ((route_layer *)net.layers[i]) -> output_gpu;
}
else if(net.types[i] == DROPOUT){
return get_network_output_gpu_layer(net, i-1);
}
return 0;
}
float * get_network_delta_gpu_layer(network net, int i)
{
if(net.types[i] == CONVOLUTIONAL){
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
return layer.delta_gpu;
}
else if(net.types[i] == DETECTION){
detection_layer layer = *(detection_layer *)net.layers[i];
return layer.delta_gpu;
}
else if(net.types[i] == DECONVOLUTIONAL){
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
return layer.delta_gpu;
}
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
return layer.delta_gpu;
}
else if(net.types[i] == MAXPOOL){
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
return layer.delta_gpu;
}
else if(net.types[i] == ROUTE){
route_layer layer = *(route_layer *)net.layers[i];
return layer.delta_gpu;
}
else if(net.types[i] == SOFTMAX){
softmax_layer layer = *(softmax_layer *)net.layers[i];
return layer.delta_gpu;
} else if(net.types[i] == DROPOUT){
if(i == 0) return 0;
return get_network_delta_gpu_layer(net, i-1);
}
return 0;
}
float train_network_datum_gpu(network net, float *x, float *y) float train_network_datum_gpu(network net, float *x, float *y)
{ {
network_state state; network_state state;
@ -219,33 +131,22 @@ float train_network_datum_gpu(network net, float *x, float *y)
float *get_network_output_layer_gpu(network net, int i) float *get_network_output_layer_gpu(network net, int i)
{ {
if(net.types[i] == CONVOLUTIONAL){ layer l = net.layers[i];
convolutional_layer layer = *(convolutional_layer *)net.layers[i]; if(l.type == CONVOLUTIONAL){
return layer.output; return l.output;
} } else if(l.type == DECONVOLUTIONAL){
else if(net.types[i] == DECONVOLUTIONAL){ return l.output;
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i]; } else if(l.type == CONNECTED){
return layer.output; cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
} return l.output;
else if(net.types[i] == CONNECTED){ } else if(l.type == DETECTION){
connected_layer layer = *(connected_layer *)net.layers[i]; cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
cuda_pull_array(layer.output_gpu, layer.output, layer.outputs*layer.batch); return l.output;
return layer.output; } else if(l.type == MAXPOOL){
} return l.output;
else if(net.types[i] == DETECTION){ } else if(l.type == SOFTMAX){
detection_layer layer = *(detection_layer *)net.layers[i]; pull_softmax_layer_output(l);
int outputs = get_detection_layer_output_size(layer); return l.output;
cuda_pull_array(layer.output_gpu, layer.output, outputs*layer.batch);
return layer.output;
}
else if(net.types[i] == MAXPOOL){
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
return layer.output;
}
else if(net.types[i] == SOFTMAX){
softmax_layer layer = *(softmax_layer *)net.layers[i];
pull_softmax_layer_output(layer);
return layer.output;
} }
return 0; return 0;
} }
@ -253,7 +154,7 @@ float *get_network_output_layer_gpu(network net, int i)
float *get_network_output_gpu(network net) float *get_network_output_gpu(network net)
{ {
int i; int i;
for(i = net.n-1; i > 0; --i) if(net.types[i] != COST) break; for(i = net.n-1; i > 0; --i) if(net.layers[i].type != COST) break;
return get_network_output_layer_gpu(net, i); return get_network_output_layer_gpu(net, i);
} }

View File

@ -1,96 +0,0 @@
#include "normalization_layer.h"
#include <stdio.h>
image get_normalization_image(normalization_layer layer)
{
int h = layer.h;
int w = layer.w;
int c = layer.c;
return float_to_image(w,h,c,layer.output);
}
image get_normalization_delta(normalization_layer layer)
{
int h = layer.h;
int w = layer.w;
int c = layer.c;
return float_to_image(w,h,c,layer.delta);
}
normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa)
{
fprintf(stderr, "Local Response Normalization Layer: %d x %d x %d image, %d size\n", h,w,c,size);
normalization_layer *layer = calloc(1, sizeof(normalization_layer));
layer->batch = batch;
layer->h = h;
layer->w = w;
layer->c = c;
layer->kappa = kappa;
layer->size = size;
layer->alpha = alpha;
layer->beta = beta;
layer->output = calloc(h * w * c * batch, sizeof(float));
layer->delta = calloc(h * w * c * batch, sizeof(float));
layer->sums = calloc(h*w, sizeof(float));
return layer;
}
void resize_normalization_layer(normalization_layer *layer, int h, int w)
{
layer->h = h;
layer->w = w;
layer->output = realloc(layer->output, h * w * layer->c * layer->batch * sizeof(float));
layer->delta = realloc(layer->delta, h * w * layer->c * layer->batch * sizeof(float));
layer->sums = realloc(layer->sums, h*w * sizeof(float));
}
void add_square_array(float *src, float *dest, int n)
{
int i;
for(i = 0; i < n; ++i){
dest[i] += src[i]*src[i];
}
}
void sub_square_array(float *src, float *dest, int n)
{
int i;
for(i = 0; i < n; ++i){
dest[i] -= src[i]*src[i];
}
}
void forward_normalization_layer(const normalization_layer layer, network_state state)
{
int i,j,k;
memset(layer.sums, 0, layer.h*layer.w*sizeof(float));
int imsize = layer.h*layer.w;
for(j = 0; j < layer.size/2; ++j){
if(j < layer.c) add_square_array(state.input+j*imsize, layer.sums, imsize);
}
for(k = 0; k < layer.c; ++k){
int next = k+layer.size/2;
int prev = k-layer.size/2-1;
if(next < layer.c) add_square_array(state.input+next*imsize, layer.sums, imsize);
if(prev > 0) sub_square_array(state.input+prev*imsize, layer.sums, imsize);
for(i = 0; i < imsize; ++i){
layer.output[k*imsize + i] = state.input[k*imsize+i] / pow(layer.kappa + layer.alpha * layer.sums[i], layer.beta);
}
}
}
void backward_normalization_layer(const normalization_layer layer, network_state state)
{
// TODO!
// OR NOT TODO!!
}
void visualize_normalization_layer(normalization_layer layer, char *window)
{
image delta = get_normalization_image(layer);
image dc = collapse_image_layers(delta, 1);
char buff[256];
sprintf(buff, "%s: Output", window);
show_image(dc, buff);
save_image(dc, buff);
free_image(dc);
}

View File

@ -1,27 +0,0 @@
#ifndef NORMALIZATION_LAYER_H
#define NORMALIZATION_LAYER_H
#include "image.h"
#include "params.h"
typedef struct {
int batch;
int h,w,c;
int size;
float alpha;
float beta;
float kappa;
float *delta;
float *output;
float *sums;
} normalization_layer;
image get_normalization_image(normalization_layer layer);
normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa);
void resize_normalization_layer(normalization_layer *layer, int h, int w);
void forward_normalization_layer(const normalization_layer layer, network_state state);
void backward_normalization_layer(const normalization_layer layer, network_state state);
void visualize_normalization_layer(normalization_layer layer, char *window);
#endif

251
src/old.c
View File

@ -1,3 +1,254 @@
void save_network(network net, char *filename)
{
FILE *fp = fopen(filename, "w");
if(!fp) file_error(filename);
int i;
for(i = 0; i < net.n; ++i)
{
if(net.types[i] == CONVOLUTIONAL)
print_convolutional_cfg(fp, (convolutional_layer *)net.layers[i], net, i);
else if(net.types[i] == DECONVOLUTIONAL)
print_deconvolutional_cfg(fp, (deconvolutional_layer *)net.layers[i], net, i);
else if(net.types[i] == CONNECTED)
print_connected_cfg(fp, (connected_layer *)net.layers[i], net, i);
else if(net.types[i] == CROP)
print_crop_cfg(fp, (crop_layer *)net.layers[i], net, i);
else if(net.types[i] == MAXPOOL)
print_maxpool_cfg(fp, (maxpool_layer *)net.layers[i], net, i);
else if(net.types[i] == DROPOUT)
print_dropout_cfg(fp, (dropout_layer *)net.layers[i], net, i);
else if(net.types[i] == SOFTMAX)
print_softmax_cfg(fp, (softmax_layer *)net.layers[i], net, i);
else if(net.types[i] == DETECTION)
print_detection_cfg(fp, (detection_layer *)net.layers[i], net, i);
else if(net.types[i] == COST)
print_cost_cfg(fp, (cost_layer *)net.layers[i], net, i);
}
fclose(fp);
}
void print_convolutional_cfg(FILE *fp, convolutional_layer *l, network net, int count)
{
#ifdef GPU
if(gpu_index >= 0) pull_convolutional_layer(*l);
#endif
int i;
fprintf(fp, "[convolutional]\n");
fprintf(fp, "filters=%d\n"
"size=%d\n"
"stride=%d\n"
"pad=%d\n"
"activation=%s\n",
l->n, l->size, l->stride, l->pad,
get_activation_string(l->activation));
fprintf(fp, "biases=");
for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
fprintf(fp, "\n");
fprintf(fp, "weights=");
for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]);
fprintf(fp, "\n\n");
}
void print_deconvolutional_cfg(FILE *fp, deconvolutional_layer *l, network net, int count)
{
#ifdef GPU
if(gpu_index >= 0) pull_deconvolutional_layer(*l);
#endif
int i;
fprintf(fp, "[deconvolutional]\n");
fprintf(fp, "filters=%d\n"
"size=%d\n"
"stride=%d\n"
"activation=%s\n",
l->n, l->size, l->stride,
get_activation_string(l->activation));
fprintf(fp, "biases=");
for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
fprintf(fp, "\n");
fprintf(fp, "weights=");
for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]);
fprintf(fp, "\n\n");
}
void print_dropout_cfg(FILE *fp, dropout_layer *l, network net, int count)
{
fprintf(fp, "[dropout]\n");
fprintf(fp, "probability=%g\n\n", l->probability);
}
void print_connected_cfg(FILE *fp, connected_layer *l, network net, int count)
{
#ifdef GPU
if(gpu_index >= 0) pull_connected_layer(*l);
#endif
int i;
fprintf(fp, "[connected]\n");
fprintf(fp, "output=%d\n"
"activation=%s\n",
l->outputs,
get_activation_string(l->activation));
fprintf(fp, "biases=");
for(i = 0; i < l->outputs; ++i) fprintf(fp, "%g,", l->biases[i]);
fprintf(fp, "\n");
fprintf(fp, "weights=");
for(i = 0; i < l->outputs*l->inputs; ++i) fprintf(fp, "%g,", l->weights[i]);
fprintf(fp, "\n\n");
}
void print_crop_cfg(FILE *fp, crop_layer *l, network net, int count)
{
fprintf(fp, "[crop]\n");
fprintf(fp, "crop_height=%d\ncrop_width=%d\nflip=%d\n\n", l->crop_height, l->crop_width, l->flip);
}
void print_maxpool_cfg(FILE *fp, maxpool_layer *l, network net, int count)
{
fprintf(fp, "[maxpool]\n");
fprintf(fp, "size=%d\nstride=%d\n\n", l->size, l->stride);
}
void print_softmax_cfg(FILE *fp, softmax_layer *l, network net, int count)
{
fprintf(fp, "[softmax]\n");
fprintf(fp, "\n");
}
void print_detection_cfg(FILE *fp, detection_layer *l, network net, int count)
{
fprintf(fp, "[detection]\n");
fprintf(fp, "classes=%d\ncoords=%d\nrescore=%d\nnuisance=%d\n", l->classes, l->coords, l->rescore, l->nuisance);
fprintf(fp, "\n");
}
void print_cost_cfg(FILE *fp, cost_layer *l, network net, int count)
{
fprintf(fp, "[cost]\ntype=%s\n", get_cost_string(l->type));
fprintf(fp, "\n");
}
#ifndef NORMALIZATION_LAYER_H
#define NORMALIZATION_LAYER_H
#include "image.h"
#include "params.h"
typedef struct {
int batch;
int h,w,c;
int size;
float alpha;
float beta;
float kappa;
float *delta;
float *output;
float *sums;
} normalization_layer;
image get_normalization_image(normalization_layer layer);
normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa);
void resize_normalization_layer(normalization_layer *layer, int h, int w);
void forward_normalization_layer(const normalization_layer layer, network_state state);
void backward_normalization_layer(const normalization_layer layer, network_state state);
void visualize_normalization_layer(normalization_layer layer, char *window);
#endif
#include "normalization_layer.h"
#include <stdio.h>
image get_normalization_image(normalization_layer layer)
{
int h = layer.h;
int w = layer.w;
int c = layer.c;
return float_to_image(w,h,c,layer.output);
}
image get_normalization_delta(normalization_layer layer)
{
int h = layer.h;
int w = layer.w;
int c = layer.c;
return float_to_image(w,h,c,layer.delta);
}
normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa)
{
fprintf(stderr, "Local Response Normalization Layer: %d x %d x %d image, %d size\n", h,w,c,size);
normalization_layer *layer = calloc(1, sizeof(normalization_layer));
layer->batch = batch;
layer->h = h;
layer->w = w;
layer->c = c;
layer->kappa = kappa;
layer->size = size;
layer->alpha = alpha;
layer->beta = beta;
layer->output = calloc(h * w * c * batch, sizeof(float));
layer->delta = calloc(h * w * c * batch, sizeof(float));
layer->sums = calloc(h*w, sizeof(float));
return layer;
}
void resize_normalization_layer(normalization_layer *layer, int h, int w)
{
layer->h = h;
layer->w = w;
layer->output = realloc(layer->output, h * w * layer->c * layer->batch * sizeof(float));
layer->delta = realloc(layer->delta, h * w * layer->c * layer->batch * sizeof(float));
layer->sums = realloc(layer->sums, h*w * sizeof(float));
}
void add_square_array(float *src, float *dest, int n)
{
int i;
for(i = 0; i < n; ++i){
dest[i] += src[i]*src[i];
}
}
void sub_square_array(float *src, float *dest, int n)
{
int i;
for(i = 0; i < n; ++i){
dest[i] -= src[i]*src[i];
}
}
void forward_normalization_layer(const normalization_layer layer, network_state state)
{
int i,j,k;
memset(layer.sums, 0, layer.h*layer.w*sizeof(float));
int imsize = layer.h*layer.w;
for(j = 0; j < layer.size/2; ++j){
if(j < layer.c) add_square_array(state.input+j*imsize, layer.sums, imsize);
}
for(k = 0; k < layer.c; ++k){
int next = k+layer.size/2;
int prev = k-layer.size/2-1;
if(next < layer.c) add_square_array(state.input+next*imsize, layer.sums, imsize);
if(prev > 0) sub_square_array(state.input+prev*imsize, layer.sums, imsize);
for(i = 0; i < imsize; ++i){
layer.output[k*imsize + i] = state.input[k*imsize+i] / pow(layer.kappa + layer.alpha * layer.sums[i], layer.beta);
}
}
}
void backward_normalization_layer(const normalization_layer layer, network_state state)
{
// TODO!
// OR NOT TODO!!
}
void visualize_normalization_layer(normalization_layer layer, char *window)
{
image delta = get_normalization_image(layer);
image dc = collapse_image_layers(delta, 1);
char buff[256];
sprintf(buff, "%s: Output", window);
show_image(dc, buff);
save_image(dc, buff);
free_image(dc);
}
void test_load() void test_load()
{ {

View File

@ -10,7 +10,6 @@
#include "deconvolutional_layer.h" #include "deconvolutional_layer.h"
#include "connected_layer.h" #include "connected_layer.h"
#include "maxpool_layer.h" #include "maxpool_layer.h"
#include "normalization_layer.h"
#include "softmax_layer.h" #include "softmax_layer.h"
#include "dropout_layer.h" #include "dropout_layer.h"
#include "detection_layer.h" #include "detection_layer.h"
@ -34,7 +33,6 @@ int is_softmax(section *s);
int is_crop(section *s); int is_crop(section *s);
int is_cost(section *s); int is_cost(section *s);
int is_detection(section *s); int is_detection(section *s);
int is_normalization(section *s);
int is_route(section *s); int is_route(section *s);
list *read_cfg(char *filename); list *read_cfg(char *filename);
@ -78,7 +76,7 @@ typedef struct size_params{
int c; int c;
} size_params; } size_params;
deconvolutional_layer *parse_deconvolutional(list *options, size_params params) deconvolutional_layer parse_deconvolutional(list *options, size_params params)
{ {
int n = option_find_int(options, "filters",1); int n = option_find_int(options, "filters",1);
int size = option_find_int(options, "size",1); int size = option_find_int(options, "size",1);
@ -93,20 +91,20 @@ deconvolutional_layer *parse_deconvolutional(list *options, size_params params)
batch=params.batch; batch=params.batch;
if(!(h && w && c)) error("Layer before deconvolutional layer must output image."); if(!(h && w && c)) error("Layer before deconvolutional layer must output image.");
deconvolutional_layer *layer = make_deconvolutional_layer(batch,h,w,c,n,size,stride,activation); deconvolutional_layer layer = make_deconvolutional_layer(batch,h,w,c,n,size,stride,activation);
char *weights = option_find_str(options, "weights", 0); char *weights = option_find_str(options, "weights", 0);
char *biases = option_find_str(options, "biases", 0); char *biases = option_find_str(options, "biases", 0);
parse_data(weights, layer->filters, c*n*size*size); parse_data(weights, layer.filters, c*n*size*size);
parse_data(biases, layer->biases, n); parse_data(biases, layer.biases, n);
#ifdef GPU #ifdef GPU
if(weights || biases) push_deconvolutional_layer(*layer); if(weights || biases) push_deconvolutional_layer(layer);
#endif #endif
option_unused(options); option_unused(options);
return layer; return layer;
} }
convolutional_layer *parse_convolutional(list *options, size_params params) convolutional_layer parse_convolutional(list *options, size_params params)
{ {
int n = option_find_int(options, "filters",1); int n = option_find_int(options, "filters",1);
int size = option_find_int(options, "size",1); int size = option_find_int(options, "size",1);
@ -122,68 +120,68 @@ convolutional_layer *parse_convolutional(list *options, size_params params)
batch=params.batch; batch=params.batch;
if(!(h && w && c)) error("Layer before convolutional layer must output image."); if(!(h && w && c)) error("Layer before convolutional layer must output image.");
convolutional_layer *layer = make_convolutional_layer(batch,h,w,c,n,size,stride,pad,activation); convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,pad,activation);
char *weights = option_find_str(options, "weights", 0); char *weights = option_find_str(options, "weights", 0);
char *biases = option_find_str(options, "biases", 0); char *biases = option_find_str(options, "biases", 0);
parse_data(weights, layer->filters, c*n*size*size); parse_data(weights, layer.filters, c*n*size*size);
parse_data(biases, layer->biases, n); parse_data(biases, layer.biases, n);
#ifdef GPU #ifdef GPU
if(weights || biases) push_convolutional_layer(*layer); if(weights || biases) push_convolutional_layer(layer);
#endif #endif
option_unused(options); option_unused(options);
return layer; return layer;
} }
connected_layer *parse_connected(list *options, size_params params) connected_layer parse_connected(list *options, size_params params)
{ {
int output = option_find_int(options, "output",1); int output = option_find_int(options, "output",1);
char *activation_s = option_find_str(options, "activation", "logistic"); char *activation_s = option_find_str(options, "activation", "logistic");
ACTIVATION activation = get_activation(activation_s); ACTIVATION activation = get_activation(activation_s);
connected_layer *layer = make_connected_layer(params.batch, params.inputs, output, activation); connected_layer layer = make_connected_layer(params.batch, params.inputs, output, activation);
char *weights = option_find_str(options, "weights", 0); char *weights = option_find_str(options, "weights", 0);
char *biases = option_find_str(options, "biases", 0); char *biases = option_find_str(options, "biases", 0);
parse_data(biases, layer->biases, output); parse_data(biases, layer.biases, output);
parse_data(weights, layer->weights, params.inputs*output); parse_data(weights, layer.weights, params.inputs*output);
#ifdef GPU #ifdef GPU
if(weights || biases) push_connected_layer(*layer); if(weights || biases) push_connected_layer(layer);
#endif #endif
option_unused(options); option_unused(options);
return layer; return layer;
} }
softmax_layer *parse_softmax(list *options, size_params params) softmax_layer parse_softmax(list *options, size_params params)
{ {
int groups = option_find_int(options, "groups",1); int groups = option_find_int(options, "groups",1);
softmax_layer *layer = make_softmax_layer(params.batch, params.inputs, groups); softmax_layer layer = make_softmax_layer(params.batch, params.inputs, groups);
option_unused(options); option_unused(options);
return layer; return layer;
} }
detection_layer *parse_detection(list *options, size_params params) detection_layer parse_detection(list *options, size_params params)
{ {
int coords = option_find_int(options, "coords", 1); int coords = option_find_int(options, "coords", 1);
int classes = option_find_int(options, "classes", 1); int classes = option_find_int(options, "classes", 1);
int rescore = option_find_int(options, "rescore", 1); int rescore = option_find_int(options, "rescore", 1);
int nuisance = option_find_int(options, "nuisance", 0); int nuisance = option_find_int(options, "nuisance", 0);
int background = option_find_int(options, "background", 1); int background = option_find_int(options, "background", 1);
detection_layer *layer = make_detection_layer(params.batch, params.inputs, classes, coords, rescore, background, nuisance); detection_layer layer = make_detection_layer(params.batch, params.inputs, classes, coords, rescore, background, nuisance);
option_unused(options); option_unused(options);
return layer; return layer;
} }
cost_layer *parse_cost(list *options, size_params params) cost_layer parse_cost(list *options, size_params params)
{ {
char *type_s = option_find_str(options, "type", "sse"); char *type_s = option_find_str(options, "type", "sse");
COST_TYPE type = get_cost_type(type_s); COST_TYPE type = get_cost_type(type_s);
cost_layer *layer = make_cost_layer(params.batch, params.inputs, type); cost_layer layer = make_cost_layer(params.batch, params.inputs, type);
option_unused(options); option_unused(options);
return layer; return layer;
} }
crop_layer *parse_crop(list *options, size_params params) crop_layer parse_crop(list *options, size_params params)
{ {
int crop_height = option_find_int(options, "crop_height",1); int crop_height = option_find_int(options, "crop_height",1);
int crop_width = option_find_int(options, "crop_width",1); int crop_width = option_find_int(options, "crop_width",1);
@ -199,12 +197,12 @@ crop_layer *parse_crop(list *options, size_params params)
batch=params.batch; batch=params.batch;
if(!(h && w && c)) error("Layer before crop layer must output image."); if(!(h && w && c)) error("Layer before crop layer must output image.");
crop_layer *layer = make_crop_layer(batch,h,w,c,crop_height,crop_width,flip, angle, saturation, exposure); crop_layer l = make_crop_layer(batch,h,w,c,crop_height,crop_width,flip, angle, saturation, exposure);
option_unused(options); option_unused(options);
return layer; return l;
} }
maxpool_layer *parse_maxpool(list *options, size_params params) maxpool_layer parse_maxpool(list *options, size_params params)
{ {
int stride = option_find_int(options, "stride",1); int stride = option_find_int(options, "stride",1);
int size = option_find_int(options, "size",stride); int size = option_find_int(options, "size",stride);
@ -216,39 +214,20 @@ maxpool_layer *parse_maxpool(list *options, size_params params)
batch=params.batch; batch=params.batch;
if(!(h && w && c)) error("Layer before maxpool layer must output image."); if(!(h && w && c)) error("Layer before maxpool layer must output image.");
maxpool_layer *layer = make_maxpool_layer(batch,h,w,c,size,stride); maxpool_layer layer = make_maxpool_layer(batch,h,w,c,size,stride);
option_unused(options); option_unused(options);
return layer; return layer;
} }
dropout_layer *parse_dropout(list *options, size_params params) dropout_layer parse_dropout(list *options, size_params params)
{ {
float probability = option_find_float(options, "probability", .5); float probability = option_find_float(options, "probability", .5);
dropout_layer *layer = make_dropout_layer(params.batch, params.inputs, probability); dropout_layer layer = make_dropout_layer(params.batch, params.inputs, probability);
option_unused(options); option_unused(options);
return layer; return layer;
} }
normalization_layer *parse_normalization(list *options, size_params params) route_layer parse_route(list *options, size_params params, network net)
{
int size = option_find_int(options, "size",1);
float alpha = option_find_float(options, "alpha", 0.);
float beta = option_find_float(options, "beta", 1.);
float kappa = option_find_float(options, "kappa", 1.);
int batch,h,w,c;
h = params.h;
w = params.w;
c = params.c;
batch=params.batch;
if(!(h && w && c)) error("Layer before normalization layer must output image.");
normalization_layer *layer = make_normalization_layer(batch,h,w,c,size, alpha, beta, kappa);
option_unused(options);
return layer;
}
route_layer *parse_route(list *options, size_params params, network net)
{ {
char *l = option_find(options, "layers"); char *l = option_find(options, "layers");
int len = strlen(l); int len = strlen(l);
@ -265,11 +244,26 @@ route_layer *parse_route(list *options, size_params params, network net)
int index = atoi(l); int index = atoi(l);
l = strchr(l, ',')+1; l = strchr(l, ',')+1;
layers[i] = index; layers[i] = index;
sizes[i] = get_network_output_size_layer(net, index); sizes[i] = net.layers[index].outputs;
} }
int batch = params.batch; int batch = params.batch;
route_layer *layer = make_route_layer(batch, n, layers, sizes); route_layer layer = make_route_layer(batch, n, layers, sizes);
convolutional_layer first = net.layers[layers[0]];
layer.out_w = first.out_w;
layer.out_h = first.out_h;
layer.out_c = first.out_c;
for(i = 1; i < n; ++i){
int index = layers[i];
convolutional_layer next = net.layers[index];
if(next.out_w == first.out_w && next.out_h == first.out_h){
layer.out_c += next.out_c;
}else{
layer.out_h = layer.out_w = layer.out_c = 0;
}
}
option_unused(options); option_unused(options);
return layer; return layer;
} }
@ -318,61 +312,44 @@ network parse_network_cfg(char *filename)
fprintf(stderr, "%d: ", count); fprintf(stderr, "%d: ", count);
s = (section *)n->val; s = (section *)n->val;
options = s->options; options = s->options;
layer l = {0};
if(is_convolutional(s)){ if(is_convolutional(s)){
convolutional_layer *layer = parse_convolutional(options, params); l = parse_convolutional(options, params);
net.types[count] = CONVOLUTIONAL;
net.layers[count] = layer;
}else if(is_deconvolutional(s)){ }else if(is_deconvolutional(s)){
deconvolutional_layer *layer = parse_deconvolutional(options, params); l = parse_deconvolutional(options, params);
net.types[count] = DECONVOLUTIONAL;
net.layers[count] = layer;
}else if(is_connected(s)){ }else if(is_connected(s)){
connected_layer *layer = parse_connected(options, params); l = parse_connected(options, params);
net.types[count] = CONNECTED;
net.layers[count] = layer;
}else if(is_crop(s)){ }else if(is_crop(s)){
crop_layer *layer = parse_crop(options, params); l = parse_crop(options, params);
net.types[count] = CROP;
net.layers[count] = layer;
}else if(is_cost(s)){ }else if(is_cost(s)){
cost_layer *layer = parse_cost(options, params); l = parse_cost(options, params);
net.types[count] = COST;
net.layers[count] = layer;
}else if(is_detection(s)){ }else if(is_detection(s)){
detection_layer *layer = parse_detection(options, params); l = parse_detection(options, params);
net.types[count] = DETECTION;
net.layers[count] = layer;
}else if(is_softmax(s)){ }else if(is_softmax(s)){
softmax_layer *layer = parse_softmax(options, params); l = parse_softmax(options, params);
net.types[count] = SOFTMAX;
net.layers[count] = layer;
}else if(is_maxpool(s)){ }else if(is_maxpool(s)){
maxpool_layer *layer = parse_maxpool(options, params); l = parse_maxpool(options, params);
net.types[count] = MAXPOOL;
net.layers[count] = layer;
}else if(is_normalization(s)){
normalization_layer *layer = parse_normalization(options, params);
net.types[count] = NORMALIZATION;
net.layers[count] = layer;
}else if(is_route(s)){ }else if(is_route(s)){
route_layer *layer = parse_route(options, params, net); l = parse_route(options, params, net);
net.types[count] = ROUTE;
net.layers[count] = layer;
}else if(is_dropout(s)){ }else if(is_dropout(s)){
dropout_layer *layer = parse_dropout(options, params); l = parse_dropout(options, params);
net.types[count] = DROPOUT; l.output = net.layers[count-1].output;
net.layers[count] = layer; l.delta = net.layers[count-1].delta;
#ifdef GPU
l.output_gpu = net.layers[count-1].output_gpu;
l.delta_gpu = net.layers[count-1].delta_gpu;
#endif
}else{ }else{
fprintf(stderr, "Type not recognized: %s\n", s->type); fprintf(stderr, "Type not recognized: %s\n", s->type);
} }
net.layers[count] = l;
free_section(s); free_section(s);
n = n->next; n = n->next;
if(n){ if(n){
image im = get_network_image_layer(net, count); params.h = l.out_h;
params.h = im.h; params.w = l.out_w;
params.w = im.w; params.c = l.out_c;
params.c = im.c; params.inputs = l.outputs;
params.inputs = get_network_output_size_layer(net, count);
} }
++count; ++count;
} }
@ -429,11 +406,6 @@ int is_softmax(section *s)
return (strcmp(s->type, "[soft]")==0 return (strcmp(s->type, "[soft]")==0
|| strcmp(s->type, "[softmax]")==0); || strcmp(s->type, "[softmax]")==0);
} }
int is_normalization(section *s)
{
return (strcmp(s->type, "[lrnorm]")==0
|| strcmp(s->type, "[localresponsenormalization]")==0);
}
int is_route(section *s) int is_route(section *s)
{ {
return (strcmp(s->type, "[route]")==0); return (strcmp(s->type, "[route]")==0);
@ -492,114 +464,6 @@ list *read_cfg(char *filename)
return sections; return sections;
} }
void print_convolutional_cfg(FILE *fp, convolutional_layer *l, network net, int count)
{
#ifdef GPU
if(gpu_index >= 0) pull_convolutional_layer(*l);
#endif
int i;
fprintf(fp, "[convolutional]\n");
fprintf(fp, "filters=%d\n"
"size=%d\n"
"stride=%d\n"
"pad=%d\n"
"activation=%s\n",
l->n, l->size, l->stride, l->pad,
get_activation_string(l->activation));
fprintf(fp, "biases=");
for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
fprintf(fp, "\n");
fprintf(fp, "weights=");
for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]);
fprintf(fp, "\n\n");
}
void print_deconvolutional_cfg(FILE *fp, deconvolutional_layer *l, network net, int count)
{
#ifdef GPU
if(gpu_index >= 0) pull_deconvolutional_layer(*l);
#endif
int i;
fprintf(fp, "[deconvolutional]\n");
fprintf(fp, "filters=%d\n"
"size=%d\n"
"stride=%d\n"
"activation=%s\n",
l->n, l->size, l->stride,
get_activation_string(l->activation));
fprintf(fp, "biases=");
for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
fprintf(fp, "\n");
fprintf(fp, "weights=");
for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]);
fprintf(fp, "\n\n");
}
void print_dropout_cfg(FILE *fp, dropout_layer *l, network net, int count)
{
fprintf(fp, "[dropout]\n");
fprintf(fp, "probability=%g\n\n", l->probability);
}
void print_connected_cfg(FILE *fp, connected_layer *l, network net, int count)
{
#ifdef GPU
if(gpu_index >= 0) pull_connected_layer(*l);
#endif
int i;
fprintf(fp, "[connected]\n");
fprintf(fp, "output=%d\n"
"activation=%s\n",
l->outputs,
get_activation_string(l->activation));
fprintf(fp, "biases=");
for(i = 0; i < l->outputs; ++i) fprintf(fp, "%g,", l->biases[i]);
fprintf(fp, "\n");
fprintf(fp, "weights=");
for(i = 0; i < l->outputs*l->inputs; ++i) fprintf(fp, "%g,", l->weights[i]);
fprintf(fp, "\n\n");
}
void print_crop_cfg(FILE *fp, crop_layer *l, network net, int count)
{
fprintf(fp, "[crop]\n");
fprintf(fp, "crop_height=%d\ncrop_width=%d\nflip=%d\n\n", l->crop_height, l->crop_width, l->flip);
}
void print_maxpool_cfg(FILE *fp, maxpool_layer *l, network net, int count)
{
fprintf(fp, "[maxpool]\n");
fprintf(fp, "size=%d\nstride=%d\n\n", l->size, l->stride);
}
void print_normalization_cfg(FILE *fp, normalization_layer *l, network net, int count)
{
fprintf(fp, "[localresponsenormalization]\n");
fprintf(fp, "size=%d\n"
"alpha=%g\n"
"beta=%g\n"
"kappa=%g\n\n", l->size, l->alpha, l->beta, l->kappa);
}
void print_softmax_cfg(FILE *fp, softmax_layer *l, network net, int count)
{
fprintf(fp, "[softmax]\n");
fprintf(fp, "\n");
}
void print_detection_cfg(FILE *fp, detection_layer *l, network net, int count)
{
fprintf(fp, "[detection]\n");
fprintf(fp, "classes=%d\ncoords=%d\nrescore=%d\nnuisance=%d\n", l->classes, l->coords, l->rescore, l->nuisance);
fprintf(fp, "\n");
}
void print_cost_cfg(FILE *fp, cost_layer *l, network net, int count)
{
fprintf(fp, "[cost]\ntype=%s\n", get_cost_string(l->type));
fprintf(fp, "\n");
}
void save_weights(network net, char *filename) void save_weights(network net, char *filename)
{ {
fprintf(stderr, "Saving weights to %s\n", filename); fprintf(stderr, "Saving weights to %s\n", filename);
@ -613,37 +477,35 @@ void save_weights(network net, char *filename)
int i; int i;
for(i = 0; i < net.n; ++i){ for(i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){ layer l = net.layers[i];
convolutional_layer layer = *(convolutional_layer *) net.layers[i]; if(l.type == CONVOLUTIONAL){
#ifdef GPU #ifdef GPU
if(gpu_index >= 0){ if(gpu_index >= 0){
pull_convolutional_layer(layer); pull_convolutional_layer(l);
} }
#endif #endif
int num = layer.n*layer.c*layer.size*layer.size; int num = l.n*l.c*l.size*l.size;
fwrite(layer.biases, sizeof(float), layer.n, fp); fwrite(l.biases, sizeof(float), l.n, fp);
fwrite(layer.filters, sizeof(float), num, fp); fwrite(l.filters, sizeof(float), num, fp);
} }
if(net.types[i] == DECONVOLUTIONAL){ if(l.type == DECONVOLUTIONAL){
deconvolutional_layer layer = *(deconvolutional_layer *) net.layers[i];
#ifdef GPU #ifdef GPU
if(gpu_index >= 0){ if(gpu_index >= 0){
pull_deconvolutional_layer(layer); pull_deconvolutional_layer(l);
} }
#endif #endif
int num = layer.n*layer.c*layer.size*layer.size; int num = l.n*l.c*l.size*l.size;
fwrite(layer.biases, sizeof(float), layer.n, fp); fwrite(l.biases, sizeof(float), l.n, fp);
fwrite(layer.filters, sizeof(float), num, fp); fwrite(l.filters, sizeof(float), num, fp);
} }
if(net.types[i] == CONNECTED){ if(l.type == CONNECTED){
connected_layer layer = *(connected_layer *) net.layers[i];
#ifdef GPU #ifdef GPU
if(gpu_index >= 0){ if(gpu_index >= 0){
pull_connected_layer(layer); pull_connected_layer(l);
} }
#endif #endif
fwrite(layer.biases, sizeof(float), layer.outputs, fp); fwrite(l.biases, sizeof(float), l.outputs, fp);
fwrite(layer.weights, sizeof(float), layer.outputs*layer.inputs, fp); fwrite(l.weights, sizeof(float), l.outputs*l.inputs, fp);
} }
} }
fclose(fp); fclose(fp);
@ -663,35 +525,33 @@ void load_weights_upto(network *net, char *filename, int cutoff)
int i; int i;
for(i = 0; i < net->n && i < cutoff; ++i){ for(i = 0; i < net->n && i < cutoff; ++i){
if(net->types[i] == CONVOLUTIONAL){ layer l = net->layers[i];
convolutional_layer layer = *(convolutional_layer *) net->layers[i]; if(l.type == CONVOLUTIONAL){
int num = layer.n*layer.c*layer.size*layer.size; int num = l.n*l.c*l.size*l.size;
fread(layer.biases, sizeof(float), layer.n, fp); fread(l.biases, sizeof(float), l.n, fp);
fread(layer.filters, sizeof(float), num, fp); fread(l.filters, sizeof(float), num, fp);
#ifdef GPU #ifdef GPU
if(gpu_index >= 0){ if(gpu_index >= 0){
push_convolutional_layer(layer); push_convolutional_layer(l);
} }
#endif #endif
} }
if(net->types[i] == DECONVOLUTIONAL){ if(l.type == DECONVOLUTIONAL){
deconvolutional_layer layer = *(deconvolutional_layer *) net->layers[i]; int num = l.n*l.c*l.size*l.size;
int num = layer.n*layer.c*layer.size*layer.size; fread(l.biases, sizeof(float), l.n, fp);
fread(layer.biases, sizeof(float), layer.n, fp); fread(l.filters, sizeof(float), num, fp);
fread(layer.filters, sizeof(float), num, fp);
#ifdef GPU #ifdef GPU
if(gpu_index >= 0){ if(gpu_index >= 0){
push_deconvolutional_layer(layer); push_deconvolutional_layer(l);
} }
#endif #endif
} }
if(net->types[i] == CONNECTED){ if(l.type == CONNECTED){
connected_layer layer = *(connected_layer *) net->layers[i]; fread(l.biases, sizeof(float), l.outputs, fp);
fread(layer.biases, sizeof(float), layer.outputs, fp); fread(l.weights, sizeof(float), l.outputs*l.inputs, fp);
fread(layer.weights, sizeof(float), layer.outputs*layer.inputs, fp);
#ifdef GPU #ifdef GPU
if(gpu_index >= 0){ if(gpu_index >= 0){
push_connected_layer(layer); push_connected_layer(l);
} }
#endif #endif
} }
@ -704,34 +564,3 @@ void load_weights(network *net, char *filename)
load_weights_upto(net, filename, net->n); load_weights_upto(net, filename, net->n);
} }
void save_network(network net, char *filename)
{
FILE *fp = fopen(filename, "w");
if(!fp) file_error(filename);
int i;
for(i = 0; i < net.n; ++i)
{
if(net.types[i] == CONVOLUTIONAL)
print_convolutional_cfg(fp, (convolutional_layer *)net.layers[i], net, i);
else if(net.types[i] == DECONVOLUTIONAL)
print_deconvolutional_cfg(fp, (deconvolutional_layer *)net.layers[i], net, i);
else if(net.types[i] == CONNECTED)
print_connected_cfg(fp, (connected_layer *)net.layers[i], net, i);
else if(net.types[i] == CROP)
print_crop_cfg(fp, (crop_layer *)net.layers[i], net, i);
else if(net.types[i] == MAXPOOL)
print_maxpool_cfg(fp, (maxpool_layer *)net.layers[i], net, i);
else if(net.types[i] == DROPOUT)
print_dropout_cfg(fp, (dropout_layer *)net.layers[i], net, i);
else if(net.types[i] == NORMALIZATION)
print_normalization_cfg(fp, (normalization_layer *)net.layers[i], net, i);
else if(net.types[i] == SOFTMAX)
print_softmax_cfg(fp, (softmax_layer *)net.layers[i], net, i);
else if(net.types[i] == DETECTION)
print_detection_cfg(fp, (detection_layer *)net.layers[i], net, i);
else if(net.types[i] == COST)
print_cost_cfg(fp, (cost_layer *)net.layers[i], net, i);
}
fclose(fp);
}

View File

@ -3,83 +3,89 @@
#include "blas.h" #include "blas.h"
#include <stdio.h> #include <stdio.h>
route_layer *make_route_layer(int batch, int n, int *input_layers, int *input_sizes) route_layer make_route_layer(int batch, int n, int *input_layers, int *input_sizes)
{ {
printf("Route Layer:"); fprintf(stderr,"Route Layer:");
route_layer *layer = calloc(1, sizeof(route_layer)); route_layer l = {0};
layer->batch = batch; l.type = ROUTE;
layer->n = n; l.batch = batch;
layer->input_layers = input_layers; l.n = n;
layer->input_sizes = input_sizes; l.input_layers = input_layers;
l.input_sizes = input_sizes;
int i; int i;
int outputs = 0; int outputs = 0;
for(i = 0; i < n; ++i){ for(i = 0; i < n; ++i){
printf(" %d", input_layers[i]); fprintf(stderr," %d", input_layers[i]);
outputs += input_sizes[i]; outputs += input_sizes[i];
} }
printf("\n"); fprintf(stderr, "\n");
layer->outputs = outputs; l.outputs = outputs;
layer->delta = calloc(outputs*batch, sizeof(float)); l.inputs = outputs;
layer->output = calloc(outputs*batch, sizeof(float));; l.delta = calloc(outputs*batch, sizeof(float));
l.output = calloc(outputs*batch, sizeof(float));;
#ifdef GPU #ifdef GPU
layer->delta_gpu = cuda_make_array(0, outputs*batch); l.delta_gpu = cuda_make_array(0, outputs*batch);
layer->output_gpu = cuda_make_array(0, outputs*batch); l.output_gpu = cuda_make_array(0, outputs*batch);
#endif #endif
return layer; return l;
} }
void forward_route_layer(const route_layer layer, network net) void forward_route_layer(const route_layer l, network net)
{ {
int i, j; int i, j;
int offset = 0; int offset = 0;
for(i = 0; i < layer.n; ++i){ for(i = 0; i < l.n; ++i){
float *input = get_network_output_layer(net, layer.input_layers[i]); int index = l.input_layers[i];
int input_size = layer.input_sizes[i]; float *input = net.layers[index].output;
for(j = 0; j < layer.batch; ++j){ int input_size = l.input_sizes[i];
copy_cpu(input_size, input + j*input_size, 1, layer.output + offset + j*layer.outputs, 1); for(j = 0; j < l.batch; ++j){
copy_cpu(input_size, input + j*input_size, 1, l.output + offset + j*l.outputs, 1);
} }
offset += input_size; offset += input_size;
} }
} }
void backward_route_layer(const route_layer layer, network net) void backward_route_layer(const route_layer l, network net)
{ {
int i, j; int i, j;
int offset = 0; int offset = 0;
for(i = 0; i < layer.n; ++i){ for(i = 0; i < l.n; ++i){
float *delta = get_network_delta_layer(net, layer.input_layers[i]); int index = l.input_layers[i];
int input_size = layer.input_sizes[i]; float *delta = net.layers[index].delta;
for(j = 0; j < layer.batch; ++j){ int input_size = l.input_sizes[i];
copy_cpu(input_size, layer.delta + offset + j*layer.outputs, 1, delta + j*input_size, 1); for(j = 0; j < l.batch; ++j){
copy_cpu(input_size, l.delta + offset + j*l.outputs, 1, delta + j*input_size, 1);
} }
offset += input_size; offset += input_size;
} }
} }
#ifdef GPU #ifdef GPU
void forward_route_layer_gpu(const route_layer layer, network net) void forward_route_layer_gpu(const route_layer l, network net)
{ {
int i, j; int i, j;
int offset = 0; int offset = 0;
for(i = 0; i < layer.n; ++i){ for(i = 0; i < l.n; ++i){
float *input = get_network_output_gpu_layer(net, layer.input_layers[i]); int index = l.input_layers[i];
int input_size = layer.input_sizes[i]; float *input = net.layers[index].output_gpu;
for(j = 0; j < layer.batch; ++j){ int input_size = l.input_sizes[i];
copy_ongpu(input_size, input + j*input_size, 1, layer.output_gpu + offset + j*layer.outputs, 1); for(j = 0; j < l.batch; ++j){
copy_ongpu(input_size, input + j*input_size, 1, l.output_gpu + offset + j*l.outputs, 1);
} }
offset += input_size; offset += input_size;
} }
} }
void backward_route_layer_gpu(const route_layer layer, network net) void backward_route_layer_gpu(const route_layer l, network net)
{ {
int i, j; int i, j;
int offset = 0; int offset = 0;
for(i = 0; i < layer.n; ++i){ for(i = 0; i < l.n; ++i){
float *delta = get_network_delta_gpu_layer(net, layer.input_layers[i]); int index = l.input_layers[i];
int input_size = layer.input_sizes[i]; float *delta = net.layers[index].delta_gpu;
for(j = 0; j < layer.batch; ++j){ int input_size = l.input_sizes[i];
copy_ongpu(input_size, layer.delta_gpu + offset + j*layer.outputs, 1, delta + j*input_size, 1); for(j = 0; j < l.batch; ++j){
copy_ongpu(input_size, l.delta_gpu + offset + j*l.outputs, 1, delta + j*input_size, 1);
} }
offset += input_size; offset += input_size;
} }

View File

@ -1,28 +1,17 @@
#ifndef ROUTE_LAYER_H #ifndef ROUTE_LAYER_H
#define ROUTE_LAYER_H #define ROUTE_LAYER_H
#include "network.h" #include "network.h"
#include "layer.h"
typedef struct { typedef layer route_layer;
int batch;
int outputs;
int n;
int * input_layers;
int * input_sizes;
float * delta;
float * output;
#ifdef GPU
float * delta_gpu;
float * output_gpu;
#endif
} route_layer;
route_layer *make_route_layer(int batch, int n, int *input_layers, int *input_size); route_layer make_route_layer(int batch, int n, int *input_layers, int *input_size);
void forward_route_layer(const route_layer layer, network net); void forward_route_layer(const route_layer l, network net);
void backward_route_layer(const route_layer layer, network net); void backward_route_layer(const route_layer l, network net);
#ifdef GPU #ifdef GPU
void forward_route_layer_gpu(const route_layer layer, network net); void forward_route_layer_gpu(const route_layer l, network net);
void backward_route_layer_gpu(const route_layer layer, network net); void backward_route_layer_gpu(const route_layer l, network net);
#endif #endif
#endif #endif

View File

@ -7,21 +7,23 @@
#include <stdio.h> #include <stdio.h>
#include <assert.h> #include <assert.h>
softmax_layer *make_softmax_layer(int batch, int inputs, int groups) softmax_layer make_softmax_layer(int batch, int inputs, int groups)
{ {
assert(inputs%groups == 0); assert(inputs%groups == 0);
fprintf(stderr, "Softmax Layer: %d inputs\n", inputs); fprintf(stderr, "Softmax Layer: %d inputs\n", inputs);
softmax_layer *layer = calloc(1, sizeof(softmax_layer)); softmax_layer l = {0};
layer->batch = batch; l.type = SOFTMAX;
layer->groups = groups; l.batch = batch;
layer->inputs = inputs; l.groups = groups;
layer->output = calloc(inputs*batch, sizeof(float)); l.inputs = inputs;
layer->delta = calloc(inputs*batch, sizeof(float)); l.outputs = inputs;
l.output = calloc(inputs*batch, sizeof(float));
l.delta = calloc(inputs*batch, sizeof(float));
#ifdef GPU #ifdef GPU
layer->output_gpu = cuda_make_array(layer->output, inputs*batch); l.output_gpu = cuda_make_array(l.output, inputs*batch);
layer->delta_gpu = cuda_make_array(layer->delta, inputs*batch); l.delta_gpu = cuda_make_array(l.delta, inputs*batch);
#endif #endif
return layer; return l;
} }
void softmax_array(float *input, int n, float *output) void softmax_array(float *input, int n, float *output)
@ -42,21 +44,21 @@ void softmax_array(float *input, int n, float *output)
} }
} }
void forward_softmax_layer(const softmax_layer layer, network_state state) void forward_softmax_layer(const softmax_layer l, network_state state)
{ {
int b; int b;
int inputs = layer.inputs / layer.groups; int inputs = l.inputs / l.groups;
int batch = layer.batch * layer.groups; int batch = l.batch * l.groups;
for(b = 0; b < batch; ++b){ for(b = 0; b < batch; ++b){
softmax_array(state.input+b*inputs, inputs, layer.output+b*inputs); softmax_array(state.input+b*inputs, inputs, l.output+b*inputs);
} }
} }
void backward_softmax_layer(const softmax_layer layer, network_state state) void backward_softmax_layer(const softmax_layer l, network_state state)
{ {
int i; int i;
for(i = 0; i < layer.inputs*layer.batch; ++i){ for(i = 0; i < l.inputs*l.batch; ++i){
state.delta[i] = layer.delta[i]; state.delta[i] = l.delta[i];
} }
} }

View File

@ -1,28 +1,19 @@
#ifndef SOFTMAX_LAYER_H #ifndef SOFTMAX_LAYER_H
#define SOFTMAX_LAYER_H #define SOFTMAX_LAYER_H
#include "params.h" #include "params.h"
#include "layer.h"
typedef struct { typedef layer softmax_layer;
int inputs;
int batch;
int groups;
float *delta;
float *output;
#ifdef GPU
float * delta_gpu;
float * output_gpu;
#endif
} softmax_layer;
void softmax_array(float *input, int n, float *output); void softmax_array(float *input, int n, float *output);
softmax_layer *make_softmax_layer(int batch, int inputs, int groups); softmax_layer make_softmax_layer(int batch, int inputs, int groups);
void forward_softmax_layer(const softmax_layer layer, network_state state); void forward_softmax_layer(const softmax_layer l, network_state state);
void backward_softmax_layer(const softmax_layer layer, network_state state); void backward_softmax_layer(const softmax_layer l, network_state state);
#ifdef GPU #ifdef GPU
void pull_softmax_layer_output(const softmax_layer layer); void pull_softmax_layer_output(const softmax_layer l);
void forward_softmax_layer_gpu(const softmax_layer layer, network_state state); void forward_softmax_layer_gpu(const softmax_layer l, network_state state);
void backward_softmax_layer_gpu(const softmax_layer layer, network_state state); void backward_softmax_layer_gpu(const softmax_layer l, network_state state);
#endif #endif
#endif #endif