LSTM, RNN, GRU - use connected_layer that uses cuDNN. Fixed CRNN for conv-layer with cuDNN.

This commit is contained in:
AlexeyAB
2019-01-28 23:50:51 +03:00
parent 0e1f3eaf35
commit 640bdbc063
13 changed files with 292 additions and 101 deletions

View File

@ -9,8 +9,8 @@ rem darknet.exe rnn train cfg/lstm.train.cfg backup/lstm.backup -file text.txt
pause
darknet.exe rnn generate cfg/lstm.train.cfg backup/lstm.backup -srand 2 -len 500 -seed apple
darknet.exe rnn generate cfg/lstm.train.cfg backup/lstm.backup -len 500 -seed apple
darknet.exe rnn generate cfg/lstm.train.cfg backup/lstm.backup -srand 2 -len 500 -seed apple > text_gen.txt
darknet.exe rnn generate cfg/lstm.train.cfg backup/lstm.backup -len 500 -seed apple > text_gen.txt
pause

View File

@ -1,5 +1,6 @@
#include "connected_layer.h"
#include "batchnorm_layer.h"
#include "convolutional_layer.h"
#include "utils.h"
#include "cuda.h"
#include "blas.h"
@ -10,6 +11,41 @@
#include <stdlib.h>
#include <string.h>
static size_t get_connected_workspace_size(layer l) {
#ifdef CUDNN
if (gpu_index >= 0) {
size_t most = 0;
size_t s = 0;
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
l.srcTensorDesc,
l.weightDesc,
l.convDesc,
l.dstTensorDesc,
l.fw_algo,
&s));
if (s > most) most = s;
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
l.srcTensorDesc,
l.ddstTensorDesc,
l.convDesc,
l.dweightDesc,
l.bf_algo,
&s));
if (s > most) most = s;
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
l.weightDesc,
l.ddstTensorDesc,
l.convDesc,
l.dsrcTensorDesc,
l.bd_algo,
&s));
if (s > most) most = s;
return most;
}
#endif
return 0;
}
connected_layer make_connected_layer(int batch, int steps, int inputs, int outputs, ACTIVATION activation, int batch_normalize)
{
int total_batch = batch*steps;
@ -27,6 +63,10 @@ connected_layer make_connected_layer(int batch, int steps, int inputs, int outpu
l.out_h = 1;
l.out_w = 1;
l.out_c = outputs;
l.n = l.out_c;
l.size = 1;
l.stride = 1;
l.pad = 0;
l.output = calloc(total_batch*outputs, sizeof(float));
l.delta = calloc(total_batch*outputs, sizeof(float));
@ -98,16 +138,13 @@ connected_layer make_connected_layer(int batch, int steps, int inputs, int outpu
l.x_gpu = cuda_make_array(l.output, total_batch*outputs);
l.x_norm_gpu = cuda_make_array(l.output, total_batch*outputs);
#ifdef CUDNN
cudnnCreateTensorDescriptor(&l.normDstTensorDesc);
cudnnCreateTensorDescriptor(&l.normTensorDesc);
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
cudnnSetTensor4dDescriptor(l.normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l.out_c, 1, 1);
cudnnSetTensor4dDescriptor(l.normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
#endif
}
#endif
#ifdef CUDNN
create_convolutional_cudnn_tensors(&l);
cudnn_convolutional_setup(&l, cudnn_fastest); // cudnn_fastest, cudnn_smallest
l.workspace_size = get_connected_workspace_size(l);
#endif // CUDNN
#endif // GPU
l.activation = activation;
fprintf(stderr, "connected %4d -> %4d\n", inputs, outputs);
return l;
@ -288,7 +325,27 @@ void forward_connected_layer_gpu(connected_layer l, network_state state)
float * a = state.input;
float * b = l.weights_gpu;
float * c = l.output_gpu;
#ifdef CUDNN
float one = 1; // alpha[0], beta[0]
float alpha = 1, beta = 0;
CHECK_CUDNN(cudnnConvolutionForward(cudnn_handle(),
&alpha, //&one,
l.srcTensorDesc,
state.input,
l.weightDesc,
l.weights_gpu,
l.convDesc,
l.fw_algo,
state.workspace,
l.workspace_size,
&beta, //&one,
l.dstTensorDesc,
l.output_gpu));
#else // CUDNN
gemm_ongpu(0,1,m,n,k,1,a,k,b,k,1,c,n);
#endif // CUDNN
if (l.batch_normalize) {
forward_batchnorm_layer_gpu(l, state);
}
@ -312,12 +369,51 @@ void backward_connected_layer_gpu(connected_layer l, network_state state)
backward_batchnorm_layer_gpu(l, state);
}
#ifdef CUDNN_DISABLED
float one = 1;
// calculate conv weight updates
// if used: beta=1 then loss decreases faster
CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(),
&one,
l.srcTensorDesc,
state.input,
l.ddstTensorDesc,
l.delta_gpu,
l.convDesc,
l.bf_algo,
state.workspace,
l.workspace_size,
&one,
l.dweightDesc,
l.weight_updates_gpu));
if (state.delta) {
// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
// calculate delta for the next layer
CHECK_CUDNN(cudnnConvolutionBackwardData(cudnn_handle(),
&one,
l.weightDesc,
l.weights_gpu,
l.ddstTensorDesc,
l.delta_gpu,
l.convDesc,
l.bd_algo,
state.workspace,
l.workspace_size,
&one,
l.dsrcTensorDesc,
state.delta));
}
#else // CUDNN
int m = l.outputs;
int k = l.batch;
int n = l.inputs;
float * a = l.delta_gpu;
float * b = state.input;
float * c = l.weight_updates_gpu;
gemm_ongpu(1,0,m,n,k,1,a,m,b,n,1,c,n);
m = l.batch;
@ -329,5 +425,6 @@ void backward_connected_layer_gpu(connected_layer l, network_state state)
c = state.delta;
if(c) gemm_ongpu(0,0,m,n,k,1,a,k,b,n,1,c,n);
#endif // CUDNN
}
#endif

