Fixed LSTM-layer

This commit is contained in:
AlexeyAB
2019-01-28 20:22:14 +03:00
parent 85b99872cb
commit 110b5240a4
14 changed files with 130 additions and 20 deletions

View File

@ -716,10 +716,10 @@ int is_network(section *s)
network parse_network_cfg(char *filename)
{
return parse_network_cfg_custom(filename, 0);
return parse_network_cfg_custom(filename, 0, 0);
}
network parse_network_cfg_custom(char *filename, int batch)
network parse_network_cfg_custom(char *filename, int batch, int time_steps)
{
list *sections = read_cfg(filename);
node *n = sections->front;
@ -738,6 +738,7 @@ network parse_network_cfg_custom(char *filename, int batch)
params.c = net.c;
params.inputs = net.inputs;
if (batch > 0) net.batch = batch;
if (time_steps > 0) net.time_steps = time_steps;
params.batch = net.batch;
params.time_steps = net.time_steps;
params.net = net;
@ -1300,7 +1301,7 @@ network *load_network_custom(char *cfg, char *weights, int clear, int batch)
{
printf(" Try to load cfg: %s, weights: %s, clear = %d \n", cfg, weights, clear);
network *net = calloc(1, sizeof(network));
*net = parse_network_cfg_custom(cfg, batch);
*net = parse_network_cfg_custom(cfg, batch, 0);
if (weights && weights[0] != 0) {
load_weights(net, weights);
}