faster nms and stuff

This commit is contained in:
Joseph Redmon 2018-03-15 15:23:14 -07:00
parent 0b64cb4dd3
commit 0f110834f4
10 changed files with 126 additions and 93 deletions

View File

@ -146,7 +146,7 @@ void validate_coco(char *cfg, char *weights)
FILE *fp = fopen(buff, "w");
fprintf(fp, "[\n");
detection *dets = make_network_boxes(net);
detection *dets = make_network_boxes(net, 0);
int m = plist->size;
int i=0;
@ -231,7 +231,7 @@ void validate_coco_recall(char *cfgfile, char *weightfile)
snprintf(buff, 1024, "%s%s.txt", base, coco_classes[j]);
fps[j] = fopen(buff, "w");
}
detection *dets = make_network_boxes(net);
detection *dets = make_network_boxes(net, 0);
int m = plist->size;
int i=0;
@ -302,7 +302,7 @@ void test_coco(char *cfgfile, char *weightfile, char *filename, float thresh)
clock_t time;
char buff[256];
char *input = buff;
detection *dets = make_network_boxes(net);
detection *dets = make_network_boxes(net, 0);
while(1){
if(filename){
strncpy(input, filename, 256);

View File

@ -279,8 +279,6 @@ void validate_detector_flip(char *datacfg, char *cfgfile, char *weightfile, char
}
}
detection *dets = make_network_boxes(net);
int m = plist->size;
int i=0;
int t;
@ -333,15 +331,17 @@ void validate_detector_flip(char *datacfg, char *cfgfile, char *weightfile, char
network_predict(net, input.data);
int w = val[t].w;
int h = val[t].h;
fill_network_boxes(net, w, h, thresh, .5, map, 0, dets);
if (nms) do_nms_sort(dets, l.w*l.h*l.n, classes, nms);
int num = 0;
detection *dets = get_network_boxes(net, w, h, thresh, .5, map, 0, &num);
if (nms) do_nms_sort(dets, num, classes, nms);
if (coco){
print_cocos(fp, path, dets, l.w*l.h*l.n, classes, w, h);
print_cocos(fp, path, dets, num, classes, w, h);
} else if (imagenet){
print_imagenet_detections(fp, i+t-nthreads+1, dets, l.w*l.h*l.n, classes, w, h);
print_imagenet_detections(fp, i+t-nthreads+1, dets, num, classes, w, h);
} else {
print_detector_detections(fps, id, dets, l.w*l.h*l.n, classes, w, h);
print_detector_detections(fps, id, dets, num, classes, w, h);
}
free_detections(dets, num);
free(id);
free_image(val[t]);
free_image(val_resized[t]);
@ -409,8 +409,6 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile, char *out
}
}
detection *dets = make_network_boxes(net);
int nboxes = num_boxes(net);
int m = plist->size;
int i=0;
@ -459,7 +457,8 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile, char *out
network_predict(net, X);
int w = val[t].w;
int h = val[t].h;
fill_network_boxes(net, w, h, thresh, .5, map, 0, dets);
int nboxes = 0;
detection *dets = get_network_boxes(net, w, h, thresh, .5, map, 0, &nboxes);
if (nms) do_nms_sort(dets, nboxes, classes, nms);
if (coco){
print_cocos(fp, path, dets, nboxes, classes, w, h);
@ -497,7 +496,6 @@ void validate_detector_recall(char *cfgfile, char *weightfile)
layer l = net->layers[net->n-1];
int j, k;
detection *dets = make_network_boxes(net);
int m = plist->size;
int i=0;
@ -510,7 +508,6 @@ void validate_detector_recall(char *cfgfile, char *weightfile)
int correct = 0;
int proposals = 0;
float avg_iou = 0;
int nboxes = num_boxes(net);
for(i = 0; i < m; ++i){
char *path = paths[i];
@ -518,7 +515,8 @@ void validate_detector_recall(char *cfgfile, char *weightfile)
image sized = resize_image(orig, net->w, net->h);
char *id = basecfg(path);
network_predict(net, sized.data);
fill_network_boxes(net, sized.w, sized.h, thresh, .5, 0, 1, dets);
int nboxes = 0;
detection *dets = get_network_boxes(net, sized.w, sized.h, thresh, .5, 0, 1, &nboxes);
if (nms) do_nms_obj(dets, nboxes, 1, nms);
char labelpath[4096];
@ -590,18 +588,18 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
//resize_network(net, sized.w, sized.h);
layer l = net->layers[net->n-1];
int nboxes = num_boxes(net);
printf("%d\n", nboxes);
float *X = sized.data;
time=what_time_is_it_now();
network_predict(net, X);
printf("%s: Predicted in %f seconds.\n", input, what_time_is_it_now()-time);
detection *dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, 0, 1);
int nboxes = 0;
detection *dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes);
printf("%d\n", nboxes);
//if (nms) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms);
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
draw_detections(im, dets, nboxes, thresh, names, alphabet, l.classes);
free_detections(dets, num_boxes(net));
free_detections(dets, nboxes);
if(outfile){
save_image(im, outfile);
}
@ -673,11 +671,10 @@ void censor_detector(char *datacfg, char *cfgfile, char *weightfile, int cam_ind
image in_s = letterbox_image(in, net->w, net->h);
layer l = net->layers[net->n-1];
int nboxes = num_boxes(net);
float *X = in_s.data;
network_predict(net, X);
detection *dets = get_network_boxes(net, in.w, in.h, thresh, 0, 0, 0);
int nboxes = 0;
detection *dets = get_network_boxes(net, in.w, in.h, thresh, 0, 0, 0, &nboxes);
//if (nms) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms);
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
@ -691,7 +688,7 @@ void censor_detector(char *datacfg, char *cfgfile, char *weightfile, int cam_ind
}
show_image(in, base);
cvWaitKey(10);
free_detections(dets, num_boxes(net));
free_detections(dets, nboxes);
free_image(in_s);
@ -756,12 +753,12 @@ void extract_detector(char *datacfg, char *cfgfile, char *weightfile, int cam_in
image in_s = letterbox_image(in, net->w, net->h);
layer l = net->layers[net->n-1];
int nboxes = num_boxes(net);
show_image(in, base);
int nboxes = 0;
float *X = in_s.data;
network_predict(net, X);
detection *dets = get_network_boxes(net, in.w, in.h, thresh, 0, 0, 1);
detection *dets = get_network_boxes(net, in.w, in.h, thresh, 0, 0, 1, &nboxes);
//if (nms) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms);
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
@ -779,7 +776,7 @@ void extract_detector(char *datacfg, char *cfgfile, char *weightfile, int cam_in
free_image(bim);
}
}
free_detections(dets, num_boxes(net));
free_detections(dets, nboxes);
free_image(in_s);
@ -795,6 +792,7 @@ void extract_detector(char *datacfg, char *cfgfile, char *weightfile, int cam_in
}
}
/*
void network_detect(network *net, image im, float thresh, float hier_thresh, float nms, detection *dets)
{
network_predict_image(net, im);
@ -803,6 +801,7 @@ void network_detect(network *net, image im, float thresh, float hier_thresh, flo
fill_network_boxes(net, im.w, im.h, thresh, hier_thresh, 0, 0, dets);
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
}
*/
void run_detector(int argc, char **argv)
{

View File

@ -133,7 +133,7 @@ void validate_yolo(char *cfg, char *weights)
image *buf = calloc(nthreads, sizeof(image));
image *buf_resized = calloc(nthreads, sizeof(image));
pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
detection *dets = make_network_boxes(net);
detection *dets = make_network_boxes(net, 0);
load_args args = {0};
args.w = net->w;
@ -200,7 +200,7 @@ void validate_yolo_recall(char *cfg, char *weights)
snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
fps[j] = fopen(buff, "w");
}
detection *dets = make_network_boxes(net);
detection *dets = make_network_boxes(net, 0);
int m = plist->size;
int i=0;
@ -271,7 +271,7 @@ void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)
char buff[256];
char *input = buff;
float nms=.4;
detection *dets = make_network_boxes(net);
detection *dets = make_network_boxes(net, 0);
while(1){
if(filename){
strncpy(input, filename, 256);

View File

@ -682,7 +682,7 @@ void save_weights_upto(network *net, char *filename, int cutoff);
void load_weights_upto(network *net, char *filename, int start, int cutoff);
void zero_objectness(layer l);
void get_region_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, float tree_thresh, int relative, detection *dets);
int get_region_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, float tree_thresh, int relative, detection *dets);
void free_network(network *net);
void set_batch_network(network *net, int b);
void set_temp_network(network *net, float t);
@ -739,10 +739,7 @@ int network_width(network *net);
int network_height(network *net);
float *network_predict_image(network *net, image im);
void network_detect(network *net, image im, float thresh, float hier_thresh, float nms, detection *dets);
int num_boxes(network *net);
detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative);
void fill_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets);
detection *make_network_boxes(network *net);
detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, int *num);
void free_detections(detection *dets, int n);
void reset_network_state(network *net, int b);

