batch inference refactoring

This commit is contained in:
enes
2019-10-19 16:18:44 +03:00
parent c435fca738
commit b392621e2e
4 changed files with 27 additions and 31 deletions

View File

@ -212,10 +212,10 @@ predict_image_letterbox = lib.network_predict_image_letterbox
predict_image_letterbox.argtypes = [c_void_p, IMAGE]
predict_image_letterbox.restype = POINTER(c_float)
network_predict_custom = lib.network_predict_custom
network_predict_custom.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int,
network_predict_batch = lib.network_predict_batch
network_predict_batch.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int,
c_float, c_float, POINTER(c_int), c_int, c_int]
network_predict_custom.restype = POINTER(DETNUMPAIR)
network_predict_batch.restype = POINTER(DETNUMPAIR)
def array_to_image(arr):
import numpy as np
@ -460,9 +460,6 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath
img_samples = ['data/person.jpg', 'data/person.jpg', 'data/person.jpg']
image_list = [cv2.imread(k) for k in img_samples]
if len(image_list) > batch_size:
raise ValueError(
"Please check if batch size is equal to the number of images passed to the function")
net = load_net_custom(configPath.encode('utf-8'), weightPath.encode('utf-8'), 0, batch_size)
meta = load_meta(metaPath.encode('utf-8'))
pred_height, pred_width, c = image_list[0].shape
@ -480,7 +477,7 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath
data = arr.ctypes.data_as(POINTER(c_float))
im = IMAGE(net_width, net_height, c, data)
batch_dets = network_predict_custom(net, im, batch_size, pred_width,
batch_dets = network_predict_batch(net, im, batch_size, pred_width,
pred_height, thresh, hier_thresh, None, 0, 0)
batch_boxes = []
batch_scores = []
@ -521,6 +518,6 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath
return batch_boxes, batch_scores, batch_classes
if __name__ == "__main__":
print(performDetect())
# Uncomment the following line to see batch inference working
#print(performBatchDetect())
#print(performDetect())
#Uncomment the following line to see batch inference working
print(performBatchDetect())