🔥 🔥 yolo v2 🔥 🔥

This commit is contained in:
Joseph Redmon 2016-11-17 12:18:19 -08:00
parent c71bff69ea
commit c6afc7ff14
24 changed files with 761 additions and 173 deletions

8
cfg/coco.data Normal file
View File

@ -0,0 +1,8 @@
classes= 80
train = /home/pjreddie/data/coco/trainvalno5k.txt
valid = coco_testdev
#valid = data/coco_val_5k.list
names = data/coco.names
backup = /home/pjreddie/backup/
eval=coco

6
cfg/voc.data Normal file
View File

@ -0,0 +1,6 @@
classes= 20
train = /home/pjreddie/data/voc/train.txt
valid = /home/pjreddie/data/voc/2007_test.txt
names = data/pascal.names
backup = /home/pjreddie/backup/

View File

@ -1,36 +1,25 @@
[net]
batch=1
subdivisions=1
height=448
width=448
batch=64
subdivisions=8
height=416
width=416
channels=3
momentum=0.9
decay=0.0005
saturation=1.5
exposure=1.5
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.0005
learning_rate=0.001
max_batches = 120000
policy=steps
steps=200,400,600,20000,30000
scales=2.5,2,2,.1,.1
max_batches = 40000
steps=-1,100,80000,100000
scales=.1,10,.1,.1
[convolutional]
batch_normalize=1
filters=64
size=7
stride=2
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=192
filters=32
size=3
stride=1
pad=1
@ -40,6 +29,54 @@ activation=leaky
size=2
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
@ -56,6 +93,34 @@ stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
@ -76,78 +141,6 @@ activation=leaky
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
@ -156,10 +149,6 @@ stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
@ -192,6 +181,7 @@ stride=1
pad=1
activation=leaky
#######
[convolutional]
@ -205,10 +195,19 @@ activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[route]
layers=-9
[reorg]
stride=2
pad=1
filters=1024
activation=leaky
[route]
layers=-1,-3
[convolutional]
batch_normalize=1
@ -219,39 +218,27 @@ filters=1024
activation=leaky
[convolutional]
batch_normalize=1
size=3
size=1
stride=1
pad=1
filters=1024
activation=leaky
[local]
size=3
stride=1
pad=1
filters=256
activation=leaky
[dropout]
probability=.5
[connected]
output= 1715
filters=425
activation=linear
[detection]
classes=20
[region]
anchors = 0.738768,0.874946, 2.42204,2.65704, 4.30971,7.04493, 10.246,4.59428, 12.6868,11.8741
bias_match=1
classes=80
coords=4
rescore=1
side=7
num=3
softmax=0
sqrt=1
num=5
softmax=1
jitter=.2
rescore=1
object_scale=1
noobject_scale=.5
object_scale=5
noobject_scale=1
class_scale=1
coord_scale=5
coord_scale=1
absolute=1
thresh = .6
random=0

244
cfg/yolo_voc.cfg Normal file
View File

@ -0,0 +1,244 @@
[net]
batch=64
subdivisions=8
height=416
width=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.0001
max_batches = 45000
policy=steps
steps=100,25000,35000
scales=10,.1,.1
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
#######
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[route]
layers=-9
[reorg]
stride=2
[route]
layers=-1,-3
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=125
activation=linear
[region]
anchors = 1.08,1.19, 3.42,4.41, 6.63,11.38, 9.42,5.11, 16.62,10.52
bias_match=1
classes=20
coords=4
num=5
softmax=1
jitter=.2
rescore=1
object_scale=5
noobject_scale=1
class_scale=1
coord_scale=1
absolute=1
thresh = .6
random=0

257
cfg/yolov1/yolo.cfg Normal file
View File

