shortcut layers, msr networks

This commit is contained in:
Joseph Redmon 2015-12-14 11:57:10 -08:00
parent 892923514f
commit db0397cfaa
35 changed files with 2635 additions and 56 deletions

View File

@ -34,7 +34,7 @@ CFLAGS+= -DGPU
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o
ifeq ($(GPU), 1)
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o yolo_kernels.o coco_kernels.o
endif

1408
cfg/msr_152.cfg Normal file

File diff suppressed because it is too large Load Diff

371
cfg/msr_34.cfg Normal file
View File

@ -0,0 +1,371 @@
[net]
batch=128
subdivisions=1
height=256
width=256
channels=3
momentum=0.9
decay=0.0005
learning_rate=0.1
policy=poly
power=4
max_batches=500000
#policy=sigmoid
#gamma=.00008
#step=100000
#max_batches=200000
[crop]
crop_height=224
crop_width=224
flip=1
saturation=1
exposure=1
angle=0
[convolutional]
batch_normalize=1
filters=64
size=7
stride=2
pad=1
activation=leaky
[maxpool]
size=3
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=128
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=256
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=512
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from = -3
[avgpool]
[connected]
output=1000
activation=leaky
[softmax]
groups=1
[cost]
type=sse

490
cfg/msr_50.cfg Normal file
View File

@ -0,0 +1,490 @@
[net]
batch=128
subdivisions=4
height=256
width=256
channels=3
momentum=0.9
decay=0.0005
learning_rate=0.01
[crop]
crop_height=224
crop_width=224
flip=1
saturation=1
exposure=1
angle=0
##### Conv 1 #####
[convolutional]
batch_normalize=1
filters=64
size=7
stride=2
pad=1
activation=leaky
[maxpool]
size=3
stride=2
##### Conv 2_x #####
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
##### Conv 3_x #####
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
##### Conv 4_x #####
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
##### Conv 5_x #####
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=2048
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=2048
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=2048
size=1
stride=1
pad=1
activation=leaky
[shortcut]
from = -4
[avgpool]
[connected]
output=1000
activation=leaky
[softmax]
groups=1
[cost]
type=sse

View File

@ -2,9 +2,9 @@
#define AVGPOOL_LAYER_H
#include "image.h"
#include "params.h"
#include "cuda.h"
#include "layer.h"
#include "network.h"
typedef layer avgpool_layer;

View File

@ -1,6 +1,22 @@
#include "blas.h"
#include "math.h"
void shortcut_cpu(float *out, int w, int h, int c, int batch, int sample, float *add, int stride, int c2)
{
int i,j,k,b;
for(b = 0; b < batch; ++b){
for(k = 0; k < c && k < c2; ++k){
for(j = 0; j < h/sample; ++j){
for(i = 0; i < w/sample; ++i){
int out_index = i*sample + w*(j*sample + h*(k + c*b));
int add_index = b*w*stride/sample*h*stride/sample*c2 + i*stride + w*stride/sample*(j*stride + h*stride/sample*k);
out[out_index] += add[add_index];
}
}
}
}
}
void mean_cpu(float *x, int batch, int filters, int spatial, float *mean)
{
float scale = 1./(batch * spatial);

View File

@ -15,6 +15,7 @@ void copy_cpu(int N, float *X, int INCX, float *Y, int INCY);
void scal_cpu(int N, float ALPHA, float *X, int INCX);
float dot_cpu(int N, float *X, int INCX, float *Y, int INCY);
void test_gpu_blas();
void shortcut_cpu(float *out, int w, int h, int c, int batch, int sample, float *add, int stride, int c2);
void mean_cpu(float *x, int batch, int filters, int spatial, float *mean);
void variance_cpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
@ -43,5 +44,6 @@ void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *varianc
void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean);
void shortcut_gpu(float *out, int w, int h, int c, int batch, int sample, float *add, int stride, int c2);
#endif
#endif

View File

