mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
forgot a 🐍
This commit is contained in:
parent
2f212a4742
commit
624a593075
37
python/proverbot.py
Normal file
37
python/proverbot.py
Normal file
@ -0,0 +1,37 @@
|
||||
from darknet import *
|
||||
|
||||
def predict_tactic(net, s):
|
||||
prob = 0
|
||||
d = c_array(c_float, [0.0]*256)
|
||||
tac = ''
|
||||
if not len(s):
|
||||
s = '\n'
|
||||
for c in s[:-1]:
|
||||
d[ord(c)] = 1
|
||||
pred = predict(net, d)
|
||||
d[ord(c)] = 0
|
||||
c = s[-1]
|
||||
while 1:
|
||||
d[ord(c)] = 1
|
||||
pred = predict(net, d)
|
||||
d[ord(c)] = 0
|
||||
pred = [pred[i] for i in range(256)]
|
||||
ind = sample(pred)
|
||||
c = chr(ind)
|
||||
prob += math.log(pred[ind])
|
||||
if len(tac) and tac[-1] == '.':
|
||||
break
|
||||
tac = tac + c
|
||||
return (tac, prob)
|
||||
|
||||
def predict_tactics(net, s, n):
|
||||
tacs = []
|
||||
for i in range(n):
|
||||
reset_rnn(net)
|
||||
tacs.append(predict_tactic(net, s))
|
||||
tacs = sorted(tacs, key=lambda x: -x[1])
|
||||
return tacs
|
||||
|
||||
net = load_net("cfg/coq.test.cfg", "/home/pjreddie/backup/coq.backup", 0)
|
||||
t = predict_tactics(net, "+++++\n", 10)
|
||||
print t
|
Loading…
Reference in New Issue
Block a user