View File

@ -444,7 +444,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
cuda_convert_f32_to_f16(state.input, input16_size, input16);
//fill_ongpu(output16_size / 2, 0, (float *)output16, 1);
cudnnConvolutionForward(cudnn_handle(),
CHECK_CUDNN(cudnnConvolutionForward(cudnn_handle(),
&alpha,
l.srcTensorDesc16,
input16,
@ -456,7 +456,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
l.workspace_size,
&beta,
l.dstTensorDesc16,
output16);
output16));
if (l.batch_normalize)
@ -469,7 +469,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
float zero = 0;
// Batch-normalization can still take FP16 inputs and outputs, saving half the bandwidth
// compared to FP32, it<69>s just that the statistics and value adjustment should be done in FP32.
cudnnBatchNormalizationForwardTraining(cudnn_handle(),
CHECK_CUDNN(cudnnBatchNormalizationForwardTraining(cudnn_handle(),
CUDNN_BATCHNORM_SPATIAL,
&one,
&zero,
@ -485,7 +485,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
l.rolling_variance_gpu, // output (should be FP32)
.00001,
l.mean_gpu, // output (should be FP32)
l.variance_gpu); // output (should be FP32)
l.variance_gpu)); // output (should be FP32)
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
//forward_batchnorm_layer_gpu(l, state);
@ -508,7 +508,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
//#else
cudnnConvolutionForward(cudnn_handle(),
CHECK_CUDNN(cudnnConvolutionForward(cudnn_handle(),
&alpha, //&one,
l.srcTensorDesc,
state.input,
@ -520,7 +520,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
l.workspace_size,
&beta, //&one,
l.dstTensorDesc,
l.output_gpu);
l.output_gpu));
//cudaDeviceSynchronize();
if (l.batch_normalize) {
@ -624,7 +624,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
//}
float one = 1;
float zero = 0;
cudnnBatchNormalizationBackward(cudnn_handle(),
CHECK_CUDNN(cudnnBatchNormalizationBackward(cudnn_handle(),
CUDNN_BATCHNORM_SPATIAL,
&one,
&zero,
@ -642,7 +642,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
l.bias_updates_gpu, // output (should be FP32)
.00001,
l.mean_gpu, // input (should be FP32)
l.variance_gpu); // input (should be FP32)
l.variance_gpu)); // input (should be FP32)
copy_ongpu(l.outputs*l.batch / 2, l.x_norm_gpu, 1, delta16, 1);
//cudaMemcpyAsync(delta16, l.x_norm_gpu, l.outputs*l.batch * sizeof(half), cudaMemcpyDefault, get_cuda_stream());
}
@ -659,7 +659,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
// so we should copy f32 to f16, or compute: f16=(w_up - w*d*b*s)*m
cuda_convert_f32_to_f16(l.weight_updates_gpu, l.c*l.n*l.size*l.size, l.weight_updates_gpu16);
cudnnConvolutionBackwardFilter(cudnn_handle(),
CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(),
&one,
l.srcTensorDesc16,
input16, //state.input,
@ -671,7 +671,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
l.workspace_size,
&one,
l.dweightDesc16,
l.weight_updates_gpu16); // l.weight_updates_gpu);
l.weight_updates_gpu16)); // l.weight_updates_gpu);
cuda_convert_f16_to_f32(l.weight_updates_gpu16, l.c*l.n*l.size*l.size, l.weight_updates_gpu);
@ -682,7 +682,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
// calculate delta for the next layer
// convert input: l.weights_gpu (w), l.delta_gpu (dy) from fp32 to fp16
// get output: state.delta (dx) and convert it to fp32 (ONLY if it is fp16)
cudnnConvolutionBackwardData(cudnn_handle(),
CHECK_CUDNN(cudnnConvolutionBackwardData(cudnn_handle(),
&alpha,
l.weightDesc16,
l.weights_gpu16, //l.weights_gpu,
@ -694,7 +694,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
l.workspace_size,
&beta,
l.dsrcTensorDesc16,
input16); // state.delta);
input16)); // state.delta);
cuda_convert_f16_to_f32(input16, input16_size, state.delta);
@ -711,7 +711,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
// calculate conv weight updates
// if used: beta=1 then loss decreases faster
cudnnConvolutionBackwardFilter(cudnn_handle(),
CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(),
&one,
l.srcTensorDesc,
state.input,
@ -723,13 +723,13 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
l.workspace_size,
&one,
l.dweightDesc,
l.weight_updates_gpu);
l.weight_updates_gpu));
if (state.delta) {
if (l.binary || l.xnor) swap_binary(&l);
// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
// calculate delta for the next layer
cudnnConvolutionBackwardData(cudnn_handle(),
CHECK_CUDNN(cudnnConvolutionBackwardData(cudnn_handle(),
&one,
l.weightDesc,
l.weights_gpu,
@ -741,7 +741,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
l.workspace_size,
&one,
l.dsrcTensorDesc,
state.delta);
state.delta));
if (l.binary || l.xnor) swap_binary(&l);
if (l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta);
}