@ -228,6 +228,7 @@ __global__ void mul_kernel(int N, float *X, int INCX, float *Y, int INCY)
if(i < N) Y[i*INCY] *= X[i*INCX];
}
extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial)
{
size_t N = batch*filters*spatial;
@ -372,3 +373,27 @@ extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX)
fill_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
check_error(cudaPeekAtLastError());
}
__global__ void shortcut_kernel(int size, float *out, int w, int h, int c, int batch, int sample, float *add, int stride, int c2, int min_c)
{
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (id >= size) return;
int i = id % (w/sample);
id /= (w/sample);
int j = id % (h/sample);
id /= (h/sample);
int k = id % min_c;
id /= min_c;
int b = id;
int out_index = i*sample + w*(j*sample + h*(k + c*b));
int add_index = b*w*stride/sample*h*stride/sample*c2 + i*stride + w*stride/sample*(j*stride + h*stride/sample*k);
out[out_index] += add[add_index];
}
extern "C" void shortcut_gpu(float *out, int w, int h, int c, int batch, int sample, float *add, int stride, int c2)
{
int min_c = (c < c2) ? c : c2;
int size = batch * w/sample * h/sample * min_c;
shortcut_kernel<<<cuda_gridsize(size), BLOCK>>>(size, out, w, h, c, batch, sample, add, stride, c2, min_c);
check_error(cudaPeekAtLastError());
}

View File

@ -131,7 +131,7 @@ void validate_classifier(char *datacfg, char *filename, char *weightfile)
char *label_list = option_find_str(options, "labels", "data/labels.list");
char *valid_list = option_find_str(options, "valid", "data/train.list");
int classes = option_find_int(options, "classes", 2);
int topk = option_find_int(options, "topk", 1);
int topk = option_find_int(options, "top", 1);
char **labels = get_labels(label_list);
list *plist = get_paths(valid_list);
@ -194,11 +194,12 @@ void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *fi
list *options = read_data_cfg(datacfg);
char *label_list = option_find_str(options, "labels", "data/labels.list");
char *name_list = option_find_str(options, "names", 0);
if(!name_list) name_list = option_find_str(options, "labels", "data/labels.list");
int top = option_find_int(options, "top", 1);
int i = 0;
char **names = get_labels(label_list);
char **names = get_labels(name_list);
clock_t time;
int indexes[10];
char buff[256];

View File

