mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
GUYS I KNOW HOW TO MULTITHREAD :SNAKE:
This commit is contained in:
@ -279,6 +279,54 @@ void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float t
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
void test_tactic_rnn_multi(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
|
||||
{
|
||||
char **tokens = 0;
|
||||
if(token_file){
|
||||
size_t n;
|
||||
tokens = read_tokens(token_file, &n);
|
||||
}
|
||||
|
||||
srand(rseed);
|
||||
char *base = basecfg(cfgfile);
|
||||
fprintf(stderr, "%s\n", base);
|
||||
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
int inputs = net.inputs;
|
||||
|
||||
int i, j;
|
||||
for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
|
||||
int c = 0;
|
||||
float *input = calloc(inputs, sizeof(float));
|
||||
float *out = 0;
|
||||
|
||||
while(1){
|
||||
reset_rnn_state(net, 0);
|
||||
while((c = getc(stdin)) != EOF && c != 0){
|
||||
input[c] = 1;
|
||||
out = network_predict(net, input);
|
||||
input[c] = 0;
|
||||
}
|
||||
for(i = 0; i < num; ++i){
|
||||
for(j = 0; j < inputs; ++j){
|
||||
if (out[j] < .0001) out[j] = 0;
|
||||
}
|
||||
int next = sample_array(out, inputs);
|
||||
if(c == '.' && next == '\n') break;
|
||||
c = next;
|
||||
print_symbol(c, tokens);
|
||||
|
||||
input[c] = 1;
|
||||
out = network_predict(net, input);
|
||||
input[c] = 0;
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
|
||||
{
|
||||
char **tokens = 0;
|
||||
|
Reference in New Issue
Block a user