@ -0,0 +1,257 @@
[net]
batch=1
subdivisions=1
height=448
width=448
channels=3
momentum=0.9
decay=0.0005
saturation=1.5
exposure=1.5
hue=.1
learning_rate=0.0005
policy=steps
steps=200,400,600,20000,30000
scales=2.5,2,2,.1,.1
max_batches = 40000
[convolutional]
batch_normalize=1
filters=64
size=7
stride=2
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=192
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
#######
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=2
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[local]
size=3
stride=1
pad=1
filters=256
activation=leaky
[dropout]
probability=.5
[connected]
output= 1715
activation=linear
[detection]
classes=20
coords=4
rescore=1
side=7
num=3
softmax=0
sqrt=1
jitter=.2
object_scale=1
noobject_scale=.5
class_scale=1
coord_scale=5

80
data/coco.names Normal file
View File

@ -0,0 +1,80 @@
person
bicycle
car
motorbike
aeroplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
sofa
pottedplant
bed
diningtable
toilet
tvmonitor
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush

View File

@ -5,6 +5,28 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
void reorg_cpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
{
int b,i,j,k;
int out_c = c/(stride*stride);
for(b = 0; b < batch; ++b){
for(k = 0; k < c; ++k){
for(j = 0; j < h; ++j){
for(i = 0; i < w; ++i){
int in_index = i + w*(j + h*(k + c*b));
int c2 = k % out_c;
int offset = k / out_c;
int w2 = i*stride + offset % stride;
int h2 = j*stride + offset / stride;
int out_index = w2 + w*stride*(h2 + h*stride*(c2 + out_c*b));
if(forward) out[out_index] = x[in_index];
else out[in_index] = x[out_index];
}
}
}
}
}
void flatten(float *x, int size, int layers, int batch, int forward)
{

View File

@ -4,6 +4,7 @@ void flatten(float *x, int size, int layers, int batch, int forward);
void pm(int M, int N, float *A);
float *random_matrix(int rows, int cols);
void time_random_matrix(int TA, int TB, int m, int k, int n);
void reorg_cpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out);
void test_blas();

View File

