detection better?

This commit is contained in:
Joseph Redmon 2015-04-07 15:25:30 -07:00
parent a05c4bd2e9
commit 1fd10265f8
5 changed files with 65 additions and 29 deletions

View File

@ -108,8 +108,9 @@ void randomize_boxes(box *b, int n)
void fill_truth_detection(char *path, float *truth, int classes, int num_boxes, int flip, int background, float dx, float dy, float sx, float sy) void fill_truth_detection(char *path, float *truth, int classes, int num_boxes, int flip, int background, float dx, float dy, float sx, float sy)
{ {
char *labelpath = find_replace(path, "VOC2012/JPEGImages", "labels"); char *labelpath = find_replace(path, "detection_images", "labels");
labelpath = find_replace(labelpath, ".jpg", ".txt"); labelpath = find_replace(labelpath, ".jpg", ".txt");
labelpath = find_replace(labelpath, ".JPEG", ".txt");
int count = 0; int count = 0;
box *boxes = read_boxes(labelpath, &count); box *boxes = read_boxes(labelpath, &count);
randomize_boxes(boxes, count); randomize_boxes(boxes, count);
@ -293,8 +294,6 @@ void free_data(data d)
data load_data_detection_jitter_random(int n, char **paths, int m, int classes, int h, int w, int num_boxes, int background) data load_data_detection_jitter_random(int n, char **paths, int m, int classes, int h, int w, int num_boxes, int background)
{ {
//float minscale = 0.85;
//float maxscale = 1.15;
char **random_paths = get_random_paths(paths, n, m); char **random_paths = get_random_paths(paths, n, m);
int i; int i;
data d; data d;
@ -310,10 +309,14 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int classes,
image orig = load_image_color(random_paths[i], 0, 0); image orig = load_image_color(random_paths[i], 0, 0);
int oh = orig.h; int oh = orig.h;
int ow = orig.w; int ow = orig.w;
int pleft = (rand_uniform() * 64. - 32.);
int pright = (rand_uniform() * 64. - 32.); int dw = ow/10;
int ptop = (rand_uniform() * 64. - 32.); int dh = oh/10;
int pbot = (rand_uniform() * 64. - 32.);
int pleft = (rand_uniform() * 2*dw - dw);
int pright = (rand_uniform() * 2*dw - dw);
int ptop = (rand_uniform() * 2*dh - dh);
int pbot = (rand_uniform() * 2*dh - dh);
int swidth = ow - pleft - pright; int swidth = ow - pleft - pright;
int sheight = oh - ptop - pbot; int sheight = oh - ptop - pbot;

View File

@ -1,13 +1,15 @@
#include "network.h" #include "network.h"
#include "detection_layer.h"
#include "utils.h" #include "utils.h"
#include "parser.h" #include "parser.h"
char *class_names[] = {"bg", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"}; char *class_names[] = {"bg", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
char *inet_class_names[] = {"bg", "accordion", "airplane", "ant", "antelope", "apple", "armadillo", "artichoke", "axe", "baby bed", "backpack", "bagel", "balance beam", "banana", "band aid", "banjo", "baseball", "basketball", "bathing cap", "beaker", "bear", "bee", "bell pepper", "bench", "bicycle", "binder", "bird", "bookshelf", "bow tie", "bow", "bowl", "brassiere", "burrito", "bus", "butterfly", "camel", "can opener", "car", "cart", "cattle", "cello", "centipede", "chain saw", "chair", "chime", "cocktail shaker", "coffee maker", "computer keyboard", "computer mouse", "corkscrew", "cream", "croquet ball", "crutch", "cucumber", "cup or mug", "diaper", "digital clock", "dishwasher", "dog", "domestic cat", "dragonfly", "drum", "dumbbell", "electric fan", "elephant", "face powder", "fig", "filing cabinet", "flower pot", "flute", "fox", "french horn", "frog", "frying pan", "giant panda", "goldfish", "golf ball", "golfcart", "guacamole", "guitar", "hair dryer", "hair spray", "hamburger", "hammer", "hamster", "harmonica", "harp", "hat with a wide brim", "head cabbage", "helmet", "hippopotamus", "horizontal bar", "horse", "hotdog", "iPod", "isopod", "jellyfish", "koala bear", "ladle", "ladybug", "lamp", "laptop", "lemon", "lion", "lipstick", "lizard", "lobster", "maillot", "maraca", "microphone", "microwave", "milk can", "miniskirt", "monkey", "motorcycle", "mushroom", "nail", "neck brace", "oboe", "orange", "otter", "pencil box", "pencil sharpener", "perfume", "person", "piano", "pineapple", "ping-pong ball", "pitcher", "pizza", "plastic bag", "plate rack", "pomegranate", "popsicle", "porcupine", "power drill", "pretzel", "printer", "puck", "punching bag", "purse", "rabbit", "racket", "ray", "red panda", "refrigerator", "remote control", "rubber eraser", "rugby ball", "ruler", "salt or pepper shaker", "saxophone", "scorpion", "screwdriver", "seal", "sheep", "ski", "skunk", "snail", "snake", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sofa", "spatula", "squirrel", "starfish", "stethoscope", "stove", "strainer", "strawberry", "stretcher", "sunglasses", "swimming trunks", "swine", "syringe", "table", "tape player", "tennis ball", "tick", "tie", "tiger", "toaster", "traffic light", "train", "trombone", "trumpet", "turtle", "tv or monitor", "unicycle", "vacuum", "violin", "volleyball", "waffle iron", "washer", "water bottle", "watercraft", "whale", "wine bottle", "zebra"};
#define AMNT 3 #define AMNT 3
void draw_detection(image im, float *box, int side) void draw_detection(image im, float *box, int side)
{ {
int classes = 21; int classes = 201;
int elems = 4+classes; int elems = 4+classes;
int j; int j;
int r, c; int r, c;
@ -21,7 +23,7 @@ void draw_detection(image im, float *box, int side)
if(box[j+class] > .02 || 1){ if(box[j+class] > .02 || 1){
//int z; //int z;
//for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+z], class_names[z]); //for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+z], class_names[z]);
printf("%f %s\n", box[j+class], class_names[class]); printf("%f %s\n", box[j+class], inet_class_names[class]);
float red = get_color(0,class,classes); float red = get_color(0,class,classes);
float green = get_color(1,class,classes); float green = get_color(1,class,classes);
float blue = get_color(2,class,classes); float blue = get_color(2,class,classes);
@ -35,6 +37,8 @@ void draw_detection(image im, float *box, int side)
y = (y+r)/side; y = (y+r)/side;
float h = box[j+2]; //*maxheight; float h = box[j+2]; //*maxheight;
float w = box[j+3]; //*maxwidth; float w = box[j+3]; //*maxwidth;
h = h*h;
w = w*w;
//printf("coords %f %f %f %f\n", x, y, w, h); //printf("coords %f %f %f %f\n", x, y, w, h);
int left = (x-w/2)*im.w; int left = (x-w/2)*im.w;
@ -52,6 +56,8 @@ void draw_detection(image im, float *box, int side)
void train_detection(char *cfgfile, char *weightfile) void train_detection(char *cfgfile, char *weightfile)
{ {
srand(time(0));
int imgnet = 0;
char *base = basecfg(cfgfile); char *base = basecfg(cfgfile);
printf("%s\n", base); printf("%s\n", base);
float avg_loss = -1; float avg_loss = -1;
@ -59,30 +65,37 @@ void train_detection(char *cfgfile, char *weightfile)
if(weightfile){ if(weightfile){
load_weights(&net, weightfile); load_weights(&net, weightfile);
} }
//net.seen = 0; detection_layer *layer = get_network_detection_layer(net);
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = 128; int imgs = 128;
srand(time(0));
//srand(23410);
int i = net.seen/imgs; int i = net.seen/imgs;
list *plist = get_paths("/home/pjreddie/data/voc/train.txt");
char **paths = (char **)list_to_array(plist);
printf("%d\n", plist->size);
data train, buffer; data train, buffer;
int im_dim = 448;
int classes = 20; int classes = layer->classes;
int background = 1; int background = layer->background;
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, im_dim, im_dim, 7, 7, background, &buffer); int side = sqrt(get_detection_layer_locations(*layer));
char **paths;
list *plist;
if (imgnet){
plist = get_paths("/home/pjreddie/data/imagenet/det.train.list");
}else{
plist = get_paths("/home/pjreddie/data/voc/trainall.txt");
}
paths = (char **)list_to_array(plist);
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.h, net.w, side, side, background, &buffer);
clock_t time; clock_t time;
while(1){ while(1){
i += 1; i += 1;
time=clock(); time=clock();
pthread_join(load_thread, 0); pthread_join(load_thread, 0);
train = buffer; train = buffer;
load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, im_dim, im_dim, 7, 7, background, &buffer); load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.h, net.w, side, side, background, &buffer);
//image im = float_to_image(im_dim, im_dim, 3, train.X.vals[114]); /*
//draw_detection(im, train.y.vals[114], 7); image im = float_to_image(im_dim, im_dim, 3, train.X.vals[114]);
draw_detection(im, train.y.vals[114], 7);
*/
printf("Loaded: %lf seconds\n", sec(clock()-time)); printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock(); time=clock();
@ -106,17 +119,19 @@ void validate_detection(char *cfgfile, char *weightfile)
if(weightfile){ if(weightfile){
load_weights(&net, weightfile); load_weights(&net, weightfile);
} }
detection_layer *layer = get_network_detection_layer(net);
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
srand(time(0)); srand(time(0));
list *plist = get_paths("/home/pjreddie/data/voc/val.txt"); list *plist = get_paths("/home/pjreddie/data/voc/val.txt");
//list *plist = get_paths("/home/pjreddie/data/voc/train.txt"); //list *plist = get_paths("/home/pjreddie/data/voc/train.txt");
char **paths = (char **)list_to_array(plist); char **paths = (char **)list_to_array(plist);
int im_size = 448;
int classes = 20; int classes = layer->classes;
int background = 0; int nuisance = layer->nuisance;
int nuisance = 1; int background = (layer->background && !nuisance);
int num_boxes = 7; int num_boxes = sqrt(get_detection_layer_locations(*layer));
int per_box = 4+classes+background+nuisance; int per_box = 4+classes+background+nuisance;
int num_output = num_boxes*num_boxes*per_box; int num_output = num_boxes*num_boxes*per_box;
@ -127,7 +142,7 @@ void validate_detection(char *cfgfile, char *weightfile)
fprintf(stderr, "%d\n", m); fprintf(stderr, "%d\n", m);
data val, buffer; data val, buffer;
pthread_t load_thread = load_data_thread(paths, num, 0, 0, num_output, im_size, im_size, &buffer); pthread_t load_thread = load_data_thread(paths, num, 0, 0, num_output, net.h, net.w, &buffer);
clock_t time; clock_t time;
for(i = 1; i <= splits; ++i){ for(i = 1; i <= splits; ++i){
time=clock(); time=clock();
@ -136,7 +151,7 @@ void validate_detection(char *cfgfile, char *weightfile)
num = (i+1)*m/splits - i*m/splits; num = (i+1)*m/splits - i*m/splits;
char **part = paths+(i*m/splits); char **part = paths+(i*m/splits);
if(i != splits) load_thread = load_data_thread(part, num, 0, 0, num_output, im_size, im_size, &buffer); if(i != splits) load_thread = load_data_thread(part, num, 0, 0, num_output, net.h, net.w, &buffer);
fprintf(stderr, "%d: Loaded: %lf seconds\n", i, sec(clock()-time)); fprintf(stderr, "%d: Loaded: %lf seconds\n", i, sec(clock()-time));
matrix pred = network_predict_data(net, val); matrix pred = network_predict_data(net, val);

View File

@ -23,6 +23,7 @@ detection_layer *make_detection_layer(int batch, int inputs, int classes, int co
void forward_detection_layer(const detection_layer layer, network_state state); void forward_detection_layer(const detection_layer layer, network_state state);
void backward_detection_layer(const detection_layer layer, network_state state); void backward_detection_layer(const detection_layer layer, network_state state);
int get_detection_layer_output_size(detection_layer layer); int get_detection_layer_output_size(detection_layer layer);
int get_detection_layer_locations(detection_layer layer);
#ifdef GPU #ifdef GPU
void forward_detection_layer_gpu(const detection_layer layer, network_state state); void forward_detection_layer_gpu(const detection_layer layer, network_state state);

View File

@ -500,6 +500,18 @@ int get_network_input_size(network net)
return get_network_input_size_layer(net, 0); return get_network_input_size_layer(net, 0);
} }
detection_layer *get_network_detection_layer(network net)
{
int i;
for(i = 0; i < net.n; ++i){
if(net.types[i] == DETECTION){
detection_layer *layer = (detection_layer *)net.layers[i];
return layer;
}
}
return 0;
}
image get_network_image_layer(network net, int i) image get_network_image_layer(network net, int i)
{ {
if(net.types[i] == CONVOLUTIONAL){ if(net.types[i] == CONVOLUTIONAL){

View File

@ -3,6 +3,7 @@
#define NETWORK_H #define NETWORK_H
#include "image.h" #include "image.h"
#include "detection_layer.h"
#include "params.h" #include "params.h"
#include "data.h" #include "data.h"
@ -81,6 +82,10 @@ int resize_network(network net, int h, int w, int c);
void set_batch_network(network *net, int b); void set_batch_network(network *net, int b);
int get_network_input_size(network net); int get_network_input_size(network net);
float get_network_cost(network net); float get_network_cost(network net);
detection_layer *get_network_detection_layer(network net);
int get_network_nuisance(network net);
int get_network_background(network net);
#endif #endif