mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
batch inference refactoring
This commit is contained in:
17
darknet.py
17
darknet.py
@ -212,10 +212,10 @@ predict_image_letterbox = lib.network_predict_image_letterbox
|
||||
predict_image_letterbox.argtypes = [c_void_p, IMAGE]
|
||||
predict_image_letterbox.restype = POINTER(c_float)
|
||||
|
||||
network_predict_custom = lib.network_predict_custom
|
||||
network_predict_custom.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int,
|
||||
network_predict_batch = lib.network_predict_batch
|
||||
network_predict_batch.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int,
|
||||
c_float, c_float, POINTER(c_int), c_int, c_int]
|
||||
network_predict_custom.restype = POINTER(DETNUMPAIR)
|
||||
network_predict_batch.restype = POINTER(DETNUMPAIR)
|
||||
|
||||
def array_to_image(arr):
|
||||
import numpy as np
|
||||
@ -460,9 +460,6 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath
|
||||
img_samples = ['data/person.jpg', 'data/person.jpg', 'data/person.jpg']
|
||||
image_list = [cv2.imread(k) for k in img_samples]
|
||||
|
||||
if len(image_list) > batch_size:
|
||||
raise ValueError(
|
||||
"Please check if batch size is equal to the number of images passed to the function")
|
||||
net = load_net_custom(configPath.encode('utf-8'), weightPath.encode('utf-8'), 0, batch_size)
|
||||
meta = load_meta(metaPath.encode('utf-8'))
|
||||
pred_height, pred_width, c = image_list[0].shape
|
||||
@ -480,7 +477,7 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath
|
||||
data = arr.ctypes.data_as(POINTER(c_float))
|
||||
im = IMAGE(net_width, net_height, c, data)
|
||||
|
||||
batch_dets = network_predict_custom(net, im, batch_size, pred_width,
|
||||
batch_dets = network_predict_batch(net, im, batch_size, pred_width,
|
||||
pred_height, thresh, hier_thresh, None, 0, 0)
|
||||
batch_boxes = []
|
||||
batch_scores = []
|
||||
@ -521,6 +518,6 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath
|
||||
return batch_boxes, batch_scores, batch_classes
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(performDetect())
|
||||
# Uncomment the following line to see batch inference working
|
||||
#print(performBatchDetect())
|
||||
#print(performDetect())
|
||||
#Uncomment the following line to see batch inference working
|
||||
print(performBatchDetect())
|
Reference in New Issue
Block a user