idk man 🐍 stuff

This commit is contained in:
Joseph Redmon
2017-07-27 01:28:57 -07:00
parent 7a223d8591
commit 2f212a4742
7 changed files with 89 additions and 54 deletions

View File

@ -108,22 +108,6 @@ float_pair get_rnn_data(unsigned char *text, size_t *offsets, int characters, si
return p;
}
void reset_rnn_state(network net, int b)
{
int i;
for (i = 0; i < net.n; ++i) {
#ifdef GPU
layer l = net.layers[i];
if(l.state_gpu){
fill_gpu(l.outputs, 0, l.state_gpu + l.outputs*b, 1);
}
if(l.h_gpu){
fill_gpu(l.outputs, 0, l.h_gpu + l.outputs*b, 1);
}
#endif
}
}
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
{
srand(time(0));
@ -194,7 +178,7 @@ void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear,
if(rand()%64 == 0){
//fprintf(stderr, "Reset\n");
offsets[j] = rand_size_t()%size;
reset_rnn_state(net, j);
reset_network_state(net, j);
}
}
@ -304,7 +288,7 @@ void test_tactic_rnn_multi(char *cfgfile, char *weightfile, int num, float temp,
float *out = 0;
while(1){
reset_rnn_state(net, 0);
reset_network_state(net, 0);
while((c = getc(stdin)) != EOF && c != 0){
input[c] = 1;
out = network_predict(net, input);
@ -482,7 +466,7 @@ void vec_char_rnn(char *cfgfile, char *weightfile, char *seed)
int i;
char *line;
while((line=fgetl(stdin)) != 0){
reset_rnn_state(net, 0);
reset_network_state(net, 0);
for(i = 0; i < seed_len; ++i){
c = seed[i];
input[(int)c] = 1;