mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
add LSTM layer
This commit is contained in:
35
src/parser.c
35
src/parser.c
@ -18,6 +18,7 @@
|
||||
#include "gru_layer.h"
|
||||
#include "list.h"
|
||||
#include "local_layer.h"
|
||||
#include "lstm_layer.h"
|
||||
#include "maxpool_layer.h"
|
||||
#include "normalization_layer.h"
|
||||
#include "option_list.h"
|
||||
@ -58,6 +59,7 @@ LAYER_TYPE string_to_layer_type(char * type)
|
||||
|| strcmp(type, "[network]")==0) return NETWORK;
|
||||
if (strcmp(type, "[crnn]")==0) return CRNN;
|
||||
if (strcmp(type, "[gru]")==0) return GRU;
|
||||
if (strcmp(type, "[lstm]")==0) return LSTM;
|
||||
if (strcmp(type, "[rnn]")==0) return RNN;
|
||||
if (strcmp(type, "[conn]")==0
|
||||
|| strcmp(type, "[connected]")==0) return CONNECTED;
|
||||
@ -219,6 +221,16 @@ layer parse_gru(list *options, size_params params)
|
||||
return l;
|
||||
}
|
||||
|
||||
layer parse_lstm(list *options, size_params params)
|
||||
{
|
||||
int output = option_find_int(options, "output",1);
|
||||
int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
|
||||
|
||||
layer l = make_lstm_layer(params.batch, params.inputs, output, params.time_steps, batch_normalize);
|
||||
|
||||
return l;
|
||||
}
|
||||
|
||||
connected_layer parse_connected(list *options, size_params params)
|
||||
{
|
||||
int output = option_find_int(options, "output",1);
|
||||
@ -755,6 +767,8 @@ network parse_network_cfg_custom(char *filename, int batch)
|
||||
l = parse_rnn(options, params);
|
||||
}else if(lt == GRU){
|
||||
l = parse_gru(options, params);
|
||||
}else if(lt == LSTM){
|
||||
l = parse_lstm(options, params);
|
||||
}else if(lt == CRNN){
|
||||
l = parse_crnn(options, params);
|
||||
}else if(lt == CONNECTED){
|
||||
@ -1025,6 +1039,15 @@ void save_weights_upto(network net, char *filename, int cutoff)
|
||||
save_connected_weights(*(l.state_z_layer), fp);
|
||||
save_connected_weights(*(l.state_r_layer), fp);
|
||||
save_connected_weights(*(l.state_h_layer), fp);
|
||||
} if(l.type == LSTM){
|
||||
save_connected_weights(*(l.wf), fp);
|
||||
save_connected_weights(*(l.wi), fp);
|
||||
save_connected_weights(*(l.wg), fp);
|
||||
save_connected_weights(*(l.wo), fp);
|
||||
save_connected_weights(*(l.uf), fp);
|
||||
save_connected_weights(*(l.ui), fp);
|
||||
save_connected_weights(*(l.ug), fp);
|
||||
save_connected_weights(*(l.uo), fp);
|
||||
} if(l.type == CRNN){
|
||||
save_convolutional_weights(*(l.input_layer), fp);
|
||||
save_convolutional_weights(*(l.self_layer), fp);
|
||||
@ -1236,6 +1259,16 @@ void load_weights_upto(network *net, char *filename, int cutoff)
|
||||
load_connected_weights(*(l.state_r_layer), fp, transpose);
|
||||
load_connected_weights(*(l.state_h_layer), fp, transpose);
|
||||
}
|
||||
if(l.type == LSTM){
|
||||
load_connected_weights(*(l.wf), fp, transpose);
|
||||
load_connected_weights(*(l.wi), fp, transpose);
|
||||
load_connected_weights(*(l.wg), fp, transpose);
|
||||
load_connected_weights(*(l.wo), fp, transpose);
|
||||
load_connected_weights(*(l.uf), fp, transpose);
|
||||
load_connected_weights(*(l.ui), fp, transpose);
|
||||
load_connected_weights(*(l.ug), fp, transpose);
|
||||
load_connected_weights(*(l.uo), fp, transpose);
|
||||
}
|
||||
if(l.type == LOCAL){
|
||||
int locations = l.out_w*l.out_h;
|
||||
int size = l.size*l.size*l.c*l.n*locations;
|
||||
@ -1281,4 +1314,4 @@ network *load_network(char *cfg, char *weights, int clear)
|
||||
}
|
||||
if (clear) (*net->seen) = 0;
|
||||
return net;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user