fix GRU, add LSTM

This commit is contained in:
Yao Lu 2017-06-06 17:16:13 -07:00
parent e9f3b79776
commit 604a760637
5 changed files with 70 additions and 70 deletions

View File

@ -1,4 +1,4 @@
GPU=1 GPU=0
CUDNN=0 CUDNN=0
OPENCV=0 OPENCV=0
DEBUG=0 DEBUG=0

View File

@ -63,7 +63,7 @@ typedef enum {
ACTIVE, ACTIVE,
RNN, RNN,
GRU, GRU,
LSTM, LSTM,
CRNN, CRNN,
BATCHNORM, BATCHNORM,
NETWORK, NETWORK,
@ -253,20 +253,20 @@ struct layer{
struct layer *input_h_layer; struct layer *input_h_layer;
struct layer *state_h_layer; struct layer *state_h_layer;
struct layer *wz; struct layer *wz;
struct layer *uz; struct layer *uz;
struct layer *wr; struct layer *wr;
struct layer *ur; struct layer *ur;
struct layer *wh; struct layer *wh;
struct layer *uh; struct layer *uh;
struct layer *uo; struct layer *uo;
struct layer *wo; struct layer *wo;
struct layer *uf; struct layer *uf;
struct layer *wf; struct layer *wf;
struct layer *ui; struct layer *ui;
struct layer *wi; struct layer *wi;
struct layer *ug; struct layer *ug;
struct layer *wg; struct layer *wg;
tree *softmax_tree; tree *softmax_tree;
@ -279,20 +279,20 @@ struct layer{
float *r_gpu; float *r_gpu;
float *h_gpu; float *h_gpu;
float *temp_gpu; float *temp_gpu;
float *temp2_gpu; float *temp2_gpu;
float *temp3_gpu; float *temp3_gpu;
float *dh_gpu; float *dh_gpu;
float *hh_gpu; float *hh_gpu;
float *prev_cell_gpu; float *prev_cell_gpu;
float *cell_gpu; float *cell_gpu;
float *f_gpu; float *f_gpu;
float *i_gpu; float *i_gpu;
float *g_gpu; float *g_gpu;
float *o_gpu; float *o_gpu;
float *c_gpu; float *c_gpu;
float *dc_gpu; float *dc_gpu;
float *m_gpu; float *m_gpu;
float *v_gpu; float *v_gpu;
@ -546,6 +546,7 @@ list *read_cfg(char *filename);
#include "dropout_layer.h" #include "dropout_layer.h"
#include "gemm.h" #include "gemm.h"
#include "gru_layer.h" #include "gru_layer.h"
#include "lstm_layer.h"
#include "im2col.h" #include "im2col.h"
#include "image.h" #include "image.h"
#include "layer.h" #include "layer.h"

View File

@ -185,7 +185,6 @@ void forward_gru_layer_gpu(layer l, network state)
activate_array_ongpu(l.hh_gpu, l.outputs*l.batch, TANH); activate_array_ongpu(l.hh_gpu, l.outputs*l.batch, TANH);
weighted_sum_gpu(l.h_gpu, l.hh_gpu, l.z_gpu, l.outputs*l.batch, l.output_gpu); weighted_sum_gpu(l.h_gpu, l.hh_gpu, l.z_gpu, l.outputs*l.batch, l.output_gpu);
//ht = z .* ht-1 + (1-z) .* hh
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.h_gpu, 1); copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.h_gpu, 1);
state.input += l.inputs*l.batch; state.input += l.inputs*l.batch;

View File

@ -8,12 +8,12 @@
layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize); layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize);
void forward_lstm_layer(layer l, network state); void forward_lstm_layer(layer l, network net);
void update_lstm_layer(layer l, int batch, float learning, float momentum, float decay); void update_lstm_layer(layer l, int batch, float learning, float momentum, float decay);
#ifdef GPU #ifdef GPU
void forward_lstm_layer_gpu(layer l, network state); void forward_lstm_layer_gpu(layer l, network net);
void backward_lstm_layer_gpu(layer l, network state); void backward_lstm_layer_gpu(layer l, network net);
void update_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay); void update_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay);
#endif #endif

View File