View File

@ -105,29 +105,29 @@ size_t get_workspace_size(layer l){
if(gpu_index >= 0){
size_t most = 0;
size_t s = 0;
cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
l.srcTensorDesc,
l.weightDesc,
l.convDesc,
l.dstTensorDesc,
l.fw_algo,
&s);
&s));
if (s > most) most = s;
cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
l.srcTensorDesc,
l.ddstTensorDesc,
l.convDesc,
l.dweightDesc,
l.bf_algo,
&s);
&s));
if (s > most) most = s;
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
l.weightDesc,
l.ddstTensorDesc,
l.convDesc,
l.dsrcTensorDesc,
l.bd_algo,
&s);
&s));
if (s > most) most = s;
return most;
}
@ -141,29 +141,29 @@ size_t get_workspace_size16(layer l) {
if (gpu_index >= 0) {
size_t most = 0;
size_t s = 0;
cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
l.srcTensorDesc16,
l.weightDesc16,
l.convDesc,
l.dstTensorDesc16,
l.fw_algo16,
&s);
&s));
if (s > most) most = s;
cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
l.srcTensorDesc16,
l.ddstTensorDesc16,
l.convDesc,
l.dweightDesc16,
l.bf_algo16,
&s);
&s));
if (s > most) most = s;
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
l.weightDesc16,
l.ddstTensorDesc16,
l.convDesc,
l.dsrcTensorDesc16,
l.bd_algo16,
&s);
&s));
if (s > most) most = s;
return most;
}
@ -175,6 +175,29 @@ size_t get_workspace_size16(layer l) {
#ifdef GPU
#ifdef CUDNN
void create_convolutional_cudnn_tensors(layer *l)
{
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->normTensorDesc));
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->normDstTensorDesc));
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->srcTensorDesc));
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->dstTensorDesc));
CHECK_CUDNN(cudnnCreateFilterDescriptor(&l->weightDesc));
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->dsrcTensorDesc));
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->ddstTensorDesc));
CHECK_CUDNN(cudnnCreateFilterDescriptor(&l->dweightDesc));
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->normDstTensorDescF16));
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->srcTensorDesc16));
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->dstTensorDesc16));
CHECK_CUDNN(cudnnCreateFilterDescriptor(&l->weightDesc16));
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->dsrcTensorDesc16));
CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->ddstTensorDesc16));
CHECK_CUDNN(cudnnCreateFilterDescriptor(&l->dweightDesc16));
CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&l->convDesc));
}
void cudnn_convolutional_setup(layer *l, int cudnn_preference)
{
@ -194,9 +217,9 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
// 2. Loss Scaling - required only for: activation gradients. We do not use.
// 3. FP32 Master Copy of Weights
// More: http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#tensor_ops
cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH);
CHECK_CUDNN(cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH));
#if((CUDNN_MAJOR*10 + CUDNN_MINOR) >= 72) // cuDNN >= 7.2
cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION);
CHECK_CUDNN(cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION));
#endif
#endif
@ -205,38 +228,38 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
//cudnnDataType_t data_type = CUDNN_DATA_INT8;
// backward delta
cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
cudnnSetFilter4dDescriptor(l->dweightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w));
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w));
CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->dweightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size));
// forward
cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w));
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w));
CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size));
//#ifdef CUDNN_HALF
// backward delta
cudnnSetTensor4dDescriptor(l->dsrcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->ddstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w);
cudnnSetFilter4dDescriptor(l->dweightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dsrcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w));
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->ddstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w));
CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->dweightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size));
// forward
cudnnSetTensor4dDescriptor(l->srcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->dstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w);
cudnnSetFilter4dDescriptor(l->weightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->srcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w));
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w));
CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->weightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size));
// batch norm
cudnnSetTensor4dDescriptor(l->normDstTensorDescF16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w);
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normDstTensorDescF16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w));
//#endif
// batch norm
cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1);
cudnnSetTensor4dDescriptor(l->normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1));
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w));
#if(CUDNN_MAJOR >= 6)
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT); // cudnn >= 6.0
CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT)); // cudnn >= 6.0
#else
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION); // cudnn 5.1
CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION)); // cudnn 5.1
#endif
int forward_algo = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
int backward_algo = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
@ -249,30 +272,30 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
printf(" CUDNN-slow ");
}
cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
l->srcTensorDesc,
l->weightDesc,
l->convDesc,
l->dstTensorDesc,
forward_algo,
0,
&l->fw_algo);
cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
&l->fw_algo));
CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
l->weightDesc,
l->ddstTensorDesc,
l->convDesc,
l->dsrcTensorDesc,
backward_algo,
0,
&l->bd_algo);
cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
&l->bd_algo));
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
l->srcTensorDesc,
l->ddstTensorDesc,
l->convDesc,
l->dweightDesc,
backward_filter,
0,
&l->bf_algo);
&l->bf_algo));
//if (data_type == CUDNN_DATA_HALF)
{
@ -290,8 +313,9 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
#endif
#endif
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index)
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index)
{
int total_batch = batch*steps;
int i;
convolutional_layer l = {0};
l.type = CONVOLUTIONAL;
@ -327,8 +351,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.outputs = l.out_h * l.out_w * l.out_c;
l.inputs = l.w * l.h * l.c;
l.output = calloc(l.batch*l.outputs, sizeof(float));
l.delta = calloc(l.batch*l.outputs, sizeof(float));
l.output = calloc(total_batch*l.outputs, sizeof(float));
l.delta = calloc(total_batch*l.outputs, sizeof(float));
l.forward = forward_convolutional_layer;
l.backward = backward_convolutional_layer;
@ -364,8 +388,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.rolling_mean = calloc(n, sizeof(float));
l.rolling_variance = calloc(n, sizeof(float));
l.x = calloc(l.batch*l.outputs, sizeof(float));
l.x_norm = calloc(l.batch*l.outputs, sizeof(float));
l.x = calloc(total_batch*l.outputs, sizeof(float));
l.x_norm = calloc(total_batch*l.outputs, sizeof(float));
}
if(adam){
l.adam = 1;
@ -402,8 +426,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.biases_gpu = cuda_make_array(l.biases, n);
l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
l.delta_gpu = cuda_make_array(l.delta, total_batch*out_h*out_w*n);
l.output_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
if(binary){
l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size);
@ -427,29 +451,11 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.scales_gpu = cuda_make_array(l.scales, n);
l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
l.x_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
l.x_norm_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n);
}
#ifdef CUDNN
cudnnCreateTensorDescriptor(&l.normTensorDesc);
cudnnCreateTensorDescriptor(&l.normDstTensorDesc);
cudnnCreateTensorDescriptor(&l.srcTensorDesc);
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
cudnnCreateFilterDescriptor(&l.weightDesc);
cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
cudnnCreateFilterDescriptor(&l.dweightDesc);
cudnnCreateTensorDescriptor(&l.normDstTensorDescF16);
cudnnCreateTensorDescriptor(&l.srcTensorDesc16);
cudnnCreateTensorDescriptor(&l.dstTensorDesc16);
cudnnCreateFilterDescriptor(&l.weightDesc16);
cudnnCreateTensorDescriptor(&l.dsrcTensorDesc16);
cudnnCreateTensorDescriptor(&l.ddstTensorDesc16);
cudnnCreateFilterDescriptor(&l.dweightDesc16);
cudnnCreateConvolutionDescriptor(&l.convDesc);
create_convolutional_cudnn_tensors(&l);
cudnn_convolutional_setup(&l, cudnn_fastest);
#endif
}
@ -486,7 +492,7 @@ void denormalize_convolutional_layer(convolutional_layer l)
void test_convolutional_layer()
{
convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0, 0, 0);
convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0, 0, 0);
l.batch_normalize = 1;
float data[] = {1,1,1,1,1,
1,1,1,1,1,

View File

@ -21,11 +21,12 @@ void add_bias_gpu(float *output, float *biases, int batch, int n, int size);
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size);
#ifdef CUDNN
void cudnn_convolutional_setup(layer *l, int cudnn_preference);
void create_convolutional_cudnn_tensors(layer *l);
void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16);
#endif
#endif
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index);
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index);
void denormalize_convolutional_layer(convolutional_layer l);
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
void forward_convolutional_layer(const convolutional_layer layer, network_state state);

