mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
LSTM, RNN, GRU - use connected_layer that uses cuDNN. Fixed CRNN for conv-layer with cuDNN.
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
#include "connected_layer.h"
|
||||
#include "batchnorm_layer.h"
|
||||
#include "convolutional_layer.h"
|
||||
#include "utils.h"
|
||||
#include "cuda.h"
|
||||
#include "blas.h"
|
||||
@ -10,6 +11,41 @@
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
static size_t get_connected_workspace_size(layer l) {
|
||||
#ifdef CUDNN
|
||||
if (gpu_index >= 0) {
|
||||
size_t most = 0;
|
||||
size_t s = 0;
|
||||
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
|
||||
l.srcTensorDesc,
|
||||
l.weightDesc,
|
||||
l.convDesc,
|
||||
l.dstTensorDesc,
|
||||
l.fw_algo,
|
||||
&s));
|
||||
if (s > most) most = s;
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
|
||||
l.srcTensorDesc,
|
||||
l.ddstTensorDesc,
|
||||
l.convDesc,
|
||||
l.dweightDesc,
|
||||
l.bf_algo,
|
||||
&s));
|
||||
if (s > most) most = s;
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
|
||||
l.weightDesc,
|
||||
l.ddstTensorDesc,
|
||||
l.convDesc,
|
||||
l.dsrcTensorDesc,
|
||||
l.bd_algo,
|
||||
&s));
|
||||
if (s > most) most = s;
|
||||
return most;
|
||||
}
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
connected_layer make_connected_layer(int batch, int steps, int inputs, int outputs, ACTIVATION activation, int batch_normalize)
|
||||
{
|
||||
int total_batch = batch*steps;
|
||||
@ -27,6 +63,10 @@ connected_layer make_connected_layer(int batch, int steps, int inputs, int outpu
|
||||
l.out_h = 1;
|
||||
l.out_w = 1;
|
||||
l.out_c = outputs;
|
||||
l.n = l.out_c;
|
||||
l.size = 1;
|
||||
l.stride = 1;
|
||||
l.pad = 0;
|
||||
|
||||
l.output = calloc(total_batch*outputs, sizeof(float));
|
||||
l.delta = calloc(total_batch*outputs, sizeof(float));
|
||||
@ -83,7 +123,7 @@ connected_layer make_connected_layer(int batch, int steps, int inputs, int outpu
|
||||
|
||||
l.output_gpu = cuda_make_array(l.output, outputs*total_batch);
|
||||
l.delta_gpu = cuda_make_array(l.delta, outputs*total_batch);
|
||||
if(batch_normalize){
|
||||
if (batch_normalize) {
|
||||
l.scales_gpu = cuda_make_array(l.scales, outputs);
|
||||
l.scale_updates_gpu = cuda_make_array(l.scale_updates, outputs);
|
||||
|
||||
@ -98,16 +138,13 @@ connected_layer make_connected_layer(int batch, int steps, int inputs, int outpu
|
||||
|
||||
l.x_gpu = cuda_make_array(l.output, total_batch*outputs);
|
||||
l.x_norm_gpu = cuda_make_array(l.output, total_batch*outputs);
|
||||
#ifdef CUDNN
|
||||
cudnnCreateTensorDescriptor(&l.normDstTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.normTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
|
||||
cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
|
||||
cudnnSetTensor4dDescriptor(l.normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l.out_c, 1, 1);
|
||||
cudnnSetTensor4dDescriptor(l.normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
#ifdef CUDNN
|
||||
create_convolutional_cudnn_tensors(&l);
|
||||
cudnn_convolutional_setup(&l, cudnn_fastest); // cudnn_fastest, cudnn_smallest
|
||||
l.workspace_size = get_connected_workspace_size(l);
|
||||
#endif // CUDNN
|
||||
#endif // GPU
|
||||
l.activation = activation;
|
||||
fprintf(stderr, "connected %4d -> %4d\n", inputs, outputs);
|
||||
return l;
|
||||
@ -288,7 +325,27 @@ void forward_connected_layer_gpu(connected_layer l, network_state state)
|
||||
float * a = state.input;
|
||||
float * b = l.weights_gpu;
|
||||
float * c = l.output_gpu;
|
||||
#ifdef CUDNN
|
||||
float one = 1; // alpha[0], beta[0]
|
||||
float alpha = 1, beta = 0;
|
||||
|
||||
CHECK_CUDNN(cudnnConvolutionForward(cudnn_handle(),
|
||||
&alpha, //&one,
|
||||
l.srcTensorDesc,
|
||||
state.input,
|
||||
l.weightDesc,
|
||||
l.weights_gpu,
|
||||
l.convDesc,
|
||||
l.fw_algo,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&beta, //&one,
|
||||
l.dstTensorDesc,
|
||||
l.output_gpu));
|
||||
#else // CUDNN
|
||||
gemm_ongpu(0,1,m,n,k,1,a,k,b,k,1,c,n);
|
||||
#endif // CUDNN
|
||||
|
||||
if (l.batch_normalize) {
|
||||
forward_batchnorm_layer_gpu(l, state);
|
||||
}
|
||||
@ -312,12 +369,51 @@ void backward_connected_layer_gpu(connected_layer l, network_state state)
|
||||
backward_batchnorm_layer_gpu(l, state);
|
||||
}
|
||||
|
||||
#ifdef CUDNN_DISABLED
|
||||
float one = 1;
|
||||
// calculate conv weight updates
|
||||
// if used: beta=1 then loss decreases faster
|
||||
CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(),
|
||||
&one,
|
||||
l.srcTensorDesc,
|
||||
state.input,
|
||||
l.ddstTensorDesc,
|
||||
l.delta_gpu,
|
||||
l.convDesc,
|
||||
l.bf_algo,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&one,
|
||||
l.dweightDesc,
|
||||
l.weight_updates_gpu));
|
||||
|
||||
if (state.delta) {
|
||||
// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
|
||||
// calculate delta for the next layer
|
||||
|
||||
CHECK_CUDNN(cudnnConvolutionBackwardData(cudnn_handle(),
|
||||
&one,
|
||||
l.weightDesc,
|
||||
l.weights_gpu,
|
||||
l.ddstTensorDesc,
|
||||
l.delta_gpu,
|
||||
l.convDesc,
|
||||
l.bd_algo,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&one,
|
||||
l.dsrcTensorDesc,
|
||||
state.delta));
|
||||
}
|
||||
#else // CUDNN
|
||||
|
||||
int m = l.outputs;
|
||||
int k = l.batch;
|
||||
int n = l.inputs;
|
||||
float * a = l.delta_gpu;
|
||||
float * b = state.input;
|
||||
float * c = l.weight_updates_gpu;
|
||||
|
||||
gemm_ongpu(1,0,m,n,k,1,a,m,b,n,1,c,n);
|
||||
|
||||
m = l.batch;
|
||||
@ -329,5 +425,6 @@ void backward_connected_layer_gpu(connected_layer l, network_state state)
|
||||
c = state.delta;
|
||||
|
||||
if(c) gemm_ongpu(0,0,m,n,k,1,a,k,b,n,1,c,n);
|
||||
#endif // CUDNN
|
||||
}
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user