@ -57,7 +57,7 @@ LAYER_TYPE string_to_layer_type(char * type)
|| strcmp(type, "[network]")==0) return NETWORK; || strcmp(type, "[network]")==0) return NETWORK;
if (strcmp(type, "[crnn]")==0) return CRNN; if (strcmp(type, "[crnn]")==0) return CRNN;
if (strcmp(type, "[gru]")==0) return GRU; if (strcmp(type, "[gru]")==0) return GRU;
if (strcmp(type, "[lstm]") == 0) return LSTM; if (strcmp(type, "[lstm]") == 0) return LSTM;
if (strcmp(type, "[rnn]")==0) return RNN; if (strcmp(type, "[rnn]")==0) return RNN;
if (strcmp(type, "[conn]")==0 if (strcmp(type, "[conn]")==0
|| strcmp(type, "[connected]")==0) return CONNECTED; || strcmp(type, "[connected]")==0) return CONNECTED;
@ -678,8 +678,8 @@ network parse_network_cfg(char *filename)
l = parse_rnn(options, params); l = parse_rnn(options, params);
}else if(lt == GRU){ }else if(lt == GRU){
l = parse_gru(options, params); l = parse_gru(options, params);
}else if (lt == LSTM) { }else if (lt == LSTM) {
l = parse_lstm(options, params); l = parse_lstm(options, params);
}else if(lt == CRNN){ }else if(lt == CRNN){
l = parse_crnn(options, params); l = parse_crnn(options, params);
}else if(lt == CONNECTED){ }else if(lt == CONNECTED){
@ -921,22 +921,22 @@ void save_weights_upto(network net, char *filename, int cutoff)
save_connected_weights(*(l.self_layer), fp); save_connected_weights(*(l.self_layer), fp);
save_connected_weights(*(l.output_layer), fp); save_connected_weights(*(l.output_layer), fp);
} if (l.type == LSTM) { } if (l.type == LSTM) {
save_connected_weights(*(l.wi), fp); save_connected_weights(*(l.wi), fp);
save_connected_weights(*(l.wf), fp); save_connected_weights(*(l.wf), fp);
save_connected_weights(*(l.wo), fp); save_connected_weights(*(l.wo), fp);
save_connected_weights(*(l.wg), fp); save_connected_weights(*(l.wg), fp);
save_connected_weights(*(l.ui), fp); save_connected_weights(*(l.ui), fp);
save_connected_weights(*(l.uf), fp); save_connected_weights(*(l.uf), fp);
save_connected_weights(*(l.uo), fp); save_connected_weights(*(l.uo), fp);
save_connected_weights(*(l.ug), fp); save_connected_weights(*(l.ug), fp);
} if (l.type == GRU) { } if (l.type == GRU) {
save_connected_weights(*(l.wz), fp); save_connected_weights(*(l.wz), fp);
save_connected_weights(*(l.wr), fp); save_connected_weights(*(l.wr), fp);
save_connected_weights(*(l.wh), fp); save_connected_weights(*(l.wh), fp);
save_connected_weights(*(l.uz), fp); save_connected_weights(*(l.uz), fp);
save_connected_weights(*(l.ur), fp); save_connected_weights(*(l.ur), fp);
save_connected_weights(*(l.uh), fp); save_connected_weights(*(l.uh), fp);
} if(l.type == CRNN){ } if(l.type == CRNN){
save_convolutional_weights(*(l.input_layer), fp); save_convolutional_weights(*(l.input_layer), fp);
save_convolutional_weights(*(l.self_layer), fp); save_convolutional_weights(*(l.self_layer), fp);
save_convolutional_weights(*(l.output_layer), fp); save_convolutional_weights(*(l.output_layer), fp);
@ -1128,24 +1128,24 @@ void load_weights_upto(network *net, char *filename, int start, int cutoff)
load_connected_weights(*(l.self_layer), fp, transpose); load_connected_weights(*(l.self_layer), fp, transpose);
load_connected_weights(*(l.output_layer), fp, transpose); load_connected_weights(*(l.output_layer), fp, transpose);
} }
if (l.type == LSTM) { if (l.type == LSTM) {
load_connected_weights(*(l.wi), fp, transpose); load_connected_weights(*(l.wi), fp, transpose);
load_connected_weights(*(l.wf), fp, transpose); load_connected_weights(*(l.wf), fp, transpose);
load_connected_weights(*(l.wo), fp, transpose); load_connected_weights(*(l.wo), fp, transpose);
load_connected_weights(*(l.wg), fp, transpose); load_connected_weights(*(l.wg), fp, transpose);
load_connected_weights(*(l.ui), fp, transpose); load_connected_weights(*(l.ui), fp, transpose);
load_connected_weights(*(l.uf), fp, transpose); load_connected_weights(*(l.uf), fp, transpose);
load_connected_weights(*(l.uo), fp, transpose); load_connected_weights(*(l.uo), fp, transpose);
load_connected_weights(*(l.ug), fp, transpose); load_connected_weights(*(l.ug), fp, transpose);
} }
if (l.type == GRU) { if (l.type == GRU) {
load_connected_weights(*(l.wz), fp, transpose); load_connected_weights(*(l.wz), fp, transpose);
load_connected_weights(*(l.wr), fp, transpose); load_connected_weights(*(l.wr), fp, transpose);
load_connected_weights(*(l.wh), fp, transpose); load_connected_weights(*(l.wh), fp, transpose);
load_connected_weights(*(l.uz), fp, transpose); load_connected_weights(*(l.uz), fp, transpose);
load_connected_weights(*(l.ur), fp, transpose); load_connected_weights(*(l.ur), fp, transpose);
load_connected_weights(*(l.uh), fp, transpose); load_connected_weights(*(l.uh), fp, transpose);
} }
if(l.type == LOCAL){ if(l.type == LOCAL){
int locations = l.out_w*l.out_h; int locations = l.out_w*l.out_h;
int size = l.size*l.size*l.c*l.n*locations; int size = l.size*l.size*l.c*l.n*locations;