mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Fixed RNN (RNN, GRU, LSTM) with cuDNN (batch-norm)
This commit is contained in:
@ -10,15 +10,16 @@
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation, int batch_normalize)
|
||||
connected_layer make_connected_layer(int batch, int steps, int inputs, int outputs, ACTIVATION activation, int batch_normalize)
|
||||
{
|
||||
int total_batch = batch*steps;
|
||||
int i;
|
||||
connected_layer l = {0};
|
||||
l.type = CONNECTED;
|
||||
|
||||
l.inputs = inputs;
|
||||
l.outputs = outputs;
|
||||
l.batch=batch;
|
||||
l.batch= batch;
|
||||
l.batch_normalize = batch_normalize;
|
||||
l.h = 1;
|
||||
l.w = 1;
|
||||
@ -27,8 +28,8 @@ connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVAT
|
||||
l.out_w = 1;
|
||||
l.out_c = outputs;
|
||||
|
||||
l.output = calloc(batch*outputs, sizeof(float));
|
||||
l.delta = calloc(batch*outputs, sizeof(float));
|
||||
l.output = calloc(total_batch*outputs, sizeof(float));
|
||||
l.delta = calloc(total_batch*outputs, sizeof(float));
|
||||
|
||||
l.weight_updates = calloc(inputs*outputs, sizeof(float));
|
||||
l.bias_updates = calloc(outputs, sizeof(float));
|
||||
@ -65,8 +66,8 @@ connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVAT
|
||||
l.rolling_mean = calloc(outputs, sizeof(float));
|
||||
l.rolling_variance = calloc(outputs, sizeof(float));
|
||||
|
||||
l.x = calloc(batch*outputs, sizeof(float));
|
||||
l.x_norm = calloc(batch*outputs, sizeof(float));
|
||||
l.x = calloc(total_batch*outputs, sizeof(float));
|
||||
l.x_norm = calloc(total_batch*outputs, sizeof(float));
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
@ -80,8 +81,8 @@ connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVAT
|
||||
l.weight_updates_gpu = cuda_make_array(l.weight_updates, outputs*inputs);
|
||||
l.bias_updates_gpu = cuda_make_array(l.bias_updates, outputs);
|
||||
|
||||
l.output_gpu = cuda_make_array(l.output, outputs*batch);
|
||||
l.delta_gpu = cuda_make_array(l.delta, outputs*batch);
|
||||
l.output_gpu = cuda_make_array(l.output, outputs*total_batch);
|
||||
l.delta_gpu = cuda_make_array(l.delta, outputs*total_batch);
|
||||
if(batch_normalize){
|
||||
l.scales_gpu = cuda_make_array(l.scales, outputs);
|
||||
l.scale_updates_gpu = cuda_make_array(l.scale_updates, outputs);
|
||||
@ -95,13 +96,15 @@ connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVAT
|
||||
l.mean_delta_gpu = cuda_make_array(l.mean, outputs);
|
||||
l.variance_delta_gpu = cuda_make_array(l.variance, outputs);
|
||||
|
||||
l.x_gpu = cuda_make_array(l.output, l.batch*outputs);
|
||||
l.x_norm_gpu = cuda_make_array(l.output, l.batch*outputs);
|
||||
l.x_gpu = cuda_make_array(l.output, total_batch*outputs);
|
||||
l.x_norm_gpu = cuda_make_array(l.output, total_batch*outputs);
|
||||
#ifdef CUDNN
|
||||
cudnnCreateTensorDescriptor(&l.normDstTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.normTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
|
||||
cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
|
||||
cudnnSetTensor4dDescriptor(l.normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l.out_c, 1, 1);
|
||||
cudnnSetTensor4dDescriptor(l.normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
@ -147,7 +150,7 @@ void forward_connected_layer(connected_layer l, network_state state)
|
||||
axpy_cpu(l.outputs, .05, l.variance, 1, l.rolling_variance, 1);
|
||||
|
||||
copy_cpu(l.outputs*l.batch, l.output, 1, l.x, 1);
|
||||
normalize_cpu(l.output, l.mean, l.variance, l.batch, l.outputs, 1);
|
||||
normalize_cpu(l.output, l.mean, l.variance, l.batch, l.outputs, 1);
|
||||
copy_cpu(l.outputs*l.batch, l.output, 1, l.x_norm, 1);
|
||||
} else {
|
||||
normalize_cpu(l.output, l.rolling_mean, l.rolling_variance, l.batch, l.outputs, 1);
|
||||
|
Reference in New Issue
Block a user