idk man 🐍 stuff

This commit is contained in:
Joseph Redmon
2017-07-27 01:28:57 -07:00
parent 7a223d8591
commit 2f212a4742
7 changed files with 89 additions and 54 deletions

View File

@ -1,4 +1,19 @@
from ctypes import *
import math
import random
def sample(probs):
s = sum(probs)
probs = [a/s for a in probs]
r = random.uniform(0, 1)
for i in range(len(probs)):
r = r - probs[i]
if r <= 0:
return i
return len(probs)-1
def c_array(ctype, values):
return (ctype * len(values))(*values)
class IMAGE(Structure):
_fields_ = [("w", c_int),
@ -10,43 +25,42 @@ class METADATA(Structure):
_fields_ = [("classes", c_int),
("names", POINTER(c_char_p))]
lib = CDLL("/home/pjreddie/documents/darknet/libdarknet.so", RTLD_GLOBAL)
#lib = CDLL("/home/pjreddie/documents/darknet/libdarknet.so", RTLD_GLOBAL)
lib = CDLL("libdarknet.so", RTLD_GLOBAL)
lib.network_width.argtypes = [c_void_p]
lib.network_width.restype = c_int
lib.network_height.argtypes = [c_void_p]
lib.network_height.restype = c_int
def load_meta(f):
lib.get_metadata.argtypes = [c_char_p]
lib.get_metadata.restype = METADATA
return lib.get_metadata(f)
predict = lib.network_predict_p
predict.argtypes = [c_void_p, POINTER(c_float)]
predict.restype = POINTER(c_float)
def load_net(cfg, weights):
load_network = lib.load_network_p
load_network.argtypes = [c_char_p, c_char_p, c_int]
load_network.restype = c_void_p
return load_network(cfg, weights, 0)
reset_rnn = lib.reset_rnn
reset_rnn.argtypes = [c_void_p]
def load_img(f):
load_image = lib.load_image_color
load_image.argtypes = [c_char_p, c_int, c_int]
load_image.restype = IMAGE
return load_image(f, 0, 0)
load_net = lib.load_network_p
load_net.argtypes = [c_char_p, c_char_p, c_int]
load_net.restype = c_void_p
def letterbox_img(im, w, h):
letterbox_image = lib.letterbox_image
letterbox_image.argtypes = [IMAGE, c_int, c_int]
letterbox_image.restype = IMAGE
return letterbox_image(im, w, h)
letterbox_image = lib.letterbox_image
letterbox_image.argtypes = [IMAGE, c_int, c_int]
letterbox_image.restype = IMAGE
def predict(net, im):
pred = lib.network_predict_image
pred.argtypes = [c_void_p, IMAGE]
pred.restype = POINTER(c_float)
return pred(net, im)
load_meta = lib.get_metadata
lib.get_metadata.argtypes = [c_char_p]
lib.get_metadata.restype = METADATA
load_image = lib.load_image_color
load_image.argtypes = [c_char_p, c_int, c_int]
load_image.restype = IMAGE
predict_image = lib.network_predict_image
predict_image.argtypes = [c_void_p, IMAGE]
predict_image.restype = POINTER(c_float)
def classify(net, meta, im):
out = predict(net, im)
out = predict_image(net, im)
res = []
for i in range(meta.classes):
res.append((meta.names[i], out[i]))
@ -54,17 +68,19 @@ def classify(net, meta, im):
return res
def detect(net, meta, im):
out = predict(net, im)
out = predict_image(net, im)
res = []
for i in range(meta.classes):
res.append((meta.names[i], out[i]))
res = sorted(res, key=lambda x: -x[1])
return res
if __name__ == "__main__":
net = load_net("cfg/densenet.cfg", "/home/pjreddie/trained/densenet201.weights")
im = load_img("data/wolf.jpg")
net = load_net("cfg/densenet201.cfg", "/home/pjreddie/trained/densenet201.weights", 0)
im = load_image("data/wolf.jpg", 0, 0)
meta = load_meta("cfg/imagenet1k.data")
r = classify(net, meta, im)
print r[:10]