From 9c9344a1ff4541499c8f69ea997be5145b4e1de3 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Tue, 21 May 2019 23:47:47 +0300 Subject: [PATCH] Added flag -letter_box for Detection --- include/darknet.h | 2 +- src/coco.c | 2 +- src/darknet.c | 2 +- src/demo.c | 7 ++++--- src/demo.h | 2 +- src/detector.c | 15 ++++++++------- src/yolo.c | 2 +- 7 files changed, 17 insertions(+), 15 deletions(-) diff --git a/include/darknet.h b/include/darknet.h index 68ab1fc7..0eb52b37 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -796,7 +796,7 @@ LIB_API float *network_predict_image(network *net, image im); LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, network *existing_net); LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port, int show_imgs); LIB_API void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, - float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile); + float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile, int letter_box); LIB_API int network_width(network *net); LIB_API int network_height(network *net); LIB_API void optimize_picture(network *net, image orig, int max_layer, float scale, float rate, float thresh, int norm); diff --git a/src/coco.c b/src/coco.c index 931c406c..c1535a35 100644 --- a/src/coco.c +++ b/src/coco.c @@ -384,5 +384,5 @@ void run_coco(int argc, char **argv) else if(0==strcmp(argv[2], "valid")) validate_coco(cfg, weights); else if(0==strcmp(argv[2], "recall")) validate_coco_recall(cfg, weights); else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, hier_thresh, cam_index, filename, coco_classes, 80, frame_skip, - prefix, out_filename, mjpeg_port, json_port, dont_show, ext_output); + prefix, out_filename, mjpeg_port, json_port, dont_show, ext_output, 0); } diff --git a/src/darknet.c b/src/darknet.c index 06092d48..f7ac6593 100644 --- a/src/darknet.c +++ b/src/darknet.c @@ -476,7 +476,7 @@ int main(int argc, char **argv) float thresh = find_float_arg(argc, argv, "-thresh", .24); int ext_output = find_arg(argc, argv, "-ext_output"); char *filename = (argc > 4) ? argv[4]: 0; - test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh, 0.5, 0, ext_output, 0, NULL); + test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh, 0.5, 0, ext_output, 0, NULL, 0); } else if (0 == strcmp(argv[1], "cifar")){ run_cifar(argc, argv); } else if (0 == strcmp(argv[1], "go")){ diff --git a/src/demo.c b/src/demo.c index 4262b03f..6c7f5d39 100644 --- a/src/demo.c +++ b/src/demo.c @@ -50,7 +50,7 @@ mat_cv* det_img; mat_cv* show_img; static volatile int flag_exit; -static const int letter_box = 0; +static int letter_box = 0; void *fetch_in_thread(void *ptr) { @@ -104,8 +104,9 @@ double get_wall_time() } void demo(char *cfgfile, char *weightfile, float thresh, float hier_thresh, int cam_index, const char *filename, char **names, int classes, - int frame_skip, char *prefix, char *out_filename, int mjpeg_port, int json_port, int dont_show, int ext_output) + int frame_skip, char *prefix, char *out_filename, int mjpeg_port, int json_port, int dont_show, int ext_output, int letter_box_in) { + letter_box = letter_box_in; in_img = det_img = show_img = NULL; //skip = frame_skip; image **alphabet = load_alphabet(); @@ -321,7 +322,7 @@ void demo(char *cfgfile, char *weightfile, float thresh, float hier_thresh, int } #else void demo(char *cfgfile, char *weightfile, float thresh, float hier_thresh, int cam_index, const char *filename, char **names, int classes, - int frame_skip, char *prefix, char *out_filename, int mjpeg_port, int json_port, int dont_show, int ext_output) + int frame_skip, char *prefix, char *out_filename, int mjpeg_port, int json_port, int dont_show, int ext_output, int letter_box_in) { fprintf(stderr, "Demo needs OpenCV for webcam images.\n"); } diff --git a/src/demo.h b/src/demo.h index b26b9592..1f749b89 100644 --- a/src/demo.h +++ b/src/demo.h @@ -6,7 +6,7 @@ extern "C" { #endif void demo(char *cfgfile, char *weightfile, float thresh, float hier_thresh, int cam_index, const char *filename, char **names, int classes, - int frame_skip, char *prefix, char *out_filename, int mjpeg_port, int json_port, int dont_show, int ext_output); + int frame_skip, char *prefix, char *out_filename, int mjpeg_port, int json_port, int dont_show, int ext_output, int letter_box_in); #ifdef __cplusplus } #endif diff --git a/src/detector.c b/src/detector.c index 9199673c..d7ed7af6 100644 --- a/src/detector.c +++ b/src/detector.c @@ -1256,7 +1256,7 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, - float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile) + float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile, int letter_box) { list *options = read_data_cfg(datacfg); char *name_list = option_find_str(options, "names", "data/names.list"); @@ -1304,9 +1304,9 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam //image im; //image sized = load_image_resize(input, net.w, net.h, net.c, &im); image im = load_image(input, 0, 0, net.c); - image sized = resize_image(im, net.w, net.h); - int letterbox = 0; - //image sized = letterbox_image(im, net.w, net.h); letterbox = 1; + image sized; + if(letter_box) sized = letterbox_image(im, net.w, net.h); + else sized = resize_image(im, net.w, net.h); layer l = net.layers[net.n - 1]; //box *boxes = calloc(l.w*l.h*l.n, sizeof(box)); @@ -1323,7 +1323,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam //printf("%s: Predicted in %f seconds.\n", input, (what_time_is_it_now()-time)); int nboxes = 0; - detection *dets = get_network_boxes(&net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes, letterbox); + detection *dets = get_network_boxes(&net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes, letter_box); if (nms) do_nms_sort(dets, nboxes, l.classes, nms); draw_detections_v3(im, dets, nboxes, thresh, names, alphabet, l.classes, ext_output); save_image(im, "predictions"); @@ -1409,6 +1409,7 @@ void run_detector(int argc, char **argv) { int dont_show = find_arg(argc, argv, "-dont_show"); int show = find_arg(argc, argv, "-show"); + int letter_box = find_arg(argc, argv, "-letter_box"); int calc_map = find_arg(argc, argv, "-map"); int map_points = find_int_arg(argc, argv, "-points", 0); check_mistakes = find_arg(argc, argv, "-check_mistakes"); @@ -1467,7 +1468,7 @@ void run_detector(int argc, char **argv) if (strlen(weights) > 0) if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0; char *filename = (argc > 6) ? argv[6] : 0; - if (0 == strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels, outfile); + if (0 == strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels, outfile, letter_box); else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map, mjpeg_port, show_imgs); else if (0 == strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile); else if (0 == strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights); @@ -1482,7 +1483,7 @@ void run_detector(int argc, char **argv) if (strlen(filename) > 0) if (filename[strlen(filename) - 1] == 0x0d) filename[strlen(filename) - 1] = 0; demo(cfg, weights, thresh, hier_thresh, cam_index, filename, names, classes, frame_skip, prefix, out_filename, - mjpeg_port, json_port, dont_show, ext_output); + mjpeg_port, json_port, dont_show, ext_output, letter_box); free_list_contents_kvp(options); free_list(options); diff --git a/src/yolo.c b/src/yolo.c index 07a2092c..711470ea 100644 --- a/src/yolo.c +++ b/src/yolo.c @@ -351,5 +351,5 @@ void run_yolo(int argc, char **argv) else if(0==strcmp(argv[2], "valid")) validate_yolo(cfg, weights); else if(0==strcmp(argv[2], "recall")) validate_yolo_recall(cfg, weights); else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, hier_thresh, cam_index, filename, voc_names, 20, frame_skip, - prefix, out_filename, mjpeg_port, json_port, dont_show, ext_output); + prefix, out_filename, mjpeg_port, json_port, dont_show, ext_output, 0); }