From e34f0416f507499e9dbbc2557430850ba3a022ab Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Sat, 5 Aug 2017 01:47:58 +0300 Subject: [PATCH] Added detection on images from the txt list file by using SO/DLL. --- src/yolo_console_dll.cpp | 16 +++++++++++++--- src/yolo_v2_class.cpp | 21 ++++++++++++++++----- src/yolo_v2_class.hpp | 6 +++--- 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/src/yolo_console_dll.cpp b/src/yolo_console_dll.cpp index 2e05f7ac..e70f1deb 100644 --- a/src/yolo_console_dll.cpp +++ b/src/yolo_console_dll.cpp @@ -73,15 +73,15 @@ int main() #ifdef OPENCV std::string const file_ext = filename.substr(filename.find_last_of(".") + 1); if (file_ext == "avi" || file_ext == "mp4" || file_ext == "mjpg" || file_ext == "mov") { // video file - cv::Mat frame, prev_frame; + cv::Mat frame, prev_frame, det_frame; std::vector result_vec, thread_result_vec; detector.nms = 0.02; // comment it - if track_id is not required std::thread td([]() {}); for (cv::VideoCapture cap(filename); cap >> frame, cap.isOpened();) { td.join(); result_vec = thread_result_vec; - cv::Mat det_frame = frame; - td = std::thread([&]() { thread_result_vec = detector.detect(det_frame, 0.2); }); + det_frame = frame; + td = std::thread([&]() { thread_result_vec = detector.detect(det_frame, 0.2, true); }); if (!prev_frame.empty()) { result_vec = detector.tracking(result_vec); // comment it - if track_id is not required @@ -91,6 +91,16 @@ int main() prev_frame = frame; } } + else if (file_ext == "txt") { // list of image files + std::ifstream file(filename); + if (!file.is_open()) std::cout << "File not found! \n"; + else + for (std::string line; file >> line;) { + std::cout << line << std::endl; + show_result(detector.detect(cv::imread(line)), obj_names); + } + + } else { // image file cv::Mat mat_img = cv::imread(filename); std::vector result_vec = detector.detect(mat_img); diff --git a/src/yolo_v2_class.cpp b/src/yolo_v2_class.cpp index 813a24f2..a2fabcdd 100644 --- a/src/yolo_v2_class.cpp +++ b/src/yolo_v2_class.cpp @@ -29,6 +29,7 @@ struct detector_gpu_t{ image images[FRAMES]; float *avg; float *predictions[FRAMES]; + int demo_index; }; @@ -112,11 +113,11 @@ YOLODLL_API int Detector::get_net_height() { } -YOLODLL_API std::vector Detector::detect(std::string image_filename, float thresh) +YOLODLL_API std::vector Detector::detect(std::string image_filename, float thresh, bool use_mean) { std::shared_ptr image_ptr(new image_t, [](image_t *img) { if (img->data) free(img->data); delete img; }); *image_ptr = load_image(image_filename); - return detect(*image_ptr, thresh); + return detect(*image_ptr, thresh, use_mean); } static image load_image_stb(char *filename, int channels) @@ -163,7 +164,7 @@ YOLODLL_API void Detector::free_image(image_t m) } } -YOLODLL_API std::vector Detector::detect(image_t img, float thresh) +YOLODLL_API std::vector Detector::detect(image_t img, float thresh, bool use_mean) { detector_gpu_t &detector_gpu = *reinterpret_cast(detector_gpu_ptr.get()); @@ -196,7 +197,14 @@ YOLODLL_API std::vector Detector::detect(image_t img, float thresh) float *X = sized.data; - network_predict(net, X); + float *prediction = network_predict(net, X); + + if (use_mean) { + memcpy(detector_gpu.predictions[detector_gpu.demo_index], prediction, l.outputs * sizeof(float)); + mean_arrays(detector_gpu.predictions, FRAMES, l.outputs, detector_gpu.avg); + l.output = detector_gpu.avg; + detector_gpu.demo_index = (detector_gpu.demo_index + 1) % FRAMES; + } get_region_boxes(l, 1, 1, thresh, detector_gpu.probs, detector_gpu.boxes, 0, 0); if (nms) do_nms_sort(detector_gpu.boxes, detector_gpu.probs, l.w*l.h*l.n, l.classes, nms); @@ -269,8 +277,11 @@ YOLODLL_API std::vector Detector::tracking(std::vector cur_bbox_ bool track_id_absent = !std::any_of(cur_bbox_vec.begin(), cur_bbox_vec.end(), [&](bbox_t const& b) { return b.track_id == i.track_id; }); - if (cur_index >= 0 && track_id_absent) + if (cur_index >= 0 && track_id_absent) { cur_bbox_vec[cur_index].track_id = i.track_id; + cur_bbox_vec[cur_index].w = (cur_bbox_vec[cur_index].w + i.w) / 2; + cur_bbox_vec[cur_index].h = (cur_bbox_vec[cur_index].h + i.h) / 2; + } } } diff --git a/src/yolo_v2_class.hpp b/src/yolo_v2_class.hpp index c6cad84b..4d4960c8 100644 --- a/src/yolo_v2_class.hpp +++ b/src/yolo_v2_class.hpp @@ -47,8 +47,8 @@ public: YOLODLL_API Detector(std::string cfg_filename, std::string weight_filename, int gpu_id = 0); YOLODLL_API ~Detector(); - YOLODLL_API std::vector detect(std::string image_filename, float thresh = 0.2); - YOLODLL_API std::vector detect(image_t img, float thresh = 0.2); + YOLODLL_API std::vector detect(std::string image_filename, float thresh = 0.2, bool use_mean = false); + YOLODLL_API std::vector detect(image_t img, float thresh = 0.2, bool use_mean = false); static YOLODLL_API image_t load_image(std::string image_filename); static YOLODLL_API void free_image(image_t m); YOLODLL_API int get_net_width(); @@ -57,7 +57,7 @@ public: YOLODLL_API std::vector tracking(std::vector cur_bbox_vec, int const frames_story = 4); #ifdef OPENCV - std::vector detect(cv::Mat mat, float thresh = 0.2) + std::vector detect(cv::Mat mat, float thresh = 0.2, bool use_mean = false) { if(mat.data == NULL) throw std::runtime_error("file not found");