redid the demo for TED, little faster

This commit is contained in:
Joseph Redmon 2017-04-30 13:54:40 -07:00
parent 9726f1e89c
commit 72a2fe93f9
11 changed files with 221 additions and 172 deletions

View File

@ -5,8 +5,8 @@ subdivisions=1
# Training
# batch=64
# subdivisions=8
width=608
height=608
width=416
height=416
channels=3
momentum=0.9
decay=0.0005

View File

@ -698,7 +698,7 @@ void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *fi
float *X = r.data;
time=clock();
float *predictions = network_predict(net, X);
if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy, 0, 1);
if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy, 1, 1);
top_k(predictions, net.outputs, top, indexes);
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
for(i = 0; i < top; ++i){

View File

@ -376,9 +376,10 @@ void run_coco(int argc, char **argv)
char *cfg = argv[3];
char *weights = (argc > 4) ? argv[4] : 0;
char *filename = (argc > 5) ? argv[5]: 0;
int avg = find_int_arg(argc, argv, "-avg", 1);
if(0==strcmp(argv[2], "test")) test_coco(cfg, weights, filename, thresh);
else if(0==strcmp(argv[2], "train")) train_coco(cfg, weights);
else if(0==strcmp(argv[2], "valid")) validate_coco(cfg, weights);
else if(0==strcmp(argv[2], "recall")) validate_coco_recall(cfg, weights);
else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, coco_classes, 80, frame_skip, prefix, .5, 0,0,0,0);
else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, coco_classes, 80, frame_skip, prefix, avg, .5, 0,0,0,0);
}

View File

