From 624a59307568212b7aecd9ae617bbcf4d94b8cec Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Thu, 27 Jul 2017 01:29:30 -0700 Subject: [PATCH] forgot a :snake: --- python/proverbot.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 python/proverbot.py diff --git a/python/proverbot.py b/python/proverbot.py new file mode 100644 index 00000000..095aae8f --- /dev/null +++ b/python/proverbot.py @@ -0,0 +1,37 @@ +from darknet import * + +def predict_tactic(net, s): + prob = 0 + d = c_array(c_float, [0.0]*256) + tac = '' + if not len(s): + s = '\n' + for c in s[:-1]: + d[ord(c)] = 1 + pred = predict(net, d) + d[ord(c)] = 0 + c = s[-1] + while 1: + d[ord(c)] = 1 + pred = predict(net, d) + d[ord(c)] = 0 + pred = [pred[i] for i in range(256)] + ind = sample(pred) + c = chr(ind) + prob += math.log(pred[ind]) + if len(tac) and tac[-1] == '.': + break + tac = tac + c + return (tac, prob) + +def predict_tactics(net, s, n): + tacs = [] + for i in range(n): + reset_rnn(net) + tacs.append(predict_tactic(net, s)) + tacs = sorted(tacs, key=lambda x: -x[1]) + return tacs + +net = load_net("cfg/coq.test.cfg", "/home/pjreddie/backup/coq.backup", 0) +t = predict_tactics(net, "+++++\n", 10) +print t