mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
added real numpy interface to detect
This commit is contained in:
parent
911663e082
commit
fee68e39ce
@ -13,9 +13,21 @@ def sample(probs):
|
||||
return len(probs)-1
|
||||
|
||||
def c_array(ctype, values):
|
||||
arr = (ctype*len(values))()
|
||||
arr[:] = values
|
||||
return arr
|
||||
new_values = values.ctypes.data_as(POINTER(ctype))
|
||||
return new_values
|
||||
|
||||
def array_to_image(arr):
|
||||
import numpy as np
|
||||
# need to return old values to avoid python freeing memory
|
||||
arr = arr.transpose(2,0,1)
|
||||
c = arr.shape[0]
|
||||
h = arr.shape[1]
|
||||
w = arr.shape[2]
|
||||
arr = np.ascontiguousarray(arr.flat, dtype=np.float32) / 255.0
|
||||
data = arr.ctypes.data_as(POINTER(c_float))
|
||||
im = IMAGE(w,h,c,data)
|
||||
return im, arr
|
||||
|
||||
|
||||
class BOX(Structure):
|
||||
_fields_ = [("x", c_float),
|
||||
@ -42,7 +54,6 @@ class METADATA(Structure):
|
||||
_fields_ = [("classes", c_int),
|
||||
("names", POINTER(c_char_p))]
|
||||
|
||||
|
||||
|
||||
#lib = CDLL("/home/pjreddie/documents/darknet/libdarknet.so", RTLD_GLOBAL)
|
||||
lib = CDLL("libdarknet.so", RTLD_GLOBAL)
|
||||
@ -141,7 +152,27 @@ def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
|
||||
free_image(im)
|
||||
free_detections(dets, num)
|
||||
return res
|
||||
|
||||
|
||||
def detect_numpy(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
|
||||
im, arr = array_to_image(image)
|
||||
num = c_int(0)
|
||||
pnum = pointer(num)
|
||||
predict_image(net, im)
|
||||
dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum)
|
||||
num = pnum[0]
|
||||
if (nms): do_nms_obj(dets, num, meta.classes, nms);
|
||||
|
||||
res = []
|
||||
for j in range(num):
|
||||
for i in range(meta.classes):
|
||||
if dets[j].prob[i] > 0:
|
||||
b = dets[j].bbox
|
||||
res.append((meta.names[i], dets[j].prob[i], (b.x, b.y, b.w, b.h)))
|
||||
res = sorted(res, key=lambda x: -x[1])
|
||||
free_detections(dets, num)
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#net = load_net("cfg/densenet201.cfg", "/home/pjreddie/trained/densenet201.weights", 0)
|
||||
#im = load_image("data/wolf.jpg", 0, 0)
|
||||
@ -150,7 +181,27 @@ if __name__ == "__main__":
|
||||
#print r[:10]
|
||||
net = load_net("cfg/tiny-yolo.cfg", "tiny-yolo.weights", 0)
|
||||
meta = load_meta("cfg/coco.data")
|
||||
r = detect(net, meta, "data/dog.jpg")
|
||||
import scipy.misc
|
||||
import time
|
||||
'''
|
||||
t_start = time.time()
|
||||
for ii in range(100):
|
||||
r = detect(net, meta, 'data/dog.jpg')
|
||||
print(time.time() - t_start)
|
||||
print(r)
|
||||
|
||||
image = scipy.misc.imread('data/dog.jpg')
|
||||
for ii in range(100):
|
||||
scipy.misc.imsave('/tmp/image.jpg', image)
|
||||
r = detect(net, meta, '/tmp/image.jpg')
|
||||
print(time.time() - t_start)
|
||||
print(r)
|
||||
'''
|
||||
|
||||
image = scipy.misc.imread('data/dog.jpg')
|
||||
t_start = time.time()
|
||||
for ii in range(100):
|
||||
r = detect_numpy(net, meta, image)
|
||||
print(time.time() - t_start)
|
||||
print(r)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user