@ -9,7 +9,6 @@
#include "demo.h"
#include <sys/time.h>
#define FRAMES 3
#define DEMO 1
#ifdef OPENCV
@ -21,65 +20,26 @@ static int demo_classes;
static float **probs;
static box *boxes;
static network net;
static image in ;
static image in_s ;
static image det ;
static image det_s;
static image disp = {0};
static image buff [3];
static image buff_letter[3];
static int buff_index = 0;
static CvCapture * cap;
static IplImage * ipl;
static float fps = 0;
static float demo_thresh = 0;
static float demo_hier = .5;
static int running = 0;
static float *predictions[FRAMES];
static int demo_delay = 0;
static int demo_frame = 5;
static int demo_detections = 0;
static float **predictions;
static int demo_index = 0;
static image images[FRAMES];
static int demo_done = 0;
static float *last_avg2;
static float *last_avg;
static float *avg;
void *fetch_in_thread(void *ptr)
{
in = get_image_from_stream(cap);
if(!in.data){
error("Stream closed.");
}
in_s = letterbox_image(in, net.w, net.h);
return 0;
}
void *detect_in_thread(void *ptr)
{
float nms = .4;
layer l = net.layers[net.n-1];
float *X = det_s.data;
float *prediction = network_predict(net, X);
memcpy(predictions[demo_index], prediction, l.outputs*sizeof(float));
mean_arrays(predictions, FRAMES, l.outputs, avg);
l.output = avg;
free_image(det_s);
if(l.type == DETECTION){
get_detection_boxes(l, 1, 1, demo_thresh, probs, boxes, 0);
} else if (l.type == REGION){
get_region_boxes(l, in.w, in.h, net.w, net.h, demo_thresh, probs, boxes, 0, 0, demo_hier, 1);
} else {
error("Last layer must produce detections\n");
}
if (nms > 0) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms);
printf("\033[2J");
printf("\033[1;1H");
printf("\nFPS:%.1f\n",fps);
printf("Objects:\n\n");
images[demo_index] = det;
det = images[(demo_index + FRAMES/2 + 1)%FRAMES];
demo_index = (demo_index + 1)%FRAMES;
draw_detections(det, l.w*l.h*l.n, demo_thresh, boxes, probs, demo_names, demo_alphabet, demo_classes);
return 0;
}
double demo_time;
double get_wall_time()
{
@ -90,11 +50,95 @@ double get_wall_time()
return (double)time.tv_sec + (double)time.tv_usec * .000001;
}
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, float hier, int w, int h, int frames, int fullscreen)
void *detect_in_thread(void *ptr)
{
//skip = frame_skip;
running = 1;
float nms = .4;
layer l = net.layers[net.n-1];
float *X = buff_letter[(buff_index+2)%3].data;
float *prediction = network_predict(net, X);
memcpy(predictions[demo_index], prediction, l.outputs*sizeof(float));
mean_arrays(predictions, demo_frame, l.outputs, avg);
l.output = last_avg2;
if(demo_delay == 0) l.output = avg;
if(l.type == DETECTION){
get_detection_boxes(l, 1, 1, demo_thresh, probs, boxes, 0);
} else if (l.type == REGION){
get_region_boxes(l, buff[0].w, buff[0].h, net.w, net.h, demo_thresh, probs, boxes, 0, 0, demo_hier, 1);
} else {
error("Last layer must produce detections\n");
}
if (nms > 0) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms);
printf("\033[2J");
printf("\033[1;1H");
printf("\nFPS:%.1f\n",fps);
printf("Objects:\n\n");
image display = buff[(buff_index+2) % 3];
draw_detections(display, demo_detections, demo_thresh, boxes, probs, demo_names, demo_alphabet, demo_classes);
demo_index = (demo_index + 1)%demo_frame;
running = 0;
return 0;
}
void *fetch_in_thread(void *ptr)
{
int status = fill_image_from_stream(cap, buff[buff_index]);
letterbox_image_into(buff[buff_index], net.w, net.h, buff_letter[buff_index]);
if(status == 0) demo_done = 1;
return 0;
}
void *display_in_thread(void *ptr)
{
show_image_cv(buff[(buff_index + 1)%3], "Demo", ipl);
int c = cvWaitKey(1);
if (c != -1) c = c%256;
if (c == 10){
if(demo_delay == 0) demo_delay = 60;
else if(demo_delay == 5) demo_delay = 0;
else if(demo_delay == 60) demo_delay = 5;
else demo_delay = 0;
} else if (c == 27) {
demo_done = 1;
return 0;
} else if (c == 82) {
demo_thresh += .02;
} else if (c == 84) {
demo_thresh -= .02;
if(demo_thresh <= .02) demo_thresh = .02;
} else if (c == 83) {
demo_hier += .02;
} else if (c == 81) {
demo_hier -= .02;
if(demo_hier <= .0) demo_hier = .0;
}
return 0;
}
void *display_loop(void *ptr)
{
while(1){
display_in_thread(0);
}
}
void *detect_loop(void *ptr)
{
while(1){
detect_in_thread(0);
}
}
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int delay, char *prefix, int avg_frames, float hier, int w, int h, int frames, int fullscreen)
{
demo_delay = delay;
demo_frame = avg_frames;
predictions = calloc(demo_frame, sizeof(float*));
image **alphabet = load_alphabet();
int delay = frame_skip;
demo_names = names;
demo_alphabet = alphabet;
demo_classes = classes;
@ -106,6 +150,8 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
load_weights(&net, weightfile);
}
set_batch_network(&net, 1);
pthread_t detect_thread;
pthread_t fetch_thread;
srand(2222222);
@ -129,36 +175,25 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
if(!cap) error("Couldn't connect to webcam.\n");
layer l = net.layers[net.n-1];
demo_detections = l.n*l.w*l.h;
int j;
avg = (float *) calloc(l.outputs, sizeof(float));
for(j = 0; j < FRAMES; ++j) predictions[j] = (float *) calloc(l.outputs, sizeof(float));
for(j = 0; j < FRAMES; ++j) images[j] = make_image(1,1,3);
last_avg = (float *) calloc(l.outputs, sizeof(float));
last_avg2 = (float *) calloc(l.outputs, sizeof(float));
for(j = 0; j < demo_frame; ++j) predictions[j] = (float *) calloc(l.outputs, sizeof(float));
boxes = (box *)calloc(l.w*l.h*l.n, sizeof(box));
probs = (float **)calloc(l.w*l.h*l.n, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = (float *)calloc(l.classes+1, sizeof(float));
pthread_t fetch_thread;
pthread_t detect_thread;
fetch_in_thread(0);
det = in;
det_s = in_s;
fetch_in_thread(0);
detect_in_thread(0);
disp = det;
det = in;
det_s = in_s;
for(j = 0; j < FRAMES/2; ++j){
fetch_in_thread(0);
detect_in_thread(0);
disp = det;
det = in;
det_s = in_s;
}
buff[0] = get_image_from_stream(cap);
buff[1] = copy_image(buff[0]);
buff[2] = copy_image(buff[0]);
buff_letter[0] = letterbox_image(buff[0], net.w, net.h);
buff_letter[1] = letterbox_image(buff[0], net.w, net.h);
buff_letter[2] = letterbox_image(buff[0], net.w, net.h);
ipl = cvCreateImage(cvSize(buff[0].w,buff[0].h), IPL_DEPTH_8U, buff[0].c);
int count = 0;
if(!prefix){
@ -171,76 +206,34 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
}
}
double before = get_wall_time();
demo_time = get_wall_time();
while(1){
while(!demo_done){
buff_index = (buff_index + 1) %3;
if(pthread_create(&fetch_thread, 0, fetch_in_thread, 0)) error("Thread creation failed");
if(pthread_create(&detect_thread, 0, detect_in_thread, 0)) error("Thread creation failed");
if(!prefix){
if(count % (demo_delay+1) == 0){
fps = 1./(get_wall_time() - demo_time);
demo_time = get_wall_time();
float *swap = last_avg;
last_avg = last_avg2;
last_avg2 = swap;
memcpy(last_avg, avg, l.outputs*sizeof(float));
}
display_in_thread(0);
}else{
char name[256];
sprintf(name, "%s_%08d", prefix, count);
save_image(buff[(buff_index + 1)%3], name);
}
pthread_join(fetch_thread, 0);
pthread_join(detect_thread, 0);
++count;
if(1){
if(pthread_create(&fetch_thread, 0, fetch_in_thread, 0)) error("Thread creation failed");
if(pthread_create(&detect_thread, 0, detect_in_thread, 0)) error("Thread creation failed");
if(!prefix){
show_image(disp, "Demo");
int c = cvWaitKey(1);
if (c != -1) c = c%256;
if (c == 10){
if(frame_skip == 0) frame_skip = 60;
else if(frame_skip == 4) frame_skip = 0;
else if(frame_skip == 60) frame_skip = 4;
else frame_skip = 0;
} else if (c == 27) {
return;
} else if (c == 82) {
demo_thresh += .02;
} else if (c == 84) {
demo_thresh -= .02;
if(demo_thresh <= .02) demo_thresh = .02;
} else if (c == 83) {
demo_hier += .02;
} else if (c == 81) {
demo_hier -= .02;
if(demo_hier <= .0) demo_hier = .0;
}
}else{
char buff[256];
sprintf(buff, "%s_%08d", prefix, count);
save_image(disp, buff);
}
pthread_join(fetch_thread, 0);
pthread_join(detect_thread, 0);
if(delay == 0){
free_image(disp);
disp = det;
}
det = in;
det_s = in_s;
}else {
fetch_in_thread(0);
det = in;
det_s = in_s;
detect_in_thread(0);
if(delay == 0) {
free_image(disp);
disp = det;
}
show_image(disp, "Demo");
cvWaitKey(1);
}
--delay;
if(delay < 0){
delay = frame_skip;
double after = get_wall_time();
float curr = 1./(after - before);
fps = curr;
before = after;
}
}
}
#else
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, float hier_thresh, int w, int h, int fps, int fullscreen)
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int delay, char *prefix, int avg, float hier, int w, int h, int frames, int fullscreen)
{
fprintf(stderr, "Demo needs OpenCV for webcam images.\n");
}

