mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
LSTM, RNN, GRU - use connected_layer that uses cuDNN. Fixed CRNN for conv-layer with cuDNN.
This commit is contained in:
@ -105,29 +105,29 @@ size_t get_workspace_size(layer l){
|
||||
if(gpu_index >= 0){
|
||||
size_t most = 0;
|
||||
size_t s = 0;
|
||||
cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
|
||||
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
|
||||
l.srcTensorDesc,
|
||||
l.weightDesc,
|
||||
l.convDesc,
|
||||
l.dstTensorDesc,
|
||||
l.fw_algo,
|
||||
&s);
|
||||
&s));
|
||||
if (s > most) most = s;
|
||||
cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
|
||||
l.srcTensorDesc,
|
||||
l.ddstTensorDesc,
|
||||
l.convDesc,
|
||||
l.dweightDesc,
|
||||
l.bf_algo,
|
||||
&s);
|
||||
&s));
|
||||
if (s > most) most = s;
|
||||
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
|
||||
l.weightDesc,
|
||||
l.ddstTensorDesc,
|
||||
l.convDesc,
|
||||
l.dsrcTensorDesc,
|
||||
l.bd_algo,
|
||||
&s);
|
||||
&s));
|
||||
if (s > most) most = s;
|
||||
return most;
|
||||
}
|
||||
@ -141,29 +141,29 @@ size_t get_workspace_size16(layer l) {
|
||||
if (gpu_index >= 0) {
|
||||
size_t most = 0;
|
||||
size_t s = 0;
|
||||
cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
|
||||
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
|
||||
l.srcTensorDesc16,
|
||||
l.weightDesc16,
|
||||
l.convDesc,
|
||||
l.dstTensorDesc16,
|
||||
l.fw_algo16,
|
||||
&s);
|
||||
&s));
|
||||
if (s > most) most = s;
|
||||
cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
|
||||
l.srcTensorDesc16,
|
||||
l.ddstTensorDesc16,
|
||||
l.convDesc,
|
||||
l.dweightDesc16,
|
||||
l.bf_algo16,
|
||||
&s);
|
||||
&s));
|
||||
if (s > most) most = s;
|
||||
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
|
||||
l.weightDesc16,
|
||||
l.ddstTensorDesc16,
|
||||
l.convDesc,
|
||||
l.dsrcTensorDesc16,
|
||||
l.bd_algo16,
|
||||
&s);
|
||||
&s));
|
||||
if (s > most) most = s;
|
||||
return most;
|
||||
}
|
||||
@ -175,6 +175,29 @@ size_t get_workspace_size16(layer l) {
|
||||
|
||||
#ifdef GPU
|
||||
#ifdef CUDNN
|
||||
void create_convolutional_cudnn_tensors(layer *l)
|
||||
{
|
||||
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->normTensorDesc));
|
||||
|
||||
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->normDstTensorDesc));
|
||||
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->srcTensorDesc));
|
||||
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->dstTensorDesc));
|
||||
CHECK_CUDNN(cudnnCreateFilterDescriptor(&l->weightDesc));
|
||||
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->dsrcTensorDesc));
|
||||
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->ddstTensorDesc));
|
||||
CHECK_CUDNN(cudnnCreateFilterDescriptor(&l->dweightDesc));
|
||||
|
||||
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->normDstTensorDescF16));
|
||||
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->srcTensorDesc16));
|
||||
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->dstTensorDesc16));
|
||||
CHECK_CUDNN(cudnnCreateFilterDescriptor(&l->weightDesc16));
|
||||
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->dsrcTensorDesc16));
|
||||
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->ddstTensorDesc16));
|
||||
CHECK_CUDNN(cudnnCreateFilterDescriptor(&l->dweightDesc16));
|
||||
|
||||
CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&l->convDesc));
|
||||
}
|
||||
|
||||
void cudnn_convolutional_setup(layer *l, int cudnn_preference)
|
||||
{
|
||||
|
||||
@ -194,9 +217,9 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
|
||||
// 2. Loss Scaling - required only for: activation gradients. We do not use.
|
||||
// 3. FP32 Master Copy of Weights
|
||||
// More: http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#tensor_ops
|
||||
cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH);
|
||||
CHECK_CUDNN(cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH));
|
||||
#if((CUDNN_MAJOR*10 + CUDNN_MINOR) >= 72) // cuDNN >= 7.2
|
||||
cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION);
|
||||
CHECK_CUDNN(cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@ -205,38 +228,38 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
|
||||
//cudnnDataType_t data_type = CUDNN_DATA_INT8;
|
||||
|
||||
// backward delta
|
||||
cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w);
|
||||
cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
|
||||
cudnnSetFilter4dDescriptor(l->dweightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
||||
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w));
|
||||
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w));
|
||||
CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->dweightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size));
|
||||
|
||||
// forward
|
||||
cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w);
|
||||
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
|
||||
cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
||||
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w));
|
||||
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w));
|
||||
CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size));
|
||||
|
||||
//#ifdef CUDNN_HALF
|
||||
// backward delta
|
||||
cudnnSetTensor4dDescriptor(l->dsrcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w);
|
||||
cudnnSetTensor4dDescriptor(l->ddstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w);
|
||||
cudnnSetFilter4dDescriptor(l->dweightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
||||
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dsrcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w));
|
||||
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->ddstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w));
|
||||
CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->dweightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size));
|
||||
|
||||
// forward
|
||||
cudnnSetTensor4dDescriptor(l->srcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w);
|
||||
cudnnSetTensor4dDescriptor(l->dstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w);
|
||||
cudnnSetFilter4dDescriptor(l->weightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
||||
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->srcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w));
|
||||
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w));
|
||||
CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->weightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size));
|
||||
|
||||
// batch norm
|
||||
cudnnSetTensor4dDescriptor(l->normDstTensorDescF16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w);
|
||||
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normDstTensorDescF16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w));
|
||||
//#endif
|
||||
|
||||
// batch norm
|
||||
cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1);
|
||||
cudnnSetTensor4dDescriptor(l->normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
|
||||
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1));
|
||||
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w));
|
||||
|
||||
#if(CUDNN_MAJOR >= 6)
|
||||
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT); // cudnn >= 6.0
|
||||
CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT)); // cudnn >= 6.0
|
||||
#else
|
||||
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION); // cudnn 5.1
|
||||
CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION)); // cudnn 5.1
|
||||
#endif
|
||||
int forward_algo = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
|
||||
int backward_algo = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
|
||||
@ -249,30 +272,30 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
|
||||
printf(" CUDNN-slow ");
|
||||
}
|
||||
|
||||
cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
|
||||
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
|
||||
l->srcTensorDesc,
|
||||
l->weightDesc,
|
||||
l->convDesc,
|
||||
l->dstTensorDesc,
|
||||
forward_algo,
|
||||
0,
|
||||
&l->fw_algo);
|
||||
cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
|
||||
&l->fw_algo));
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
|
||||
l->weightDesc,
|
||||
l->ddstTensorDesc,
|
||||
l->convDesc,
|
||||
l->dsrcTensorDesc,
|
||||
backward_algo,
|
||||
0,
|
||||
&l->bd_algo);
|
||||
cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
|
||||
&l->bd_algo));
|
||||
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
|
||||
l->srcTensorDesc,
|
||||
l->ddstTensorDesc,
|
||||
l->convDesc,
|
||||
l->dweightDesc,
|
||||
backward_filter,
|
||||
0,
|
||||
&l->bf_algo);
|
||||
&l->bf_algo));
|
||||
|
||||
//if (data_type == CUDNN_DATA_HALF)
|
||||
{
|
||||
@ -290,8 +313,9 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index)
|
||||
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index)
|
||||
{
|
||||
int total_batch = batch*steps;
|
||||
int i;
|
||||
convolutional_layer l = {0};
|
||||
l.type = CONVOLUTIONAL;
|
||||
@ -327,8 +351,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
l.outputs = l.out_h * l.out_w * l.out_c;
|
||||
l.inputs = l.w * l.h * l.c;
|
||||
|
||||
l.output = calloc(l.batch*l.outputs, sizeof(float));
|
||||
l.delta = calloc(l.batch*l.outputs, sizeof(float));
|
||||
l.output = calloc(total_batch*l.outputs, sizeof(float));
|
||||
l.delta = calloc(total_batch*l.outputs, sizeof(float));
|
||||
|
||||
l.forward = forward_convolutional_layer;
|
||||
l.backward = backward_convolutional_layer;
|
||||
@ -364,8 +388,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
|
||||
l.rolling_mean = calloc(n, sizeof(float));
|
||||
l.rolling_variance = calloc(n, sizeof(float));
|
||||
l.x = calloc(l.batch*l.outputs, sizeof(float));
|
||||
l.x_norm = calloc(l.batch*l.outputs, sizeof(float));
|
||||
l.x = calloc(total_batch*l.outputs, sizeof(float));
|
||||
l.x_norm = calloc(total_batch*l.outputs, sizeof(float));
|
||||
}
|
||||
if(adam){
|
||||
l.adam = 1;
|
||||
@ -402,8 +426,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
l.biases_gpu = cuda_make_array(l.biases, n);
|
||||
l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
|
||||
|
||||
l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
|
||||
l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
||||
l.delta_gpu = cuda_make_array(l.delta, total_batch*out_h*out_w*n);
|
||||
l.output_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
|
||||
|
||||
if(binary){
|
||||
l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size);
|
||||
@ -427,29 +451,11 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
l.scales_gpu = cuda_make_array(l.scales, n);
|
||||
l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
|
||||
|
||||
l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
||||
l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
||||
l.x_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
|
||||
l.x_norm_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
|
||||
}
|
||||
#ifdef CUDNN
|
||||
cudnnCreateTensorDescriptor(&l.normTensorDesc);
|
||||
|
||||
cudnnCreateTensorDescriptor(&l.normDstTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.srcTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
|
||||
cudnnCreateFilterDescriptor(&l.weightDesc);
|
||||
cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
|
||||
cudnnCreateFilterDescriptor(&l.dweightDesc);
|
||||
|
||||
cudnnCreateTensorDescriptor(&l.normDstTensorDescF16);
|
||||
cudnnCreateTensorDescriptor(&l.srcTensorDesc16);
|
||||
cudnnCreateTensorDescriptor(&l.dstTensorDesc16);
|
||||
cudnnCreateFilterDescriptor(&l.weightDesc16);
|
||||
cudnnCreateTensorDescriptor(&l.dsrcTensorDesc16);
|
||||
cudnnCreateTensorDescriptor(&l.ddstTensorDesc16);
|
||||
cudnnCreateFilterDescriptor(&l.dweightDesc16);
|
||||
|
||||
cudnnCreateConvolutionDescriptor(&l.convDesc);
|
||||
create_convolutional_cudnn_tensors(&l);
|
||||
cudnn_convolutional_setup(&l, cudnn_fastest);
|
||||
#endif
|
||||
}
|
||||
@ -486,7 +492,7 @@ void denormalize_convolutional_layer(convolutional_layer l)
|
||||
|
||||
void test_convolutional_layer()
|
||||
{
|
||||
convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0, 0, 0);
|
||||
convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0, 0, 0);
|
||||
l.batch_normalize = 1;
|
||||
float data[] = {1,1,1,1,1,
|
||||
1,1,1,1,1,
|
||||
|
Reference in New Issue
Block a user