View File

@ -48,18 +48,21 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
l.input_layer = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0, 0);
*(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0, 0);
l.input_layer->batch = batch;
if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size;
l.self_layer = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0, 0);
*(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0, 0);
l.self_layer->batch = batch;
if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size;
l.output_layer = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0, 0);
*(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0, 0);
l.output_layer->batch = batch;
if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size;
l.output = l.output_layer->output;
l.delta = l.output_layer->delta;
@ -92,6 +95,7 @@ void forward_crnn_layer(layer l, network_state state)
{
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
int i;
layer input_layer = *(l.input_layer);
layer self_layer = *(l.self_layer);
@ -133,6 +137,7 @@ void backward_crnn_layer(layer l, network_state state)
{
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
int i;
layer input_layer = *(l.input_layer);
layer self_layer = *(l.self_layer);
@ -206,6 +211,7 @@ void forward_crnn_layer_gpu(layer l, network_state state)
{
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
int i;
layer input_layer = *(l.input_layer);
layer self_layer = *(l.self_layer);
@ -247,6 +253,7 @@ void backward_crnn_layer_gpu(layer l, network_state state)
{
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
int i;
layer input_layer = *(l.input_layer);
layer self_layer = *(l.self_layer);

View File

@ -35,6 +35,9 @@ void check_error(cudaError_t status)
printf("CUDA Error: %s\n", s);
assert(0);
snprintf(buffer, 256, "CUDA Error: %s", s);
#ifdef WIN32
getchar();
#endif
error(buffer);
}
if (status2 != cudaSuccess)
@ -44,10 +47,20 @@ void check_error(cudaError_t status)
printf("CUDA Error Prev: %s\n", s);
assert(0);
snprintf(buffer, 256, "CUDA Error Prev: %s", s);
#ifdef WIN32
getchar();
#endif
error(buffer);
}
}
void check_error_extended(cudaError_t status, char *file, int line, char *date_time)
{
if (status != cudaSuccess)
printf("CUDA Error: file: %s() : line: %d : build time: %s \n", file, line, date_time);
check_error(status);
}
dim3 cuda_gridsize(size_t n){
size_t k = (n-1) / BLOCK + 1;
size_t x = k;
@ -117,6 +130,45 @@ cudnnHandle_t cudnn_handle()
}
return handle[i];
}
void cudnn_check_error(cudnnStatus_t status)
{
//cudaDeviceSynchronize();
cudnnStatus_t status2;
cudnnStatus_t status_tmp = cudnnQueryRuntimeError(cudnn_handle(), &status2, CUDNN_ERRQUERY_RAWCODE, NULL);
if (status != CUDNN_STATUS_SUCCESS)
{
const char *s = cudnnGetErrorString(status);
char buffer[256];
printf("cuDNN Error: %s\n", s);
assert(0);
snprintf(buffer, 256, "cuDNN Error: %s", s);
#ifdef WIN32
getchar();
#endif
error(buffer);
}
if (status2 != CUDNN_STATUS_SUCCESS)
{
const char *s = cudnnGetErrorString(status);
char buffer[256];
printf("cuDNN Error Prev: %s\n", s);
assert(0);
snprintf(buffer, 256, "cuDNN Error Prev: %s", s);
#ifdef WIN32
getchar();
#endif
error(buffer);
}
}
void cudnn_check_error_extended(cudnnStatus_t status, char *file, int line, char *date_time)
{
if (status != cudaSuccess)
printf("\n cuDNN Error in: file: %s() : line: %d : build time: %s \n", file, line, date_time);
cudnn_check_error(status);
}
#endif
cublasHandle_t blas_handle()