View File

@ -2,6 +2,6 @@
#define DEMO_H
#include "image.h"
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, float hier_thresh, int w, int h, int fps, int fullscreen);
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, int avg, float hier_thresh, int w, int h, int fps, int fullscreen);
#endif

View File

@ -259,8 +259,8 @@ void forward_detection_layer_gpu(const detection_layer l, network net)
return;
}
float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
float *truth_cpu = 0;
//float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
//float *truth_cpu = 0;
forward_detection_layer(l, net);
cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);

View File

@ -292,7 +292,7 @@ void validate_detector_flip(char *datacfg, char *cfgfile, char *weightfile, char
box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes+1, sizeof(float *));
int m = plist->size;
int i=0;
@ -428,7 +428,7 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile, char *out
box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes+1, sizeof(float *));
int m = plist->size;
int i=0;
@ -521,7 +521,7 @@ void validate_detector_recall(char *cfgfile, char *weightfile)
int j, k;
box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes+1, sizeof(float *));
int m = plist->size;
int i=0;
@ -659,6 +659,7 @@ void run_detector(int argc, char **argv)
float hier_thresh = find_float_arg(argc, argv, "-hier", .5);
int cam_index = find_int_arg(argc, argv, "-c", 0);
int frame_skip = find_int_arg(argc, argv, "-s", 0);
int avg = find_int_arg(argc, argv, "-avg", 3);
if(argc < 4){
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
return;
@ -707,6 +708,6 @@ void run_detector(int argc, char **argv)
int classes = option_find_int(options, "classes", 20);
char *name_list = option_find_str(options, "names", "data/names.list");
char **names = get_labels(name_list);
demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix, hier_thresh, width, height, fps, fullscreen);
demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix, avg, hier_thresh, width, height, fps, fullscreen);
}
}

