mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
fix GRU, add LSTM
This commit is contained in:
parent
e9f3b79776
commit
604a760637
@ -63,7 +63,7 @@ typedef enum {
|
|||||||
ACTIVE,
|
ACTIVE,
|
||||||
RNN,
|
RNN,
|
||||||
GRU,
|
GRU,
|
||||||
LSTM,
|
LSTM,
|
||||||
CRNN,
|
CRNN,
|
||||||
BATCHNORM,
|
BATCHNORM,
|
||||||
NETWORK,
|
NETWORK,
|
||||||
@ -253,20 +253,20 @@ struct layer{
|
|||||||
struct layer *input_h_layer;
|
struct layer *input_h_layer;
|
||||||
struct layer *state_h_layer;
|
struct layer *state_h_layer;
|
||||||
|
|
||||||
struct layer *wz;
|
struct layer *wz;
|
||||||
struct layer *uz;
|
struct layer *uz;
|
||||||
struct layer *wr;
|
struct layer *wr;
|
||||||
struct layer *ur;
|
struct layer *ur;
|
||||||
struct layer *wh;
|
struct layer *wh;
|
||||||
struct layer *uh;
|
struct layer *uh;
|
||||||
struct layer *uo;
|
struct layer *uo;
|
||||||
struct layer *wo;
|
struct layer *wo;
|
||||||
struct layer *uf;
|
struct layer *uf;
|
||||||
struct layer *wf;
|
struct layer *wf;
|
||||||
struct layer *ui;
|
struct layer *ui;
|
||||||
struct layer *wi;
|
struct layer *wi;
|
||||||
struct layer *ug;
|
struct layer *ug;
|
||||||
struct layer *wg;
|
struct layer *wg;
|
||||||
|
|
||||||
tree *softmax_tree;
|
tree *softmax_tree;
|
||||||
|
|
||||||
@ -279,20 +279,20 @@ struct layer{
|
|||||||
float *r_gpu;
|
float *r_gpu;
|
||||||
float *h_gpu;
|
float *h_gpu;
|
||||||
|
|
||||||
float *temp_gpu;
|
float *temp_gpu;
|
||||||
float *temp2_gpu;
|
float *temp2_gpu;
|
||||||
float *temp3_gpu;
|
float *temp3_gpu;
|
||||||
|
|
||||||
float *dh_gpu;
|
float *dh_gpu;
|
||||||
float *hh_gpu;
|
float *hh_gpu;
|
||||||
float *prev_cell_gpu;
|
float *prev_cell_gpu;
|
||||||
float *cell_gpu;
|
float *cell_gpu;
|
||||||
float *f_gpu;
|
float *f_gpu;
|
||||||
float *i_gpu;
|
float *i_gpu;
|
||||||
float *g_gpu;
|
float *g_gpu;
|
||||||
float *o_gpu;
|
float *o_gpu;
|
||||||
float *c_gpu;
|
float *c_gpu;
|
||||||
float *dc_gpu;
|
float *dc_gpu;
|
||||||
|
|
||||||
float *m_gpu;
|
float *m_gpu;
|
||||||
float *v_gpu;
|
float *v_gpu;
|
||||||
@ -546,6 +546,7 @@ list *read_cfg(char *filename);
|
|||||||
#include "dropout_layer.h"
|
#include "dropout_layer.h"
|
||||||
#include "gemm.h"
|
#include "gemm.h"
|
||||||
#include "gru_layer.h"
|
#include "gru_layer.h"
|
||||||
|
#include "lstm_layer.h"
|
||||||
#include "im2col.h"
|
#include "im2col.h"
|
||||||
#include "image.h"
|
#include "image.h"
|
||||||
#include "layer.h"
|
#include "layer.h"
|
||||||
|
@ -185,7 +185,6 @@ void forward_gru_layer_gpu(layer l, network state)
|
|||||||
activate_array_ongpu(l.hh_gpu, l.outputs*l.batch, TANH);
|
activate_array_ongpu(l.hh_gpu, l.outputs*l.batch, TANH);
|
||||||
|
|
||||||
weighted_sum_gpu(l.h_gpu, l.hh_gpu, l.z_gpu, l.outputs*l.batch, l.output_gpu);
|
weighted_sum_gpu(l.h_gpu, l.hh_gpu, l.z_gpu, l.outputs*l.batch, l.output_gpu);
|
||||||
//ht = z .* ht-1 + (1-z) .* hh
|
|
||||||
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.h_gpu, 1);
|
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.h_gpu, 1);
|
||||||
|
|
||||||
state.input += l.inputs*l.batch;
|
state.input += l.inputs*l.batch;
|
||||||
|
@ -8,12 +8,12 @@
|
|||||||
|
|
||||||
layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize);
|
layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize);
|
||||||
|
|
||||||
void forward_lstm_layer(layer l, network state);
|
void forward_lstm_layer(layer l, network net);
|
||||||
void update_lstm_layer(layer l, int batch, float learning, float momentum, float decay);
|
void update_lstm_layer(layer l, int batch, float learning, float momentum, float decay);
|
||||||
|
|
||||||
#ifdef GPU
|
#ifdef GPU
|
||||||
void forward_lstm_layer_gpu(layer l, network state);
|
void forward_lstm_layer_gpu(layer l, network net);
|
||||||
void backward_lstm_layer_gpu(layer l, network state);
|
void backward_lstm_layer_gpu(layer l, network net);
|
||||||
void update_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay);
|
void update_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
74
src/parser.c
74
src/parser.c
@ -57,7 +57,7 @@ LAYER_TYPE string_to_layer_type(char * type)
|
|||||||
|| strcmp(type, "[network]")==0) return NETWORK;
|
|| strcmp(type, "[network]")==0) return NETWORK;
|
||||||
if (strcmp(type, "[crnn]")==0) return CRNN;
|
if (strcmp(type, "[crnn]")==0) return CRNN;
|
||||||
if (strcmp(type, "[gru]")==0) return GRU;
|
if (strcmp(type, "[gru]")==0) return GRU;
|
||||||
if (strcmp(type, "[lstm]") == 0) return LSTM;
|
if (strcmp(type, "[lstm]") == 0) return LSTM;
|
||||||
if (strcmp(type, "[rnn]")==0) return RNN;
|
if (strcmp(type, "[rnn]")==0) return RNN;
|
||||||
if (strcmp(type, "[conn]")==0
|
if (strcmp(type, "[conn]")==0
|
||||||
|| strcmp(type, "[connected]")==0) return CONNECTED;
|
|| strcmp(type, "[connected]")==0) return CONNECTED;
|
||||||
@ -678,8 +678,8 @@ network parse_network_cfg(char *filename)
|
|||||||
l = parse_rnn(options, params);
|
l = parse_rnn(options, params);
|
||||||
}else if(lt == GRU){
|
}else if(lt == GRU){
|
||||||
l = parse_gru(options, params);
|
l = parse_gru(options, params);
|
||||||
}else if (lt == LSTM) {
|
}else if (lt == LSTM) {
|
||||||
l = parse_lstm(options, params);
|
l = parse_lstm(options, params);
|
||||||
}else if(lt == CRNN){
|
}else if(lt == CRNN){
|
||||||
l = parse_crnn(options, params);
|
l = parse_crnn(options, params);
|
||||||
}else if(lt == CONNECTED){
|
}else if(lt == CONNECTED){
|
||||||
@ -921,22 +921,22 @@ void save_weights_upto(network net, char *filename, int cutoff)
|
|||||||
save_connected_weights(*(l.self_layer), fp);
|
save_connected_weights(*(l.self_layer), fp);
|
||||||
save_connected_weights(*(l.output_layer), fp);
|
save_connected_weights(*(l.output_layer), fp);
|
||||||
} if (l.type == LSTM) {
|
} if (l.type == LSTM) {
|
||||||
save_connected_weights(*(l.wi), fp);
|
save_connected_weights(*(l.wi), fp);
|
||||||
save_connected_weights(*(l.wf), fp);
|
save_connected_weights(*(l.wf), fp);
|
||||||
save_connected_weights(*(l.wo), fp);
|
save_connected_weights(*(l.wo), fp);
|
||||||
save_connected_weights(*(l.wg), fp);
|
save_connected_weights(*(l.wg), fp);
|
||||||
save_connected_weights(*(l.ui), fp);
|
save_connected_weights(*(l.ui), fp);
|
||||||
save_connected_weights(*(l.uf), fp);
|
save_connected_weights(*(l.uf), fp);
|
||||||
save_connected_weights(*(l.uo), fp);
|
save_connected_weights(*(l.uo), fp);
|
||||||
save_connected_weights(*(l.ug), fp);
|
save_connected_weights(*(l.ug), fp);
|
||||||
} if (l.type == GRU) {
|
} if (l.type == GRU) {
|
||||||
save_connected_weights(*(l.wz), fp);
|
save_connected_weights(*(l.wz), fp);
|
||||||
save_connected_weights(*(l.wr), fp);
|
save_connected_weights(*(l.wr), fp);
|
||||||
save_connected_weights(*(l.wh), fp);
|
save_connected_weights(*(l.wh), fp);
|
||||||
save_connected_weights(*(l.uz), fp);
|
save_connected_weights(*(l.uz), fp);
|
||||||
save_connected_weights(*(l.ur), fp);
|
save_connected_weights(*(l.ur), fp);
|
||||||
save_connected_weights(*(l.uh), fp);
|
save_connected_weights(*(l.uh), fp);
|
||||||
} if(l.type == CRNN){
|
} if(l.type == CRNN){
|
||||||
save_convolutional_weights(*(l.input_layer), fp);
|
save_convolutional_weights(*(l.input_layer), fp);
|
||||||
save_convolutional_weights(*(l.self_layer), fp);
|
save_convolutional_weights(*(l.self_layer), fp);
|
||||||
save_convolutional_weights(*(l.output_layer), fp);
|
save_convolutional_weights(*(l.output_layer), fp);
|
||||||
@ -1128,24 +1128,24 @@ void load_weights_upto(network *net, char *filename, int start, int cutoff)
|
|||||||
load_connected_weights(*(l.self_layer), fp, transpose);
|
load_connected_weights(*(l.self_layer), fp, transpose);
|
||||||
load_connected_weights(*(l.output_layer), fp, transpose);
|
load_connected_weights(*(l.output_layer), fp, transpose);
|
||||||
}
|
}
|
||||||
if (l.type == LSTM) {
|
if (l.type == LSTM) {
|
||||||
load_connected_weights(*(l.wi), fp, transpose);
|
load_connected_weights(*(l.wi), fp, transpose);
|
||||||
load_connected_weights(*(l.wf), fp, transpose);
|
load_connected_weights(*(l.wf), fp, transpose);
|
||||||
load_connected_weights(*(l.wo), fp, transpose);
|
load_connected_weights(*(l.wo), fp, transpose);
|
||||||
load_connected_weights(*(l.wg), fp, transpose);
|
load_connected_weights(*(l.wg), fp, transpose);
|
||||||
load_connected_weights(*(l.ui), fp, transpose);
|
load_connected_weights(*(l.ui), fp, transpose);
|
||||||
load_connected_weights(*(l.uf), fp, transpose);
|
load_connected_weights(*(l.uf), fp, transpose);
|
||||||
load_connected_weights(*(l.uo), fp, transpose);
|
load_connected_weights(*(l.uo), fp, transpose);
|
||||||
load_connected_weights(*(l.ug), fp, transpose);
|
load_connected_weights(*(l.ug), fp, transpose);
|
||||||
}
|
}
|
||||||
if (l.type == GRU) {
|
if (l.type == GRU) {
|
||||||
load_connected_weights(*(l.wz), fp, transpose);
|
load_connected_weights(*(l.wz), fp, transpose);
|
||||||
load_connected_weights(*(l.wr), fp, transpose);
|
load_connected_weights(*(l.wr), fp, transpose);
|
||||||
load_connected_weights(*(l.wh), fp, transpose);
|
load_connected_weights(*(l.wh), fp, transpose);
|
||||||
load_connected_weights(*(l.uz), fp, transpose);
|
load_connected_weights(*(l.uz), fp, transpose);
|
||||||
load_connected_weights(*(l.ur), fp, transpose);
|
load_connected_weights(*(l.ur), fp, transpose);
|
||||||
load_connected_weights(*(l.uh), fp, transpose);
|
load_connected_weights(*(l.uh), fp, transpose);
|
||||||
}
|
}
|
||||||
if(l.type == LOCAL){
|
if(l.type == LOCAL){
|
||||||
int locations = l.out_w*l.out_h;
|
int locations = l.out_w*l.out_h;
|
||||||
int size = l.size*l.size*l.c*l.n*locations;
|
int size = l.size*l.size*l.c*l.n*locations;
|
||||||
|
Loading…
Reference in New Issue
Block a user