add LSTM layer

This commit is contained in:
David Smith
2019-01-23 22:02:09 -06:00
parent 17019854c3
commit 5e778cd91e
5 changed files with 761 additions and 1 deletions

View File

@ -18,6 +18,7 @@
#include "gru_layer.h"
#include "list.h"
#include "local_layer.h"
#include "lstm_layer.h"
#include "maxpool_layer.h"
#include "normalization_layer.h"
#include "option_list.h"
@ -58,6 +59,7 @@ LAYER_TYPE string_to_layer_type(char * type)
|| strcmp(type, "[network]")==0) return NETWORK;
if (strcmp(type, "[crnn]")==0) return CRNN;
if (strcmp(type, "[gru]")==0) return GRU;
if (strcmp(type, "[lstm]")==0) return LSTM;
if (strcmp(type, "[rnn]")==0) return RNN;
if (strcmp(type, "[conn]")==0
|| strcmp(type, "[connected]")==0) return CONNECTED;
@ -219,6 +221,16 @@ layer parse_gru(list *options, size_params params)
return l;
}
layer parse_lstm(list *options, size_params params)
{
int output = option_find_int(options, "output",1);
int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
layer l = make_lstm_layer(params.batch, params.inputs, output, params.time_steps, batch_normalize);
return l;
}
connected_layer parse_connected(list *options, size_params params)
{
int output = option_find_int(options, "output",1);
@ -755,6 +767,8 @@ network parse_network_cfg_custom(char *filename, int batch)
l = parse_rnn(options, params);
}else if(lt == GRU){
l = parse_gru(options, params);
}else if(lt == LSTM){
l = parse_lstm(options, params);
}else if(lt == CRNN){
l = parse_crnn(options, params);
}else if(lt == CONNECTED){
@ -1025,6 +1039,15 @@ void save_weights_upto(network net, char *filename, int cutoff)
save_connected_weights(*(l.state_z_layer), fp);
save_connected_weights(*(l.state_r_layer), fp);
save_connected_weights(*(l.state_h_layer), fp);
} if(l.type == LSTM){
save_connected_weights(*(l.wf), fp);
save_connected_weights(*(l.wi), fp);
save_connected_weights(*(l.wg), fp);
save_connected_weights(*(l.wo), fp);
save_connected_weights(*(l.uf), fp);
save_connected_weights(*(l.ui), fp);
save_connected_weights(*(l.ug), fp);
save_connected_weights(*(l.uo), fp);
} if(l.type == CRNN){
save_convolutional_weights(*(l.input_layer), fp);
save_convolutional_weights(*(l.self_layer), fp);
@ -1236,6 +1259,16 @@ void load_weights_upto(network *net, char *filename, int cutoff)
load_connected_weights(*(l.state_r_layer), fp, transpose);
load_connected_weights(*(l.state_h_layer), fp, transpose);
}
if(l.type == LSTM){
load_connected_weights(*(l.wf), fp, transpose);
load_connected_weights(*(l.wi), fp, transpose);
load_connected_weights(*(l.wg), fp, transpose);
load_connected_weights(*(l.wo), fp, transpose);
load_connected_weights(*(l.uf), fp, transpose);
load_connected_weights(*(l.ui), fp, transpose);
load_connected_weights(*(l.ug), fp, transpose);
load_connected_weights(*(l.uo), fp, transpose);
}
if(l.type == LOCAL){
int locations = l.out_w*l.out_h;
int size = l.size*l.size*l.c*l.n*locations;
@ -1281,4 +1314,4 @@ network *load_network(char *cfg, char *weights, int clear)
}
if (clear) (*net->seen) = 0;
return net;
}
}