View File

@ -216,6 +216,7 @@ void draw_detections(image im, int num, float thresh, box *boxes, float **probs,
if (alphabet) {
image label = get_label(alphabet, names[class], (im.h*.03)/10);
draw_label(im, top + width, left, label, rgb);
free_image(label);
}
}
}
@ -394,6 +395,11 @@ void normalize_image2(image p)
free(max);
}
void copy_image_into(image src, image dest)
{
memcpy(dest.data, src.data, src.h*src.w*src.c*sizeof(float));
}
image copy_image(image p)
{
image copy = p;
@ -413,19 +419,16 @@ void rgbgr_image(image im)
}
#ifdef OPENCV
void show_image_cv(image p, const char *name)
void show_image_cv(image p, const char *name, IplImage *disp)
{
int x,y,k;
image copy = copy_image(p);
constrain_image(copy);
if(p.c == 3) rgbgr_image(copy);
if(p.c == 3) rgbgr_image(p);
//normalize_image(copy);
char buff[256];
//sprintf(buff, "%s (%d)", name, windows);
sprintf(buff, "%s", name);
IplImage *disp = cvCreateImage(cvSize(p.w,p.h), IPL_DEPTH_8U, p.c);
int step = disp->widthStep;
cvNamedWindow(buff, CV_WINDOW_NORMAL);
//cvMoveWindow(buff, 100*(windows%10) + 200*(windows/10), 100*(windows%10));
@ -433,11 +436,10 @@ void show_image_cv(image p, const char *name)
for(y = 0; y < p.h; ++y){
for(x = 0; x < p.w; ++x){
for(k= 0; k < p.c; ++k){
disp->imageData[y*step + x*p.c + k] = (unsigned char)(get_pixel(copy,x,y,k)*255);
disp->imageData[y*step + x*p.c + k] = (unsigned char)(get_pixel(p,x,y,k)*255);
}
}
}
free_image(copy);
if(0){
int w = 448;
int h = w*p.h/p.w;
@ -451,14 +453,18 @@ void show_image_cv(image p, const char *name)
cvReleaseImage(&buffer);
}
cvShowImage(buff, disp);
cvReleaseImage(&disp);
}
#endif
void show_image(image p, const char *name)
{
#ifdef OPENCV
show_image_cv(p, name);
IplImage *disp = cvCreateImage(cvSize(p.w,p.h), IPL_DEPTH_8U, p.c);
image copy = copy_image(p);
constrain_image(copy);
show_image_cv(copy, name, disp);
free_image(copy);
cvReleaseImage(&disp);
#else
fprintf(stderr, "Not compiled with OpenCV, saving to %s.png instead\n", name);
save_image(p, name);
@ -467,23 +473,31 @@ void show_image(image p, const char *name)
#ifdef OPENCV
image ipl_to_image(IplImage* src)
void ipl_into_image(IplImage* src, image im)
{
unsigned char *data = (unsigned char *)src->imageData;
int h = src->height;
int w = src->width;
int c = src->nChannels;
int step = src->widthStep;
image out = make_image(w, h, c);
int i, j, k, count=0;;
int i, j, k;
for(k= 0; k < c; ++k){
for(i = 0; i < h; ++i){
for(i = 0; i < h; ++i){
for(k= 0; k < c; ++k){
for(j = 0; j < w; ++j){
out.data[count++] = data[i*step + j*c + k]/255.;
im.data[k*w*h + i*w + j] = data[i*step + j*c + k]/255.;
}
}
}
}
image ipl_to_image(IplImage* src)
{
int h = src->height;
int w = src->width;
int c = src->nChannels;
image out = make_image(w, h, c);
ipl_into_image(src, out);
return out;
}
@ -513,6 +527,14 @@ image load_image_cv(char *filename, int channels)
return out;
}
void flush_stream_buffer(CvCapture *cap, int n)
{
int i;
for(i = 0; i < n; ++i) {
cvQueryFrame(cap);
}
}
image get_image_from_stream(CvCapture *cap)
{
IplImage* src = cvQueryFrame(cap);
@ -522,6 +544,15 @@ image get_image_from_stream(CvCapture *cap)
return im;
}
int fill_image_from_stream(CvCapture *cap, image im)
{
IplImage* src = cvQueryFrame(cap);
if (!src) return 0;
ipl_into_image(src, im);
rgbgr_image(im);
return 1;
}
void save_image_jpg(image p, const char *name)
{
image copy = copy_image(p);
@ -794,6 +825,22 @@ void composite_3d(char *f1, char *f2, char *out, int delta)
#endif
}
void letterbox_image_into(image im, int w, int h, image boxed)
{
int new_w = im.w;
int new_h = im.h;
if (((float)w/im.w) < ((float)h/im.h)) {
new_w = w;
new_h = (im.h * w)/im.w;
} else {
new_h = h;
new_w = (im.w * h)/im.h;
}
image resized = resize_image(im, new_w, new_h);
embed_image(resized, boxed, (w-new_w)/2, (h-new_h)/2);
free_image(resized);
}
image letterbox_image(image im, int w, int h)
{
int new_w = im.w;

View File

@ -29,7 +29,11 @@ typedef struct {
#ifndef __cplusplus
#ifdef OPENCV
image get_image_from_stream(CvCapture *cap);
int fill_image_from_stream(CvCapture *cap, image im);
image ipl_to_image(IplImage* src);
void ipl_into_image(IplImage* src, image im);
void flush_stream_buffer(CvCapture *cap, int n);
void show_image_cv(image p, const char *name, IplImage *disp);
#endif
#endif
@ -49,6 +53,7 @@ image random_crop_image(image im, int w, int h);
image random_augment_image(image im, float angle, float aspect, int low, int high, int size);
void random_distort_image(image im, float hue, float saturation, float exposure);
image letterbox_image(image im, int w, int h);
void letterbox_image_into(image im, int w, int h, image boxed);
image resize_image(image im, int w, int h);
image resize_min(image im, int min);
image resize_max(image im, int max);
@ -96,6 +101,7 @@ image make_random_image(int w, int h, int c);
image make_empty_image(int w, int h, int c);
image float_to_image(int w, int h, int c, float *data);
image copy_image(image p);
void copy_image_into(image src, image dest);
image load_image(char *filename, int w, int h, int c);
image load_image_color(char *filename, int w, int h);
image **load_alphabet();

View File

@ -406,7 +406,7 @@ void get_region_boxes(layer l, int w, int h, int netw, int neth, float thresh, f
probs[index][j] = (prob > thresh) ? prob : 0;
if(prob > max) max = prob;
// TODO REMOVE
// if (j != 15 && j != 16) probs[index][j] = 0;
// if (j == 56 ) probs[index][j] = 0;
/*
if (j != 0) probs[index][j] = 0;
int blacklist[] = {121, 497, 482, 504, 122, 518,481, 418, 542, 491, 914, 478, 120, 510,500};

View File

@ -340,6 +340,7 @@ void run_yolo(int argc, char **argv)
return;
}
int avg = find_int_arg(argc, argv, "-avg", 1);
char *cfg = argv[3];
char *weights = (argc > 4) ? argv[4] : 0;
char *filename = (argc > 5) ? argv[5]: 0;
@ -347,5 +348,5 @@ void run_yolo(int argc, char **argv)
else if(0==strcmp(argv[2], "train")) train_yolo(cfg, weights);
else if(0==strcmp(argv[2], "valid")) validate_yolo(cfg, weights);
else if(0==strcmp(argv[2], "recall")) validate_yolo_recall(cfg, weights);
else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, voc_names, 20, frame_skip, prefix, .5, 0,0,0,0);
else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, voc_names, 20, frame_skip, prefix, avg, .5, 0,0,0,0);
}