mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Fixed LSTM-layer
This commit is contained in:
@ -716,10 +716,10 @@ int is_network(section *s)
|
||||
|
||||
network parse_network_cfg(char *filename)
|
||||
{
|
||||
return parse_network_cfg_custom(filename, 0);
|
||||
return parse_network_cfg_custom(filename, 0, 0);
|
||||
}
|
||||
|
||||
network parse_network_cfg_custom(char *filename, int batch)
|
||||
network parse_network_cfg_custom(char *filename, int batch, int time_steps)
|
||||
{
|
||||
list *sections = read_cfg(filename);
|
||||
node *n = sections->front;
|
||||
@ -738,6 +738,7 @@ network parse_network_cfg_custom(char *filename, int batch)
|
||||
params.c = net.c;
|
||||
params.inputs = net.inputs;
|
||||
if (batch > 0) net.batch = batch;
|
||||
if (time_steps > 0) net.time_steps = time_steps;
|
||||
params.batch = net.batch;
|
||||
params.time_steps = net.time_steps;
|
||||
params.net = net;
|
||||
@ -1300,7 +1301,7 @@ network *load_network_custom(char *cfg, char *weights, int clear, int batch)
|
||||
{
|
||||
printf(" Try to load cfg: %s, weights: %s, clear = %d \n", cfg, weights, clear);
|
||||
network *net = calloc(1, sizeof(network));
|
||||
*net = parse_network_cfg_custom(cfg, batch);
|
||||
*net = parse_network_cfg_custom(cfg, batch, 0);
|
||||
if (weights && weights[0] != 0) {
|
||||
load_weights(net, weights);
|
||||
}
|
||||
|
Reference in New Issue
Block a user