From 604a7606372714647ee46b0fc89091073b0cc7c2 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Tue, 6 Jun 2017 17:16:13 -0700 Subject: [PATCH] fix GRU, add LSTM --- Makefile | 2 +- include/darknet.h | 57 ++++++++++++++++++------------------ src/gru_layer.c | 1 - src/lstm_layer.h | 6 ++-- src/parser.c | 74 +++++++++++++++++++++++------------------------ 5 files changed, 70 insertions(+), 70 deletions(-) diff --git a/Makefile b/Makefile index 9ef36b84..d4a78aad 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -GPU=1 +GPU=0 CUDNN=0 OPENCV=0 DEBUG=0 diff --git a/include/darknet.h b/include/darknet.h index f2ef660a..06d426cb 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -63,7 +63,7 @@ typedef enum { ACTIVE, RNN, GRU, - LSTM, + LSTM, CRNN, BATCHNORM, NETWORK, @@ -253,20 +253,20 @@ struct layer{ struct layer *input_h_layer; struct layer *state_h_layer; - struct layer *wz; - struct layer *uz; - struct layer *wr; - struct layer *ur; - struct layer *wh; - struct layer *uh; - struct layer *uo; - struct layer *wo; - struct layer *uf; - struct layer *wf; - struct layer *ui; - struct layer *wi; - struct layer *ug; - struct layer *wg; + struct layer *wz; + struct layer *uz; + struct layer *wr; + struct layer *ur; + struct layer *wh; + struct layer *uh; + struct layer *uo; + struct layer *wo; + struct layer *uf; + struct layer *wf; + struct layer *ui; + struct layer *wi; + struct layer *ug; + struct layer *wg; tree *softmax_tree; @@ -279,20 +279,20 @@ struct layer{ float *r_gpu; float *h_gpu; - float *temp_gpu; - float *temp2_gpu; - float *temp3_gpu; + float *temp_gpu; + float *temp2_gpu; + float *temp3_gpu; - float *dh_gpu; - float *hh_gpu; - float *prev_cell_gpu; - float *cell_gpu; - float *f_gpu; - float *i_gpu; - float *g_gpu; - float *o_gpu; - float *c_gpu; - float *dc_gpu; + float *dh_gpu; + float *hh_gpu; + float *prev_cell_gpu; + float *cell_gpu; + float *f_gpu; + float *i_gpu; + float *g_gpu; + float *o_gpu; + float *c_gpu; + float *dc_gpu; float *m_gpu; float *v_gpu; @@ -546,6 +546,7 @@ list *read_cfg(char *filename); #include "dropout_layer.h" #include "gemm.h" #include "gru_layer.h" +#include "lstm_layer.h" #include "im2col.h" #include "image.h" #include "layer.h" diff --git a/src/gru_layer.c b/src/gru_layer.c index 78964817..917c36f9 100644 --- a/src/gru_layer.c +++ b/src/gru_layer.c @@ -185,7 +185,6 @@ void forward_gru_layer_gpu(layer l, network state) activate_array_ongpu(l.hh_gpu, l.outputs*l.batch, TANH); weighted_sum_gpu(l.h_gpu, l.hh_gpu, l.z_gpu, l.outputs*l.batch, l.output_gpu); - //ht = z .* ht-1 + (1-z) .* hh copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.h_gpu, 1); state.input += l.inputs*l.batch; diff --git a/src/lstm_layer.h b/src/lstm_layer.h index a9ed792d..8ed387af 100644 --- a/src/lstm_layer.h +++ b/src/lstm_layer.h @@ -8,12 +8,12 @@ layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize); -void forward_lstm_layer(layer l, network state); +void forward_lstm_layer(layer l, network net); void update_lstm_layer(layer l, int batch, float learning, float momentum, float decay); #ifdef GPU -void forward_lstm_layer_gpu(layer l, network state); -void backward_lstm_layer_gpu(layer l, network state); +void forward_lstm_layer_gpu(layer l, network net); +void backward_lstm_layer_gpu(layer l, network net); void update_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay); #endif diff --git a/src/parser.c b/src/parser.c index 499be075..16012212 100644 --- a/src/parser.c +++ b/src/parser.c @@ -57,7 +57,7 @@ LAYER_TYPE string_to_layer_type(char * type) || strcmp(type, "[network]")==0) return NETWORK; if (strcmp(type, "[crnn]")==0) return CRNN; if (strcmp(type, "[gru]")==0) return GRU; - if (strcmp(type, "[lstm]") == 0) return LSTM; + if (strcmp(type, "[lstm]") == 0) return LSTM; if (strcmp(type, "[rnn]")==0) return RNN; if (strcmp(type, "[conn]")==0 || strcmp(type, "[connected]")==0) return CONNECTED; @@ -678,8 +678,8 @@ network parse_network_cfg(char *filename) l = parse_rnn(options, params); }else if(lt == GRU){ l = parse_gru(options, params); - }else if (lt == LSTM) { - l = parse_lstm(options, params); + }else if (lt == LSTM) { + l = parse_lstm(options, params); }else if(lt == CRNN){ l = parse_crnn(options, params); }else if(lt == CONNECTED){ @@ -921,22 +921,22 @@ void save_weights_upto(network net, char *filename, int cutoff) save_connected_weights(*(l.self_layer), fp); save_connected_weights(*(l.output_layer), fp); } if (l.type == LSTM) { - save_connected_weights(*(l.wi), fp); - save_connected_weights(*(l.wf), fp); - save_connected_weights(*(l.wo), fp); - save_connected_weights(*(l.wg), fp); - save_connected_weights(*(l.ui), fp); - save_connected_weights(*(l.uf), fp); - save_connected_weights(*(l.uo), fp); - save_connected_weights(*(l.ug), fp); - } if (l.type == GRU) { - save_connected_weights(*(l.wz), fp); - save_connected_weights(*(l.wr), fp); - save_connected_weights(*(l.wh), fp); - save_connected_weights(*(l.uz), fp); - save_connected_weights(*(l.ur), fp); - save_connected_weights(*(l.uh), fp); - } if(l.type == CRNN){ + save_connected_weights(*(l.wi), fp); + save_connected_weights(*(l.wf), fp); + save_connected_weights(*(l.wo), fp); + save_connected_weights(*(l.wg), fp); + save_connected_weights(*(l.ui), fp); + save_connected_weights(*(l.uf), fp); + save_connected_weights(*(l.uo), fp); + save_connected_weights(*(l.ug), fp); + } if (l.type == GRU) { + save_connected_weights(*(l.wz), fp); + save_connected_weights(*(l.wr), fp); + save_connected_weights(*(l.wh), fp); + save_connected_weights(*(l.uz), fp); + save_connected_weights(*(l.ur), fp); + save_connected_weights(*(l.uh), fp); + } if(l.type == CRNN){ save_convolutional_weights(*(l.input_layer), fp); save_convolutional_weights(*(l.self_layer), fp); save_convolutional_weights(*(l.output_layer), fp); @@ -1128,24 +1128,24 @@ void load_weights_upto(network *net, char *filename, int start, int cutoff) load_connected_weights(*(l.self_layer), fp, transpose); load_connected_weights(*(l.output_layer), fp, transpose); } - if (l.type == LSTM) { - load_connected_weights(*(l.wi), fp, transpose); - load_connected_weights(*(l.wf), fp, transpose); - load_connected_weights(*(l.wo), fp, transpose); - load_connected_weights(*(l.wg), fp, transpose); - load_connected_weights(*(l.ui), fp, transpose); - load_connected_weights(*(l.uf), fp, transpose); - load_connected_weights(*(l.uo), fp, transpose); - load_connected_weights(*(l.ug), fp, transpose); - } - if (l.type == GRU) { - load_connected_weights(*(l.wz), fp, transpose); - load_connected_weights(*(l.wr), fp, transpose); - load_connected_weights(*(l.wh), fp, transpose); - load_connected_weights(*(l.uz), fp, transpose); - load_connected_weights(*(l.ur), fp, transpose); - load_connected_weights(*(l.uh), fp, transpose); - } + if (l.type == LSTM) { + load_connected_weights(*(l.wi), fp, transpose); + load_connected_weights(*(l.wf), fp, transpose); + load_connected_weights(*(l.wo), fp, transpose); + load_connected_weights(*(l.wg), fp, transpose); + load_connected_weights(*(l.ui), fp, transpose); + load_connected_weights(*(l.uf), fp, transpose); + load_connected_weights(*(l.uo), fp, transpose); + load_connected_weights(*(l.ug), fp, transpose); + } + if (l.type == GRU) { + load_connected_weights(*(l.wz), fp, transpose); + load_connected_weights(*(l.wr), fp, transpose); + load_connected_weights(*(l.wh), fp, transpose); + load_connected_weights(*(l.uz), fp, transpose); + load_connected_weights(*(l.ur), fp, transpose); + load_connected_weights(*(l.uh), fp, transpose); + } if(l.type == LOCAL){ int locations = l.out_w*l.out_h; int size = l.size*l.size*l.c*l.n*locations;