@ -25,13 +25,13 @@ connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVAT
l.weight_updates = calloc(inputs*outputs, sizeof(float));
l.bias_updates = calloc(outputs, sizeof(float));
l.weights = calloc(inputs*outputs, sizeof(float));
l.weights = calloc(outputs*inputs, sizeof(float));
l.biases = calloc(outputs, sizeof(float));
//float scale = 1./sqrt(inputs);
float scale = sqrt(2./inputs);
for(i = 0; i < inputs*outputs; ++i){
for(i = 0; i < outputs*inputs; ++i){
l.weights[i] = 2*scale*rand_uniform() - scale;
}
@ -40,10 +40,10 @@ connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVAT
}
#ifdef GPU
l.weights_gpu = cuda_make_array(l.weights, inputs*outputs);
l.weights_gpu = cuda_make_array(l.weights, outputs*inputs);
l.biases_gpu = cuda_make_array(l.biases, outputs);
l.weight_updates_gpu = cuda_make_array(l.weight_updates, inputs*outputs);
l.weight_updates_gpu = cuda_make_array(l.weight_updates, outputs*inputs);
l.bias_updates_gpu = cuda_make_array(l.bias_updates, outputs);
l.output_gpu = cuda_make_array(l.output, outputs*batch);
@ -76,7 +76,7 @@ void forward_connected_layer(connected_layer l, network_state state)
float *a = state.input;
float *b = l.weights;
float *c = l.output;
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
activate_array(l.output, l.outputs*l.batch, l.activation);
}
@ -87,11 +87,11 @@ void backward_connected_layer(connected_layer l, network_state state)
for(i = 0; i < l.batch; ++i){
axpy_cpu(l.outputs, 1, l.delta + i*l.outputs, 1, l.bias_updates, 1);
}
int m = l.inputs;
int m = l.outputs;
int k = l.batch;
int n = l.outputs;
float *a = state.input;
float *b = l.delta;
int n = l.inputs;
float *a = l.delta;
float *b = state.input;
float *c = l.weight_updates;
gemm(1,0,m,n,k,1,a,m,b,n,1,c,n);
@ -103,7 +103,7 @@ void backward_connected_layer(connected_layer l, network_state state)
b = l.weights;
c = state.delta;
if(c) gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
if(c) gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
}
#ifdef GPU
@ -146,7 +146,7 @@ void forward_connected_layer_gpu(connected_layer l, network_state state)
float * a = state.input;
float * b = l.weights_gpu;
float * c = l.output_gpu;
gemm_ongpu(0,0,m,n,k,1,a,k,b,n,1,c,n);
gemm_ongpu(0,1,m,n,k,1,a,k,b,k,1,c,n);
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
/*
@ -163,11 +163,11 @@ void backward_connected_layer_gpu(connected_layer l, network_state state)
for(i = 0; i < l.batch; ++i){
axpy_ongpu_offset(l.outputs, 1, l.delta_gpu, i*l.outputs, 1, l.bias_updates_gpu, 0, 1);
}
int m = l.inputs;
int m = l.outputs;
int k = l.batch;
int n = l.outputs;
float * a = state.input;
float * b = l.delta_gpu;
int n = l.inputs;
float * a = l.delta_gpu;
float * b = state.input;
float * c = l.weight_updates_gpu;
gemm_ongpu(1,0,m,n,k,1,a,m,b,n,1,c,n);
@ -179,6 +179,6 @@ void backward_connected_layer_gpu(connected_layer l, network_state state)
b = l.weights_gpu;
c = state.delta;
if(c) gemm_ongpu(0,1,m,n,k,1,a,k,b,k,1,c,n);
if(c) gemm_ongpu(0,0,m,n,k,1,a,k,b,n,1,c,n);
}
#endif

View File

@ -2,8 +2,8 @@
#define CONNECTED_LAYER_H
#include "activations.h"
#include "params.h"
#include "layer.h"
#include "network.h"
typedef layer connected_layer;

View File

@ -6,6 +6,7 @@
#include "image.h"
#include "activations.h"
#include "layer.h"
#include "network.h"
typedef layer convolutional_layer;

View File

@ -1,7 +1,7 @@
#ifndef COST_LAYER_H
#define COST_LAYER_H
#include "params.h"
#include "layer.h"
#include "network.h"
typedef layer cost_layer;

View File

@ -4,6 +4,7 @@
#include "image.h"
#include "params.h"
#include "layer.h"
#include "network.h"
typedef layer crop_layer;

View File

@ -67,6 +67,7 @@ float *cuda_make_array(float *x, int n)
status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
check_error(status);
}
if(!x_gpu) error("Cuda malloc failed\n");
return x_gpu;
}

View File

@ -149,6 +149,43 @@ void correct_boxes(box_label *boxes, int n, float dx, float dy, float sx, float
}
}
void fill_truth_swag(char *path, float *truth, int classes, int flip, float dx, float dy, float sx, float sy)
{
char *labelpath = find_replace(path, "images", "labels");
labelpath = find_replace(labelpath, "JPEGImages", "labels");
labelpath = find_replace(labelpath, ".jpg", ".txt");
labelpath = find_replace(labelpath, ".JPG", ".txt");
labelpath = find_replace(labelpath, ".JPEG", ".txt");
int count = 0;
box_label *boxes = read_boxes(labelpath, &count);
randomize_boxes(boxes, count);
correct_boxes(boxes, count, dx, dy, sx, sy, flip);
float x,y,w,h;
int id;
int i;
for (i = 0; i < count && i < 30; ++i) {
x = boxes[i].x;
y = boxes[i].y;
w = boxes[i].w;
h = boxes[i].h;
id = boxes[i].id;
if (w < .0 || h < .0) continue;
int index = (4+classes) * i;
truth[index++] = x;
truth[index++] = y;
truth[index++] = w;
truth[index++] = h;
if (id < classes) truth[index+id] = 1;
}
free(boxes);
}
void fill_truth_region(char *path, float *truth, int classes, int num_boxes, int flip, float dx, float dy, float sx, float sy)
{
char *labelpath = find_replace(path, "images", "labels");
@ -482,6 +519,59 @@ data load_data_compare(int n, char **paths, int m, int classes, int w, int h)
return d;
}
data load_data_swag(char **paths, int n, int classes, float jitter)
{
int index = rand_r(&data_seed)%n;
char *random_path = paths[index];
image orig = load_image_color(random_path, 0, 0);
int h = orig.h;
int w = orig.w;
data d;
d.shallow = 0;
d.w = w;
d.h = h;
d.X.rows = 1;
d.X.vals = calloc(d.X.rows, sizeof(float*));
d.X.cols = h*w*3;
int k = (4+classes)*30;
d.y = make_matrix(1, k);
int dw = w*jitter;
int dh = h*jitter;
int pleft = (rand_uniform() * 2*dw - dw);
int pright = (rand_uniform() * 2*dw - dw);
int ptop = (rand_uniform() * 2*dh - dh);
int pbot = (rand_uniform() * 2*dh - dh);
int swidth = w - pleft - pright;
int sheight = h - ptop - pbot;
float sx = (float)swidth / w;
float sy = (float)sheight / h;
int flip = rand_r(&data_seed)%2;
image cropped = crop_image(orig, pleft, ptop, swidth, sheight);
float dx = ((float)pleft/w)/sx;
float dy = ((float)ptop /h)/sy;
image sized = resize_image(cropped, w, h);
if(flip) flip_image(sized);
d.X.vals[0] = sized.data;
fill_truth_swag(random_path, d.y.vals[0], classes, flip, dx, dy, 1./sx, 1./sy);
free_image(orig);
free_image(cropped);
return d;
}
data load_data_detection(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background)
{
char **random_paths = get_random_paths(paths, n, m);
@ -559,6 +649,8 @@ void *load_thread(void *ptr)
*a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h);
} else if (a.type == REGION_DATA){
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter);
} else if (a.type == SWAG_DATA){
*a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
} else if (a.type == COMPARE_DATA){
*a.d = load_data_compare(a.n, a.paths, a.m, a.classes, a.w, a.h);
} else if (a.type == IMAGE_DATA){

View File

@ -20,13 +20,14 @@ static inline float distance_from_edge(int x, int max)
}
typedef struct{
int w, h;
matrix X;
matrix y;
int shallow;
} data;
typedef enum {
CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA
CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA
} data_type;
typedef struct load_args{

View File

@ -6,6 +6,7 @@
#include "image.h"
#include "activations.h"
#include "layer.h"
#include "network.h"
typedef layer deconvolutional_layer;

View File

@ -146,7 +146,7 @@ void forward_detection_layer(const detection_layer l, network_state state)
}
float iou = box_iou(out, truth);
//printf("%d", best_index);
//printf("%d,", best_index);
int p_index = index + locations*l.classes + i*l.n + best_index;
*(l.cost) -= l.noobject_scale * pow(l.output[p_index], 2);
*(l.cost) += l.object_scale * pow(1-l.output[p_index], 2);

View File

@ -1,8 +1,8 @@
#ifndef REGION_LAYER_H
#define REGION_LAYER_H
#include "params.h"
#include "layer.h"
#include "network.h"
typedef layer detection_layer;

View File

@ -3,6 +3,7 @@
#include "params.h"
#include "layer.h"
#include "network.h"
typedef layer dropout_layer;

View File

@ -1,5 +1,6 @@
#include "image.h"
#include "utils.h"
#include "blas.h"
#include <stdio.h>
#include <math.h>
@ -708,8 +709,14 @@ void test_resize(char *filename)
image exp5 = copy_image(im);
exposure_image(exp5, .5);
image r = resize_image(im, im.w/2, im.h/2);
image black = make_image(im.w, im.h, im.c);
shortcut_cpu(black.data, im.w, im.h, im.c, 1, 2, r.data, 1, r.c);
show_image(im, "Original");
show_image(gray, "Gray");
show_image(black, "Black");
show_image(sat2, "Saturation-2");
show_image(sat5, "Saturation-.5");
show_image(exp2, "Exposure-2");

View File

@ -3,6 +3,9 @@
#include "activations.h"
struct layer;
typedef struct layer layer;
typedef enum {
CONVOLUTIONAL,
DECONVOLUTIONAL,
@ -16,20 +19,22 @@ typedef enum {
COST,
NORMALIZATION,
AVGPOOL,
LOCAL
LOCAL,
SHORTCUT
} LAYER_TYPE;
typedef enum{
SSE, MASKED
} COST_TYPE;
typedef struct {
struct layer{
LAYER_TYPE type;
ACTIVATION activation;
COST_TYPE cost_type;
int batch_normalize;
int batch;
int forced;
int flipped;
int inputs;
int outputs;
int truths;
@ -45,6 +50,7 @@ typedef struct {
int crop_height;
int sqrt;
int flip;
int index;
float angle;
float jitter;
float saturation;
@ -144,7 +150,7 @@ typedef struct {
float * squared_gpu;
float * norms_gpu;
#endif
} layer;
};
void free_layer(layer);

View File

@ -2,10 +2,10 @@
#define LOCAL_LAYER_H
#include "cuda.h"
#include "params.h"
#include "image.h"
#include "activations.h"
#include "layer.h"
#include "network.h"
typedef layer local_layer;

View File

@ -5,6 +5,7 @@
#include "params.h"
#include "cuda.h"
#include "layer.h"
#include "network.h"
typedef layer maxpool_layer;

View File

@ -19,6 +19,7 @@
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "route_layer.h"
#include "shortcut_layer.h"
int get_current_batch(network net)
{
@ -94,6 +95,8 @@ char *get_layer_string(LAYER_TYPE a)
return "cost";
case ROUTE:
return "route";
case SHORTCUT:
return "shortcut";
case NORMALIZATION:
return "normalization";
default:
@ -119,6 +122,7 @@ void forward_network(network net, network_state state)
{
int i;
for(i = 0; i < net.n; ++i){
state.index = i;
layer l = net.layers[i];
if(l.delta){
scal_cpu(l.outputs * l.batch, 0, l.delta, 1);
@ -149,6 +153,8 @@ void forward_network(network net, network_state state)
forward_dropout_layer(l, state);
} else if(l.type == ROUTE){
forward_route_layer(l, net);
} else if(l.type == SHORTCUT){
forward_shortcut_layer(l, state);
}
state.input = l.output;
}
@ -211,6 +217,7 @@ void backward_network(network net, network_state state)
float *original_input = state.input;
float *original_delta = state.delta;
for(i = net.n-1; i >= 0; --i){
state.index = i;
if(i == 0){
state.input = original_input;
state.delta = original_delta;
@ -244,6 +251,8 @@ void backward_network(network net, network_state state)
backward_cost_layer(l, state);
} else if(l.type == ROUTE){
backward_route_layer(l, net);
} else if(l.type == SHORTCUT){
backward_shortcut_layer(l, state);
}
}
}
@ -255,6 +264,8 @@ float train_network_datum(network net, float *x, float *y)
if(gpu_index >= 0) return train_network_datum_gpu(net, x, y);
#endif
network_state state;
state.index = 0;
state.net = net;
state.input = x;
state.delta = 0;
state.truth = y;
@ -307,6 +318,8 @@ float train_network_batch(network net, data d, int n)
{
int i,j;
network_state state;
state.index = 0;
state.net = net;
state.train = 1;
state.delta = 0;
float sum = 0;
@ -443,6 +456,8 @@ float *network_predict(network net, float *input)
#endif
network_state state;
state.net = net;
state.index = 0;
state.input = input;
state.truth = 0;
state.train = 0;

View File

@ -3,15 +3,15 @@
#define NETWORK_H
#include "image.h"
#include "detection_layer.h"
#include "layer.h"
#include "data.h"
#include "params.h"
typedef enum {
CONSTANT, STEP, EXP, POLY, STEPS, SIG
} learning_rate_policy;
typedef struct {
typedef struct network{
int n;
int batch;
int *seen;
@ -43,6 +43,15 @@ typedef struct {
#endif
} network;
typedef struct network_state {
float *truth;
float *input;
float *delta;
int train;
int index;
network net;
} network_state;
#ifdef GPU
float train_network_datum_gpu(network net, float *x, float *y);
float *network_predict_gpu(network net, float *input);

View File

@ -27,6 +27,7 @@ extern "C" {
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "route_layer.h"
#include "shortcut_layer.h"
#include "blas.h"
}
@ -38,6 +39,7 @@ void forward_network_gpu(network net, network_state state)
{
int i;
for(i = 0; i < net.n; ++i){
state.index = i;
layer l = net.layers[i];
if(l.delta_gpu){
fill_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1);
@ -68,6 +70,8 @@ void forward_network_gpu(network net, network_state state)
forward_dropout_layer_gpu(l, state);
} else if(l.type == ROUTE){
forward_route_layer_gpu(l, net);
} else if(l.type == SHORTCUT){
forward_shortcut_layer_gpu(l, state);
}
state.input = l.output_gpu;
}
@ -79,6 +83,7 @@ void backward_network_gpu(network net, network_state state)
float * original_input = state.input;
float * original_delta = state.delta;
for(i = net.n-1; i >= 0; --i){
state.index = i;
layer l = net.layers[i];
if(i == 0){
state.input = original_input;
@ -112,6 +117,8 @@ void backward_network_gpu(network net, network_state state)
backward_cost_layer_gpu(l, state);
} else if(l.type == ROUTE){
backward_route_layer_gpu(l, net);
} else if(l.type == SHORTCUT){
backward_shortcut_layer_gpu(l, state);
}
}
}
@ -138,6 +145,8 @@ void update_network_gpu(network net)
float train_network_datum_gpu(network net, float *x, float *y)
{
network_state state;
state.index = 0;
state.net = net;
int x_size = get_network_input_size(net)*net.batch;
int y_size = get_network_output_size(net)*net.batch;
if(net.layers[net.n-1].type == DETECTION) y_size = net.layers[net.n-1].truths*net.batch;
@ -178,6 +187,8 @@ float *network_predict_gpu(network net, float *input)
{
int size = get_network_input_size(net) * net.batch;
network_state state;
state.index = 0;
state.net = net;
state.input = cuda_make_array(input, size);
state.truth = 0;
state.train = 0;

View File

@ -3,7 +3,7 @@
#include "image.h"
#include "layer.h"
#include "params.h"
#include "network.h"
layer make_normalization_layer(int batch, int w, int h, int c, int size, float alpha, float beta, float kappa);
void resize_normalization_layer(layer *layer, int h, int w);

View File

@ -1,12 +1 @@
#ifndef PARAMS_H
#define PARAMS_H
typedef struct {
float *truth;
float *input;
float *delta;
int train;
} network_state;
#endif

View File

@ -17,6 +17,7 @@
#include "avgpool_layer.h"
#include "local_layer.h"
#include "route_layer.h"
#include "shortcut_layer.h"
#include "list.h"
#include "option_list.h"
#include "utils.h"
@ -37,6 +38,7 @@ int is_dropout(section *s);
int is_softmax(section *s);
int is_normalization(section *s);
int is_crop(section *s);
int is_shortcut(section *s);
int is_cost(section *s);
int is_detection(section *s);
int is_route(section *s);
@ -80,6 +82,7 @@ typedef struct size_params{
int h;
int w;
int c;
int index;
} size_params;
deconvolutional_layer parse_deconvolutional(list *options, size_params params)
@ -148,6 +151,7 @@ convolutional_layer parse_convolutional(list *options, size_params params)
int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,pad,activation, batch_normalize);
layer.flipped = option_find_int_quiet(options, "flipped", 0);
char *weights = option_find_str(options, "weights", 0);
char *biases = option_find_str(options, "biases", 0);
@ -287,6 +291,20 @@ layer parse_normalization(list *options, size_params params)
return l;
}
layer parse_shortcut(list *options, size_params params, network net)
{
char *l = option_find(options, "from");
int index = atoi(l);
if(index < 0) index = params.index + index;
int batch = params.batch;
layer from = net.layers[index];
layer s = make_shortcut_layer(batch, index, params.w, params.h, params.c, from.out_w, from.out_h, from.out_c);
return s;
}
route_layer parse_route(list *options, size_params params, network net)
{
char *l = option_find(options, "layers");
@ -303,13 +321,14 @@ route_layer parse_route(list *options, size_params params, network net)
for(i = 0; i < n; ++i){
int index = atoi(l);
l = strchr(l, ',')+1;
if(index < 0) index = params.index + index;
layers[i] = index;
sizes[i] = net.layers[index].outputs;
}
int batch = params.batch;
route_layer layer = make_route_layer(batch, n, layers, sizes);
convolutional_layer first = net.layers[layers[0]];
layer.out_w = first.out_w;
layer.out_h = first.out_h;
@ -419,6 +438,7 @@ network parse_network_cfg(char *filename)
int count = 0;
free_section(s);
while(n){
params.index = count;
fprintf(stderr, "%d: ", count);
s = (section *)n->val;
options = s->options;
@ -447,6 +467,8 @@ network parse_network_cfg(char *filename)
l = parse_avgpool(options, params);
}else if(is_route(s)){
l = parse_route(options, params, net);
}else if(is_shortcut(s)){
l = parse_shortcut(options, params, net);
}else if(is_dropout(s)){
l = parse_dropout(options, params);
l.output = net.layers[count-1].output;
@ -464,13 +486,13 @@ network parse_network_cfg(char *filename)
net.layers[count] = l;
free_section(s);
n = n->next;
++count;
if(n){
params.h = l.out_h;
params.w = l.out_w;
params.c = l.out_c;
params.inputs = l.outputs;
}
++count;
}
free_list(sections);
net.outputs = get_network_output_size(net);
@ -478,6 +500,10 @@ network parse_network_cfg(char *filename)
return net;
}
int is_shortcut(section *s)
{
return (strcmp(s->type, "[shortcut]")==0);
}
int is_crop(section *s)
{
return (strcmp(s->type, "[crop]")==0);
@ -625,9 +651,12 @@ void save_weights_upto(network net, char *filename, int cutoff)
FILE *fp = fopen(filename, "w");
if(!fp) file_error(filename);
fwrite(&net.learning_rate, sizeof(float), 1, fp);
fwrite(&net.momentum, sizeof(float), 1, fp);
fwrite(&net.decay, sizeof(float), 1, fp);
int major = 0;
int minor = 1;
int revision = 0;
fwrite(&major, sizeof(int), 1, fp);
fwrite(&minor, sizeof(int), 1, fp);
fwrite(&revision, sizeof(int), 1, fp);
fwrite(net.seen, sizeof(int), 1, fp);
int i;
@ -674,6 +703,19 @@ void save_weights(network net, char *filename)
save_weights_upto(net, filename, net.n);
}
void transpose_matrix(float *a, int rows, int cols)
{
float *transpose = calloc(rows*cols, sizeof(float));
int x, y;
for(x = 0; x < rows; ++x){
for(y = 0; y < cols; ++y){
transpose[y*rows + x] = a[x*cols + y];
}
}
memcpy(a, transpose, rows*cols*sizeof(float));
free(transpose);
}
void load_weights_upto(network *net, char *filename, int cutoff)
{
fprintf(stderr, "Loading weights from %s...", filename);
@ -681,10 +723,12 @@ void load_weights_upto(network *net, char *filename, int cutoff)
FILE *fp = fopen(filename, "r");
if(!fp) file_error(filename);
float garbage;
fread(&garbage, sizeof(float), 1, fp);
fread(&garbage, sizeof(float), 1, fp);
fread(&garbage, sizeof(float), 1, fp);
int major;
int minor;
int revision;
fread(&major, sizeof(int), 1, fp);
fread(&minor, sizeof(int), 1, fp);
fread(&revision, sizeof(int), 1, fp);
fread(net->seen, sizeof(int), 1, fp);
int i;
@ -700,6 +744,9 @@ void load_weights_upto(network *net, char *filename, int cutoff)
fread(l.rolling_variance, sizeof(float), l.n, fp);
}
fread(l.filters, sizeof(float), num, fp);
if (l.flipped) {
transpose_matrix(l.filters, l.c*l.size*l.size, l.n);
}
#ifdef GPU
if(gpu_index >= 0){
push_convolutional_layer(l);
@ -719,6 +766,9 @@ void load_weights_upto(network *net, char *filename, int cutoff)
if(l.type == CONNECTED){
fread(l.biases, sizeof(float), l.outputs, fp);
fread(l.weights, sizeof(float), l.outputs*l.inputs, fp);
if(major > 1000 || minor > 1000){
transpose_matrix(l.weights, l.inputs, l.outputs);
}
#ifdef GPU
if(gpu_index >= 0){
push_connected_layer(l);

64
src/shortcut_layer.c Normal file
View File

@ -0,0 +1,64 @@
#include "shortcut_layer.h"
#include "cuda.h"
#include "blas.h"
#include <stdio.h>
#include <assert.h>
layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2)
{
fprintf(stderr,"Shortcut Layer: %d\n", index);
layer l = {0};
l.type = SHORTCUT;
l.batch = batch;
l.w = w;
l.h = h;
l.c = c;
l.out_w = w;
l.out_h = h;
l.out_c = c;
l.outputs = w*h*c;
l.inputs = w*h*c;
int stride = w2 / w;
assert(stride * w == w2);
assert(stride * h == h2);
assert(c >= c2);
l.stride = stride;
l.n = c2;
l.index = index;
l.delta = calloc(l.outputs*batch, sizeof(float));
l.output = calloc(l.outputs*batch, sizeof(float));;
#ifdef GPU
l.delta_gpu = cuda_make_array(l.delta, l.outputs*batch);
l.output_gpu = cuda_make_array(l.output, l.outputs*batch);
#endif
return l;
}
void forward_shortcut_layer(const layer l, network_state state)
{
copy_cpu(l.outputs*l.batch, state.input, 1, l.output, 1);
shortcut_cpu(l.output, l.w, l.h, l.c, l.batch, 1, state.net.layers[l.index].output, l.stride, l.n);
}
void backward_shortcut_layer(const layer l, network_state state)
{
copy_cpu(l.outputs*l.batch, l.delta, 1, state.delta, 1);
shortcut_cpu(state.net.layers[l.index].delta, l.w*l.stride, l.h*l.stride, l.n, l.batch, l.stride, l.delta, 1, l.c);
}
#ifdef GPU
void forward_shortcut_layer_gpu(const layer l, network_state state)
{
copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1);
shortcut_gpu(l.output_gpu, l.w, l.h, l.c, l.batch, 1, state.net.layers[l.index].output_gpu, l.stride, l.n);
}
void backward_shortcut_layer_gpu(const layer l, network_state state)
{
copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, state.delta, 1);
shortcut_gpu(state.net.layers[l.index].delta_gpu, l.w*l.stride, l.h*l.stride, l.n, l.batch, l.stride, l.delta_gpu, 1, l.c);
}
#endif

16
src/shortcut_layer.h Normal file
View File

@ -0,0 +1,16 @@
#ifndef SHORTCUT_LAYER_H
#define SHORTCUT_LAYER_H
#include "layer.h"
#include "network.h"
layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2);
void forward_shortcut_layer(const layer l, network_state state);
void backward_shortcut_layer(const layer l, network_state state);
#ifdef GPU
void forward_shortcut_layer_gpu(const layer l, network_state state);
void backward_shortcut_layer_gpu(const layer l, network_state state);
#endif
#endif

View File

@ -2,6 +2,7 @@
#define SOFTMAX_LAYER_H
#include "params.h"
#include "layer.h"
#include "network.h"
typedef layer softmax_layer;

View File

@ -255,9 +255,8 @@ void validate_yolo_recall(char *cfgfile, char *weightfile)
int i=0;
float thresh = .001;
int nms = 0;
float iou_thresh = .5;
float nms_thresh = .5;
float nms = 0;
int total = 0;
int correct = 0;
@ -271,7 +270,7 @@ void validate_yolo_recall(char *cfgfile, char *weightfile)
char *id = basecfg(path);
float *predictions = network_predict(net, sized.data);
convert_yolo_detections(predictions, classes, l.n, square, side, 1, 1, thresh, probs, boxes, 1);
if (nms) do_nms(boxes, probs, side*side*l.n, 1, nms_thresh);
if (nms) do_nms(boxes, probs, side*side*l.n, 1, nms);
char *labelpath = find_replace(path, "images", "labels");
labelpath = find_replace(labelpath, "JPEGImages", "labels");