View File

@ -24,6 +24,9 @@ extern int gpu_index;
extern "C" {
#endif // __cplusplus
void check_error(cudaError_t status);
void check_error_extended(cudaError_t status, char *file, int line, char *date_time);
#define CHECK_CUDA(X) check_error_extended(X, __FILE__ " : " __FUNCTION__, __LINE__, __DATE__ " - " __TIME__ );
cublasHandle_t blas_handle();
float *cuda_make_array(float *x, size_t n);
int *cuda_make_int_array(size_t n);
@ -46,6 +49,9 @@ extern "C" {
#ifdef CUDNN
cudnnHandle_t cudnn_handle();
enum {cudnn_fastest, cudnn_smallest};
void cudnn_check_error_extended(cudnnStatus_t status, char *file, int line, char *date_time);
#define CHECK_CUDNN(X) cudnn_check_error_extended(X, __FILE__ " : " __FUNCTION__, __LINE__, __DATE__ " - " __TIME__ );
#endif
#else // GPU

View File

@ -119,6 +119,7 @@ void forward_gru_layer(layer l, network_state state)
{
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
int i;
layer input_z_layer = *(l.input_z_layer);
layer input_r_layer = *(l.input_r_layer);
@ -219,6 +220,7 @@ void forward_gru_layer_gpu(layer l, network_state state)
{
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
int i;
layer input_z_layer = *(l.input_z_layer);
layer input_r_layer = *(l.input_r_layer);
@ -295,6 +297,7 @@ void backward_gru_layer_gpu(layer l, network_state state)
{
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
int i;
layer input_z_layer = *(l.input_z_layer);
layer input_r_layer = *(l.input_r_layer);

View File

@ -40,41 +40,49 @@ layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_n
fprintf(stderr, "\t\t");
*(l.uf) = make_connected_layer(batch, steps, inputs, outputs, LINEAR, batch_normalize);
l.uf->batch = batch;
if (l.workspace_size < l.uf->workspace_size) l.workspace_size = l.uf->workspace_size;
l.ui = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.ui) = make_connected_layer(batch, steps, inputs, outputs, LINEAR, batch_normalize);
l.ui->batch = batch;
if (l.workspace_size < l.ui->workspace_size) l.workspace_size = l.ui->workspace_size;
l.ug = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.ug) = make_connected_layer(batch, steps, inputs, outputs, LINEAR, batch_normalize);
l.ug->batch = batch;
if (l.workspace_size < l.ug->workspace_size) l.workspace_size = l.ug->workspace_size;
l.uo = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.uo) = make_connected_layer(batch, steps, inputs, outputs, LINEAR, batch_normalize);
l.uo->batch = batch;
if (l.workspace_size < l.uo->workspace_size) l.workspace_size = l.uo->workspace_size;
l.wf = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.wf) = make_connected_layer(batch, steps, outputs, outputs, LINEAR, batch_normalize);
l.wf->batch = batch;
if (l.workspace_size < l.wf->workspace_size) l.workspace_size = l.wf->workspace_size;
l.wi = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.wi) = make_connected_layer(batch, steps, outputs, outputs, LINEAR, batch_normalize);
l.wi->batch = batch;
if (l.workspace_size < l.wi->workspace_size) l.workspace_size = l.wi->workspace_size;
l.wg = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.wg) = make_connected_layer(batch, steps, outputs, outputs, LINEAR, batch_normalize);
l.wg->batch = batch;
if (l.workspace_size < l.wg->workspace_size) l.workspace_size = l.wg->workspace_size;
l.wo = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.wo) = make_connected_layer(batch, steps, outputs, outputs, LINEAR, batch_normalize);
l.wo->batch = batch;
if (l.workspace_size < l.wo->workspace_size) l.workspace_size = l.wo->workspace_size;
l.batch_normalize = batch_normalize;
l.outputs = outputs;
@ -161,6 +169,7 @@ void forward_lstm_layer(layer l, network_state state)
{
network_state s = { 0 };
s.train = state.train;
s.workspace = state.workspace;
int i;
layer wf = *(l.wf);
layer wi = *(l.wi);
@ -247,6 +256,7 @@ void backward_lstm_layer(layer l, network_state state)
{
network_state s = { 0 };
s.train = state.train;
s.workspace = state.workspace;
int i;
layer wf = *(l.wf);
layer wi = *(l.wi);
@ -403,6 +413,7 @@ void forward_lstm_layer_gpu(layer l, network_state state)
{
network_state s = { 0 };
s.train = state.train;
s.workspace = state.workspace;
int i;
layer wf = *(l.wf);
layer wi = *(l.wi);
@ -489,6 +500,7 @@ void backward_lstm_layer_gpu(layer l, network_state state)
{
network_state s = { 0 };
s.train = state.train;
s.workspace = state.workspace;
int i;
layer wf = *(l.wf);
layer wi = *(l.wi);

View File

@ -58,7 +58,7 @@ void forward_network_gpu(network net, network_state state)
if(l.delta_gpu && state.train){
fill_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1);
}
//printf("%d - type: %d - ", i, l.type);
//printf("\n layer %d - type: %d - \n", i, l.type);
//start_timer();
l.forward_gpu(l, state);
//cudaDeviceSynchronize();

View File

@ -167,7 +167,7 @@ convolutional_layer parse_convolutional(list *options, size_params params)
int xnor = option_find_int_quiet(options, "xnor", 0);
int use_bin_output = option_find_int_quiet(options, "bin_output", 0);
convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index);
convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index);
layer.flipped = option_find_int_quiet(options, "flipped", 0);
layer.dot = option_find_float_quiet(options, "dot", 0);
@ -845,7 +845,6 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
net.outputs = get_network_output_size(net);
net.output = get_network_output(net);
printf("Total BFLOPS %5.3f \n", bflops);
//printf("%ld\n", workspace_size);
#ifdef GPU
get_cuda_stream();
get_cuda_memcpy_stream();
@ -864,6 +863,7 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
check_error(cudaMalloc((void **)net.output16_gpu, *net.max_output16_size * sizeof(short))); //sizeof(half)
}
if (workspace_size) {
printf(" Allocate workspace_size = %zu \n", workspace_size);
net.workspace = cuda_make_array(0, workspace_size / sizeof(float) + 1);
}
else {

View File

@ -43,16 +43,19 @@ layer make_rnn_layer(int batch, int inputs, int hidden, int outputs, int steps,
fprintf(stderr, "\t\t");
*(l.input_layer) = make_connected_layer(batch, steps, inputs, hidden, activation, batch_normalize);
l.input_layer->batch = batch;
if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size;
l.self_layer = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.self_layer) = make_connected_layer(batch, steps, hidden, hidden, (log==2)?LOGGY:(log==1?LOGISTIC:activation), batch_normalize);
l.self_layer->batch = batch;
if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size;
l.output_layer = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.output_layer) = make_connected_layer(batch, steps, hidden, outputs, activation, batch_normalize);
l.output_layer->batch = batch;
if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size;
l.outputs = outputs;
l.output = l.output_layer->output;
@ -84,6 +87,7 @@ void forward_rnn_layer(layer l, network_state state)
{
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
int i;
layer input_layer = *(l.input_layer);
layer self_layer = *(l.self_layer);
@ -126,6 +130,7 @@ void backward_rnn_layer(layer l, network_state state)
{
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
int i;
layer input_layer = *(l.input_layer);
layer self_layer = *(l.self_layer);
@ -199,6 +204,7 @@ void forward_rnn_layer_gpu(layer l, network_state state)
{
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
int i;
layer input_layer = *(l.input_layer);
layer self_layer = *(l.self_layer);
@ -241,6 +247,7 @@ void backward_rnn_layer_gpu(layer l, network_state state)
{
network_state s = {0};
s.train = state.train;
s.workspace = state.workspace;
int i;
layer input_layer = *(l.input_layer);
layer self_layer = *(l.self_layer);