GUYS I KNOW HOW TO MULTITHREAD :SNAKE:

This commit is contained in:
Joseph Redmon
2017-07-11 16:44:09 -07:00
parent 59ed1719d4
commit 616e6305e2
18 changed files with 319 additions and 151 deletions

View File

@ -279,6 +279,54 @@ void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float t
printf("\n");
}
void test_tactic_rnn_multi(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
{
char **tokens = 0;
if(token_file){
size_t n;
tokens = read_tokens(token_file, &n);
}
srand(rseed);
char *base = basecfg(cfgfile);
fprintf(stderr, "%s\n", base);
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
int inputs = net.inputs;
int i, j;
for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
int c = 0;
float *input = calloc(inputs, sizeof(float));
float *out = 0;
while(1){
reset_rnn_state(net, 0);
while((c = getc(stdin)) != EOF && c != 0){
input[c] = 1;
out = network_predict(net, input);
input[c] = 0;
}
for(i = 0; i < num; ++i){
for(j = 0; j < inputs; ++j){
if (out[j] < .0001) out[j] = 0;
}
int next = sample_array(out, inputs);
if(c == '.' && next == '\n') break;
c = next;
print_symbol(c, tokens);
input[c] = 1;
out = network_predict(net, input);
input[c] = 0;
}
printf("\n");
}
}
void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
{
char **tokens = 0;