View File

@ -63,7 +63,7 @@ make_image.argtypes = [c_int, c_int, c_int]
make_image.restype = IMAGE
get_network_boxes = lib.get_network_boxes
get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int]
get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int)]
get_network_boxes.restype = POINTER(DETECTION)
make_network_boxes = lib.make_network_boxes
@ -76,10 +76,6 @@ free_detections.argtypes = [POINTER(DETECTION), c_int]
free_ptrs = lib.free_ptrs
free_ptrs.argtypes = [POINTER(c_void_p), c_int]
num_boxes = lib.num_boxes
num_boxes.argtypes = [c_void_p]
num_boxes.restype = c_int
network_predict = lib.network_predict
network_predict.argtypes = [c_void_p, POINTER(c_float)]
@ -128,9 +124,11 @@ def classify(net, meta, im):
def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
im = load_image(image, 0, 0)
num = num_boxes(net)
num = c_int(0)
pnum = pointer(num)
predict_image(net, im)
dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0)
dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum)
num = pnum[0]
if (nms): do_nms_obj(dets, num, meta.classes, nms);
res = []

View File

@ -21,6 +21,17 @@ int nms_comparator(const void *pa, const void *pb)
void do_nms_obj(detection *dets, int total, int classes, float thresh)
{
int i, j, k;
k = total-1;
for(i = 0; i <= k; ++i){
if(dets[i].objectness == 0){
detection swap = dets[i];
dets[i] = dets[k];
dets[k] = swap;
--k;
--i;
}
}
total = k+1;
for(i = 0; i < total; ++i){
dets[i].sort_class = -1;

View File

@ -30,17 +30,12 @@ static int running = 0;
static int demo_frame = 3;
static int demo_index = 0;
static int demo_detections = 0;
//static float **predictions;
static detection **dets;
static detection *avg;
//static float *avg;
static int demo_done = 0;
double demo_time;
detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative);
detection *make_network_boxes(network *net);
void fill_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets);
detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, int *num);
void *detect_in_thread(void *ptr)
{
@ -55,12 +50,15 @@ void *detect_in_thread(void *ptr)
if(l.type == DETECTION){
get_detection_boxes(l, 1, 1, demo_thresh, probs, boxes, 0);
} else */
detection *dets;
int nboxes = 0;
if (l.type == REGION){
fill_network_boxes(net, buff[0].w, buff[0].h, demo_thresh, demo_hier, 0, 1, dets[demo_index]);
dets = get_network_boxes(net, buff[0].w, buff[0].h, demo_thresh, demo_hier, 0, 1, &nboxes);
} else {
error("Last layer must produce detections\n");
}
/*
int i,j;
box zero = {0};
int classes = l.classes;
@ -79,15 +77,17 @@ void *detect_in_thread(void *ptr)
//copy_cpu(classes, dets[0][i].prob, 1, avg[i].prob, 1);
//avg[i].objectness = dets[0][i].objectness;
}
*/
if (nms > 0) do_nms_obj(avg, demo_detections, l.classes, nms);
if (nms > 0) do_nms_obj(dets, nboxes, l.classes, nms);
printf("\033[2J");
printf("\033[1;1H");
printf("\nFPS:%.1f\n",fps);
printf("Objects:\n\n");
image display = buff[(buff_index+2) % 3];
draw_detections(display, avg, demo_detections, demo_thresh, demo_names, demo_alphabet, demo_classes);
draw_detections(display, dets, nboxes, demo_thresh, demo_names, demo_alphabet, demo_classes);
free_detections(dets, nboxes);
demo_index = (demo_index + 1)%demo_frame;
running = 0;
@ -174,11 +174,7 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
if(!cap) error("Couldn't connect to webcam.\n");
demo_detections = num_boxes(net);
avg = make_network_boxes(net);
dets = calloc(demo_frame, sizeof(detection*));
int i;
for(i = 0; i < demo_frame; ++i) dets[i] = make_network_boxes(net);
buff[0] = get_image_from_stream(cap);
buff[1] = copy_image(buff[0]);

View File

@ -502,24 +502,28 @@ float *network_predict(network *net, float *input)
return out;
}
int num_boxes(network *net)
int num_detections(network *net, float thresh)
{
int i;
int s = 0;
for(i = 0; i < net->n; ++i){
layer l = net->layers[i];
if(l.type == REGION || l.type == DETECTION){
if(l.type == REGION){
s += region_num_detections(l, thresh);
}
if(l.type == DETECTION){
s += l.w*l.h*l.n;
}
}
return s;
}
detection *make_network_boxes(network *net)
detection *make_network_boxes(network *net, float thresh, int *num)
{
layer l = net->layers[net->n - 1];
int i;
int nboxes = num_boxes(net);
int nboxes = num_detections(net, thresh);
if(num) *num = nboxes;
detection *dets = calloc(nboxes, sizeof(detection));
for(i = 0; i < nboxes; ++i){
dets[i].prob = calloc(l.classes, sizeof(float));
@ -529,14 +533,15 @@ detection *make_network_boxes(network *net)
}
return dets;
}
void fill_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets)
{
int j;
for(j = 0; j < net->n; ++j){
layer l = net->layers[j];
if(l.type == REGION){
get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets);
dets += l.w*l.h*l.n;
int count = get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets);
dets += count;
}
if(l.type == DETECTION){
get_detection_detections(l, w, h, thresh, dets);
@ -545,9 +550,9 @@ void fill_network_boxes(network *net, int w, int h, float thresh, float hier, in
}
}
detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative)
detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, int *num)
{
detection *dets = make_network_boxes(net);
detection *dets = make_network_boxes(net, thresh, num);
fill_network_boxes(net, w, h, thresh, hier, map, relative, dets);
return dets;
}

