snipplets.dev/projects/OpenVINO/C++/infer.cc
2024-08-24 15:24:01 +03:00

137 lines
4.8 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "infer.hpp"
Inf::Inf(const std::string &model_path, const float &model_probability, const float &model_NMS) {
input_shape = cv::Size(640, 640);
probability = model_probability;
NMS = model_NMS;
init(model_path);
};
Inf::Inf(const std::string &model_path, const cv::Size model_input_shape, const float &model_probability, const float &model_NMS) {
input_shape = model_input_shape;
probability = model_probability;
NMS = model_NMS;
init(model_path);
};
void Inf::init(const std::string &model_path) {
ov::Core core;
std::shared_ptr<ov::Model> model = core.read_model(model_path);
// Если модель имеет динамические формы,
// изменяем модель в соответствиии с указанной формой
if (model->is_dynamic()) {
model->reshape({1, 3, static_cast<long int>(input_shape.height), static_cast<long int>(input_shape.width)});
}
// Настройка предварительной обработки для модели
ov::preprocess::PrePostProcessor ppp = ov::preprocess::PrePostProcessor(model);
ppp.input().tensor().set_element_type(ov::element::u8).set_layout("NHWC").set_color_format(ov::preprocess::ColorFormat::BGR);
ppp.input()
.preprocess()
.convert_element_type(ov::element::f32)
.convert_color(ov::preprocess::ColorFormat::RGB)
.scale({255, 255, 255});
ppp.input().model().set_layout("NCHW");
ppp.output().tensor().set_element_type(ov::element::f32);
model = ppp.build();
compiled_model = core.compile_model(model, "AUTO");
inference_request = compiled_model.create_infer_request();
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
const ov::Shape in_shape = inputs[0].get_shape();
input_shape = cv::Size2f(in_shape[2], in_shape[1]);
const std::vector<ov::Output<ov::Node>> outputs = model->outputs();
const ov::Shape out_shape = outputs[0].get_shape();
output_shape = cv::Size(out_shape[2], out_shape[1]);
};
void Inf::pre(const cv::Mat &frame) {
cv::Mat resized_frame;
cv::resize(frame, resized_frame, input_shape, 0, 0, cv::INTER_AREA); // Resize the frame to match the model input shape
// Calculate scaling factor
scale_factor.x = static_cast<float>(frame.cols / input_shape.width);
scale_factor.y = static_cast<float>(frame.rows / input_shape.height);
float *input_data = (float *)resized_frame.data; // Get pointer to resized frame data
const ov::Tensor input_tensor =
ov::Tensor(compiled_model.input().get_element_type(), compiled_model.input().get_shape(), input_data);
inference_request.set_input_tensor(input_tensor); // Set input tensor for inference
};
void Inf::post(cv::Mat &frame) {
std::vector<int> class_list;
std::vector<float> confidence_list;
std::vector<cv::Rect> box_list;
const float *detections = inference_request.get_output_tensor().data<const float>();
const cv::Mat detection_outputs(output_shape, CV_32F, (float *)detections);
for (int i = 0; i < detection_outputs.cols; ++i) {
const cv::Mat classes_scores = detection_outputs.col(i).rowRange(4, detection_outputs.rows);
cv::Point class_id;
double score;
cv::minMaxLoc(classes_scores, nullptr, &score, nullptr, &class_id);
if (score > probability) {
class_list.push_back(class_id.y);
confidence_list.push_back(score);
const float x = detection_outputs.at<float>(0, i);
const float y = detection_outputs.at<float>(1, i);
const float w = detection_outputs.at<float>(2, i);
const float h = detection_outputs.at<float>(3, i);
cv::Rect box;
box.x = static_cast<int>(x);
box.y = static_cast<int>(y);
box.width = static_cast<int>(w);
box.height = static_cast<int>(h);
box_list.push_back(box);
}
}
std::vector<int> NMS_result;
cv::dnn::NMSBoxes(box_list, confidence_list, probability, NMS, NMS_result);
for (int i = 0; i < NMS_result.size(); ++i) {
Detection result;
const unsigned short id = NMS_result[i];
result.class_id = class_list[id];
result.probability = confidence_list[id];
result.box = GetBoundingBox(box_list[id]);
DrawDetectedObject(frame, result);
}
};
void Inf::inference(cv::Mat &frame) {
pre(frame);
inference_request.infer();
post(frame);
};
cv::Rect Inf::GetBoundingBox(const cv::Rect &src) const {
cv::Rect box = src;
box.x = (box.x - box.width / 2) * scale_factor.x;
box.y = (box.y - box.height / 2) * scale_factor.y;
box.width *= scale_factor.x;
box.height *= scale_factor.y;
return box;
}
void Inf::DrawDetectedObject(cv::Mat &frame, const Detection &detection) const {
const cv::Rect &box = detection.box;
const float &confidence = detection.probability;
const int &class_id = detection.class_id;
const cv::Scalar &color = cv::Scalar(0, 0, 180);
cv::rectangle(frame, cv::Point(box.x, box.y), cv::Point(box.x + box.width, box.y + box.height), color, 3);
}