fix GRU, add LSTM

This commit is contained in:
Yao Lu 2017-06-06 17:16:13 -07:00
parent e9f3b79776
commit 604a760637
5 changed files with 70 additions and 70 deletions

View File

@ -1,4 +1,4 @@
GPU=1 GPU=0
CUDNN=0 CUDNN=0
OPENCV=0 OPENCV=0
DEBUG=0 DEBUG=0

View File

@ -546,6 +546,7 @@ list *read_cfg(char *filename);
#include "dropout_layer.h" #include "dropout_layer.h"
#include "gemm.h" #include "gemm.h"
#include "gru_layer.h" #include "gru_layer.h"
#include "lstm_layer.h"
#include "im2col.h" #include "im2col.h"
#include "image.h" #include "image.h"
#include "layer.h" #include "layer.h"

View File

@ -185,7 +185,6 @@ void forward_gru_layer_gpu(layer l, network state)
activate_array_ongpu(l.hh_gpu, l.outputs*l.batch, TANH); activate_array_ongpu(l.hh_gpu, l.outputs*l.batch, TANH);
weighted_sum_gpu(l.h_gpu, l.hh_gpu, l.z_gpu, l.outputs*l.batch, l.output_gpu); weighted_sum_gpu(l.h_gpu, l.hh_gpu, l.z_gpu, l.outputs*l.batch, l.output_gpu);
//ht = z .* ht-1 + (1-z) .* hh
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.h_gpu, 1); copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.h_gpu, 1);
state.input += l.inputs*l.batch; state.input += l.inputs*l.batch;

View File

@ -8,12 +8,12 @@
layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize); layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize);
void forward_lstm_layer(layer l, network state); void forward_lstm_layer(layer l, network net);
void update_lstm_layer(layer l, int batch, float learning, float momentum, float decay); void update_lstm_layer(layer l, int batch, float learning, float momentum, float decay);
#ifdef GPU #ifdef GPU
void forward_lstm_layer_gpu(layer l, network state); void forward_lstm_layer_gpu(layer l, network net);
void backward_lstm_layer_gpu(layer l, network state); void backward_lstm_layer_gpu(layer l, network net);
void update_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay); void update_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay);
#endif #endif