:charmandra: 🔥 🔥 🔥

This commit is contained in:
Joseph Redmon
2017-06-09 16:41:00 -07:00
parent c3e0d90e9f
commit d8c5cfd6c6
11 changed files with 635 additions and 292 deletions

View File

@ -150,7 +150,7 @@ void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear,
}
int inputs = net.inputs;
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g, Inputs: %d\n", net.learning_rate, net.momentum, net.decay, inputs);
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g, Inputs: %d %d %d\n", net.learning_rate, net.momentum, net.decay, inputs, net.batch, net.time_steps);
int batch = net.batch;
int steps = net.time_steps;
if(clear) *net.seen = 0;
@ -174,8 +174,8 @@ void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear,
p = get_rnn_data(text, offsets, inputs, size, streams, steps);
}
memcpy(net.input, p.x, net.inputs*net.batch);
memcpy(net.truth, p.y, net.truths*net.batch);
copy_cpu(net.inputs*net.batch, p.x, 1, net.input, 1);
copy_cpu(net.truths*net.batch, p.y, 1, net.truth, 1);
float loss = train_network_datum(net) / (batch);
free(p.x);
free(p.y);

View File

@ -99,8 +99,8 @@ void train_vid_rnn(char *cfgfile, char *weightfile)
time=clock();
float_pair p = get_rnn_vid_data(extractor, paths, N, batch, steps);
memcpy(net.input, p.x, net.inputs*net.batch);
memcpy(net.truth, p.y, net.truths*net.batch);
copy_cpu(net.inputs*net.batch, p.x, 1, net.input, 1);
copy_cpu(net.truths*net.batch, p.y, 1, net.truth, 1);
float loss = train_network_datum(net) / (net.batch);