View File

@ -412,44 +412,69 @@ void correct_region_boxes(detection *dets, int n, int w, int h, int netw, int ne
}
}
void get_region_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, float tree_thresh, int relative, detection *dets)
int region_num_detections(layer l, float thresh)
{
int i,j,n,z;
float *predictions = l.output;
if (l.batch == 2) {
float *flip = l.output + l.outputs;
for (j = 0; j < l.h; ++j) {
for (i = 0; i < l.w/2; ++i) {
for (n = 0; n < l.n; ++n) {
for(z = 0; z < l.classes + l.coords + 1; ++z){
int i1 = z*l.w*l.h*l.n + n*l.w*l.h + j*l.w + i;
int i2 = z*l.w*l.h*l.n + n*l.w*l.h + j*l.w + (l.w - i - 1);
float swap = flip[i1];
flip[i1] = flip[i2];
flip[i2] = swap;
if(z == 0){
flip[i1] = -flip[i1];
flip[i2] = -flip[i2];
}
}
}
}
}
for(i = 0; i < l.outputs; ++i){
l.output[i] = (l.output[i] + flip[i])/2.;
}
}
int i, n;
int count = 0;
for (i = 0; i < l.w*l.h; ++i){
int row = i / l.w;
int col = i % l.w;
for(n = 0; n < l.n; ++n){
int index = n*l.w*l.h + i;
int obj_index = entry_index(l, 0, n*l.w*l.h + i, l.coords);
if(l.output[obj_index] > thresh){
++count;
}
}
}
return count;
}
void avg_flipped_region(layer l)
{
int i,j,n,z;
float *flip = l.output + l.outputs;
for (j = 0; j < l.h; ++j) {
for (i = 0; i < l.w/2; ++i) {
for (n = 0; n < l.n; ++n) {
for(z = 0; z < l.classes + l.coords + 1; ++z){
int i1 = z*l.w*l.h*l.n + n*l.w*l.h + j*l.w + i;
int i2 = z*l.w*l.h*l.n + n*l.w*l.h + j*l.w + (l.w - i - 1);
float swap = flip[i1];
flip[i1] = flip[i2];
flip[i2] = swap;
if(z == 0){
flip[i1] = -flip[i1];
flip[i2] = -flip[i2];
}
}
}
}
}
for(i = 0; i < l.outputs; ++i){
l.output[i] = (l.output[i] + flip[i])/2.;
}
}
int get_region_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, float tree_thresh, int relative, detection *dets)
{
int i,j,n,z;
float *predictions = l.output;
if (l.batch == 2) avg_flipped_region(l);
int count = 0;
for (i = 0; i < l.w*l.h; ++i){
int row = i / l.w;
int col = i % l.w;
for(n = 0; n < l.n; ++n){
int obj_index = entry_index(l, 0, n*l.w*l.h + i, l.coords);
if(predictions[obj_index] <= thresh) continue;
int index = count;
++count;
int box_index = entry_index(l, 0, n*l.w*l.h + i, 0);
int mask_index = entry_index(l, 0, n*l.w*l.h + i, 4);
for (j = 0; j < l.classes; ++j) {
dets[index].prob[j] = 0;
}
int obj_index = entry_index(l, 0, n*l.w*l.h + i, l.coords);
int box_index = entry_index(l, 0, n*l.w*l.h + i, 0);
int mask_index = entry_index(l, 0, n*l.w*l.h + i, 4);
float scale = l.background ? 1 : predictions[obj_index];
dets[index].bbox = get_region_box(predictions, l.biases, l.mask[n], box_index, col, row, l.w, l.h, netw, neth, l.w*l.h);
dets[index].objectness = scale > thresh ? scale : 0;
@ -485,7 +510,8 @@ void get_region_detections(layer l, int w, int h, int netw, int neth, float thre
}
}
}
correct_region_boxes(dets, l.w*l.h*l.n, w, h, netw, neth, relative);
correct_region_boxes(dets, count, w, h, netw, neth, relative);
return count;
}
#ifdef GPU

View File

@ -9,6 +9,7 @@ layer make_region_layer(int batch, int h, int w, int n, int total, int *mask, in
void forward_region_layer(const layer l, network net);
void backward_region_layer(const layer l, network net);
void resize_region_layer(layer *l, int w, int h);
int region_num_detections(layer l, float thresh);
#ifdef GPU
void forward_region_layer_gpu(const layer l, network net);