@ -13,6 +13,7 @@
#endif
extern void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top);
extern void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh);
extern void run_voxel(int argc, char **argv);
extern void run_yolo(int argc, char **argv);
extern void run_detector(int argc, char **argv);
@ -379,6 +380,10 @@ int main(int argc, char **argv)
run_super(argc, argv);
} else if (0 == strcmp(argv[1], "detector")){
run_detector(argc, argv);
} else if (0 == strcmp(argv[1], "detect")){
float thresh = find_float_arg(argc, argv, "-thresh", .25);
char *filename = (argc > 4) ? argv[4]: 0;
test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh);
} else if (0 == strcmp(argv[1], "cifar")){
run_cifar(argc, argv);
} else if (0 == strcmp(argv[1], "go")){
@ -390,7 +395,7 @@ int main(int argc, char **argv)
} else if (0 == strcmp(argv[1], "coco")){
run_coco(argc, argv);
} else if (0 == strcmp(argv[1], "classify")){
predict_classifier("cfg/imagenet1k.dataset", argv[2], argv[3], argv[4], 5);
predict_classifier("cfg/imagenet1k.data", argv[2], argv[3], argv[4], 5);
} else if (0 == strcmp(argv[1], "classifier")){
run_classifier(argc, argv);
} else if (0 == strcmp(argv[1], "art")){

View File

@ -110,6 +110,7 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
srand(2222222);
if(filename){
printf("video file: %s\n", filename);
cap = cvCaptureFromFile(filename);
}else{
cap = cvCaptureFromCAM(cam_index);

View File

@ -490,7 +490,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
void run_detector(int argc, char **argv)
{
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
float thresh = find_float_arg(argc, argv, "-thresh", .2);
float thresh = find_float_arg(argc, argv, "-thresh", .25);
int cam_index = find_int_arg(argc, argv, "-c", 0);
int frame_skip = find_int_arg(argc, argv, "-s", 0);
if(argc < 4){

View File

@ -185,10 +185,16 @@ void draw_detections(image im, int num, float thresh, box *boxes, float **probs,
int class = max_index(probs[i], classes);
float prob = probs[i][class];
if(prob > thresh){
//int width = pow(prob, 1./2.)*30+1;
int width = im.h * .012;
if(0){
width = pow(prob, 1./2.)*10+1;
alphabet = 0;
}
printf("%s: %.0f%%\n", names[class], prob*100);
int offset = class*1 % classes;
int offset = class*123457 % classes;
float red = get_color(2,offset,classes);
float green = get_color(1,offset,classes);
float blue = get_color(0,offset,classes);

View File

@ -238,9 +238,6 @@ layer parse_region(list *options, size_params params)
int classes = option_find_int(options, "classes", 20);
int num = option_find_int(options, "num", 1);
params.w = option_find_int(options, "side", params.w);
params.h = option_find_int(options, "side", params.h);
layer l = make_region_layer(params.batch, params.w, params.h, num, classes, coords);
assert(l.outputs == params.inputs);

View File

@ -44,7 +44,7 @@ region_layer make_region_layer(int batch, int w, int h, int n, int classes, int
l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
#endif
fprintf(stderr, "Region Layer\n");
fprintf(stderr, "detection\n");
srand(0);
return l;

View File

@ -23,7 +23,7 @@ layer make_reorg_layer(int batch, int h, int w, int c, int stride, int reverse)
l.out_c = c*(stride*stride);
}
l.reverse = reverse;
fprintf(stderr, "Reorg Layer: %d x %d x %d image -> %d x %d x %d image, \n", w,h,c,l.out_w, l.out_h, l.out_c);
fprintf(stderr, "reorg /%2d %4d x%4d x%4d -> %4d x%4d x%4d\n", stride, w, h, c, l.out_w, l.out_h, l.out_c);
l.outputs = l.out_h * l.out_w * l.out_c;
l.inputs = h*w*c;
int output_size = l.out_h * l.out_w * l.out_c * batch;
@ -77,45 +77,19 @@ void resize_reorg_layer(layer *l, int w, int h)
void forward_reorg_layer(const layer l, network_state state)
{
int b,i,j,k;
for(b = 0; b < l.batch; ++b){
for(k = 0; k < l.c; ++k){
for(j = 0; j < l.h; ++j){
for(i = 0; i < l.w; ++i){
int in_index = i + l.w*(j + l.h*(k + l.c*b));
int c2 = k % l.out_c;
int offset = k / l.out_c;
int w2 = i*l.stride + offset % l.stride;
int h2 = j*l.stride + offset / l.stride;
int out_index = w2 + l.out_w*(h2 + l.out_h*(c2 + l.out_c*b));
l.output[out_index] = state.input[in_index];
}
}
}
if(l.reverse){
reorg_cpu(state.input, l.w, l.h, l.c, l.batch, l.stride, 1, l.output);
}else {
reorg_cpu(state.input, l.w, l.h, l.c, l.batch, l.stride, 0, l.output);
}
}
void backward_reorg_layer(const layer l, network_state state)
{
int b,i,j,k;
for(b = 0; b < l.batch; ++b){
for(k = 0; k < l.c; ++k){
for(j = 0; j < l.h; ++j){
for(i = 0; i < l.w; ++i){
int in_index = i + l.w*(j + l.h*(k + l.c*b));
int c2 = k % l.out_c;
int offset = k / l.out_c;
int w2 = i*l.stride + offset % l.stride;
int h2 = j*l.stride + offset / l.stride;
int out_index = w2 + l.out_w*(h2 + l.out_h*(c2 + l.out_c*b));
state.delta[in_index] = l.delta[out_index];
}
}
}
if(l.reverse){
reorg_cpu(l.delta, l.w, l.h, l.c, l.batch, l.stride, 0, state.delta);
}else{
reorg_cpu(l.delta, l.w, l.h, l.c, l.batch, l.stride, 1, state.delta);
}
}

View File

@ -5,7 +5,7 @@
route_layer make_route_layer(int batch, int n, int *input_layers, int *input_sizes)
{
fprintf(stderr,"Route Layer:");
fprintf(stderr,"route ");
route_layer l = {0};
l.type = ROUTE;
l.batch = batch;