diff --git a/src/yolo_console_dll.cpp b/src/yolo_console_dll.cpp index adbe873b..2e05f7ac 100644 --- a/src/yolo_console_dll.cpp +++ b/src/yolo_console_dll.cpp @@ -76,18 +76,19 @@ int main() cv::Mat frame, prev_frame; std::vector result_vec, thread_result_vec; detector.nms = 0.02; // comment it - if track_id is not required + std::thread td([]() {}); for (cv::VideoCapture cap(filename); cap >> frame, cap.isOpened();) { - auto image_ptr = detector.mat_to_image(frame); - std::thread td([&]() { thread_result_vec = detector.detect(*image_ptr, 0.2); }); + td.join(); + result_vec = thread_result_vec; + cv::Mat det_frame = frame; + td = std::thread([&]() { thread_result_vec = detector.detect(det_frame, 0.2); }); if (!prev_frame.empty()) { result_vec = detector.tracking(result_vec); // comment it - if track_id is not required draw_boxes(prev_frame, result_vec, obj_names, 3); show_result(result_vec, obj_names); } - td.join(); prev_frame = frame; - result_vec = thread_result_vec; } } else { // image file diff --git a/src/yolo_v2_class.cpp b/src/yolo_v2_class.cpp index 31f623cf..e8be427e 100644 --- a/src/yolo_v2_class.cpp +++ b/src/yolo_v2_class.cpp @@ -102,6 +102,15 @@ YOLODLL_API Detector::~Detector() #endif } +YOLODLL_API int Detector::get_net_width() { + detector_gpu_t &detector_gpu = *reinterpret_cast(detector_gpu_ptr.get()); + return detector_gpu.net.w; +} +YOLODLL_API int Detector::get_net_height() { + detector_gpu_t &detector_gpu = *reinterpret_cast(detector_gpu_ptr.get()); + return detector_gpu.net.h; +} + YOLODLL_API std::vector Detector::detect(std::string image_filename, float thresh) { diff --git a/src/yolo_v2_class.hpp b/src/yolo_v2_class.hpp index 9a8baa62..c6cad84b 100644 --- a/src/yolo_v2_class.hpp +++ b/src/yolo_v2_class.hpp @@ -51,6 +51,8 @@ public: YOLODLL_API std::vector detect(image_t img, float thresh = 0.2); static YOLODLL_API image_t load_image(std::string image_filename); static YOLODLL_API void free_image(image_t m); + YOLODLL_API int get_net_width(); + YOLODLL_API int get_net_height(); YOLODLL_API std::vector tracking(std::vector cur_bbox_vec, int const frames_story = 4); @@ -59,8 +61,13 @@ public: { if(mat.data == NULL) throw std::runtime_error("file not found"); - auto image_ptr = mat_to_image(mat); - return detect(*image_ptr, thresh); + cv::Mat det_mat; + cv::resize(mat, det_mat, cv::Size(get_net_width(), get_net_height())); + auto image_ptr = mat_to_image(det_mat); + auto detection_boxes = detect(*image_ptr, thresh); + float wk = (float)mat.cols / det_mat.cols, hk = (float)mat.rows / det_mat.rows; + for (auto &i : detection_boxes) i.x*=wk, i.w*= wk, i.y*=hk, i.h*=hk; + return detection_boxes; } static std::shared_ptr mat_to_image(cv::Mat img)