mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
38 lines
900 B
Python
38 lines
900 B
Python
|
from darknet import *
|
||
|
|
||
|
def predict_tactic(net, s):
|
||
|
prob = 0
|
||
|
d = c_array(c_float, [0.0]*256)
|
||
|
tac = ''
|
||
|
if not len(s):
|
||
|
s = '\n'
|
||
|
for c in s[:-1]:
|
||
|
d[ord(c)] = 1
|
||
|
pred = predict(net, d)
|
||
|
d[ord(c)] = 0
|
||
|
c = s[-1]
|
||
|
while 1:
|
||
|
d[ord(c)] = 1
|
||
|
pred = predict(net, d)
|
||
|
d[ord(c)] = 0
|
||
|
pred = [pred[i] for i in range(256)]
|
||
|
ind = sample(pred)
|
||
|
c = chr(ind)
|
||
|
prob += math.log(pred[ind])
|
||
|
if len(tac) and tac[-1] == '.':
|
||
|
break
|
||
|
tac = tac + c
|
||
|
return (tac, prob)
|
||
|
|
||
|
def predict_tactics(net, s, n):
|
||
|
tacs = []
|
||
|
for i in range(n):
|
||
|
reset_rnn(net)
|
||
|
tacs.append(predict_tactic(net, s))
|
||
|
tacs = sorted(tacs, key=lambda x: -x[1])
|
||
|
return tacs
|
||
|
|
||
|
net = load_net("cfg/coq.test.cfg", "/home/pjreddie/backup/coq.backup", 0)
|
||
|
t = predict_tactics(net, "+++++\n", 10)
|
||
|
print t
|