diff --git a/build/darknet/x64/rnn_lstm.cmd b/build/darknet/x64/rnn_lstm.cmd index 22ecc341..de74c9d2 100644 --- a/build/darknet/x64/rnn_lstm.cmd +++ b/build/darknet/x64/rnn_lstm.cmd @@ -9,8 +9,8 @@ rem darknet.exe rnn train cfg/lstm.train.cfg backup/lstm.backup -file text.txt pause -darknet.exe rnn generate cfg/lstm.train.cfg backup/lstm.backup -srand 2 -len 500 -seed apple +darknet.exe rnn generate cfg/lstm.train.cfg backup/lstm.backup -len 500 -seed apple -darknet.exe rnn generate cfg/lstm.train.cfg backup/lstm.backup -srand 2 -len 500 -seed apple > text_gen.txt +darknet.exe rnn generate cfg/lstm.train.cfg backup/lstm.backup -len 500 -seed apple > text_gen.txt pause \ No newline at end of file diff --git a/src/connected_layer.c b/src/connected_layer.c index 96f7aaf2..2637f946 100644 --- a/src/connected_layer.c +++ b/src/connected_layer.c @@ -1,5 +1,6 @@ #include "connected_layer.h" #include "batchnorm_layer.h" +#include "convolutional_layer.h" #include "utils.h" #include "cuda.h" #include "blas.h" @@ -10,6 +11,41 @@ #include #include +static size_t get_connected_workspace_size(layer l) { +#ifdef CUDNN + if (gpu_index >= 0) { + size_t most = 0; + size_t s = 0; + CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(), + l.srcTensorDesc, + l.weightDesc, + l.convDesc, + l.dstTensorDesc, + l.fw_algo, + &s)); + if (s > most) most = s; + CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(), + l.srcTensorDesc, + l.ddstTensorDesc, + l.convDesc, + l.dweightDesc, + l.bf_algo, + &s)); + if (s > most) most = s; + CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(), + l.weightDesc, + l.ddstTensorDesc, + l.convDesc, + l.dsrcTensorDesc, + l.bd_algo, + &s)); + if (s > most) most = s; + return most; + } +#endif + return 0; +} + connected_layer make_connected_layer(int batch, int steps, int inputs, int outputs, ACTIVATION activation, int batch_normalize) { int total_batch = batch*steps; @@ -27,6 +63,10 @@ connected_layer make_connected_layer(int batch, int steps, int inputs, int outpu l.out_h = 1; l.out_w = 1; l.out_c = outputs; + l.n = l.out_c; + l.size = 1; + l.stride = 1; + l.pad = 0; l.output = calloc(total_batch*outputs, sizeof(float)); l.delta = calloc(total_batch*outputs, sizeof(float)); @@ -83,7 +123,7 @@ connected_layer make_connected_layer(int batch, int steps, int inputs, int outpu l.output_gpu = cuda_make_array(l.output, outputs*total_batch); l.delta_gpu = cuda_make_array(l.delta, outputs*total_batch); - if(batch_normalize){ + if (batch_normalize) { l.scales_gpu = cuda_make_array(l.scales, outputs); l.scale_updates_gpu = cuda_make_array(l.scale_updates, outputs); @@ -98,16 +138,13 @@ connected_layer make_connected_layer(int batch, int steps, int inputs, int outpu l.x_gpu = cuda_make_array(l.output, total_batch*outputs); l.x_norm_gpu = cuda_make_array(l.output, total_batch*outputs); -#ifdef CUDNN - cudnnCreateTensorDescriptor(&l.normDstTensorDesc); - cudnnCreateTensorDescriptor(&l.normTensorDesc); - cudnnCreateTensorDescriptor(&l.dstTensorDesc); - cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w); - cudnnSetTensor4dDescriptor(l.normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l.out_c, 1, 1); - cudnnSetTensor4dDescriptor(l.normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w); -#endif } -#endif +#ifdef CUDNN + create_convolutional_cudnn_tensors(&l); + cudnn_convolutional_setup(&l, cudnn_fastest); // cudnn_fastest, cudnn_smallest + l.workspace_size = get_connected_workspace_size(l); +#endif // CUDNN +#endif // GPU l.activation = activation; fprintf(stderr, "connected %4d -> %4d\n", inputs, outputs); return l; @@ -288,7 +325,27 @@ void forward_connected_layer_gpu(connected_layer l, network_state state) float * a = state.input; float * b = l.weights_gpu; float * c = l.output_gpu; +#ifdef CUDNN + float one = 1; // alpha[0], beta[0] + float alpha = 1, beta = 0; + + CHECK_CUDNN(cudnnConvolutionForward(cudnn_handle(), + &alpha, //&one, + l.srcTensorDesc, + state.input, + l.weightDesc, + l.weights_gpu, + l.convDesc, + l.fw_algo, + state.workspace, + l.workspace_size, + &beta, //&one, + l.dstTensorDesc, + l.output_gpu)); +#else // CUDNN gemm_ongpu(0,1,m,n,k,1,a,k,b,k,1,c,n); +#endif // CUDNN + if (l.batch_normalize) { forward_batchnorm_layer_gpu(l, state); } @@ -312,12 +369,51 @@ void backward_connected_layer_gpu(connected_layer l, network_state state) backward_batchnorm_layer_gpu(l, state); } +#ifdef CUDNN_DISABLED + float one = 1; + // calculate conv weight updates + // if used: beta=1 then loss decreases faster + CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(), + &one, + l.srcTensorDesc, + state.input, + l.ddstTensorDesc, + l.delta_gpu, + l.convDesc, + l.bf_algo, + state.workspace, + l.workspace_size, + &one, + l.dweightDesc, + l.weight_updates_gpu)); + + if (state.delta) { + // http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData + // calculate delta for the next layer + + CHECK_CUDNN(cudnnConvolutionBackwardData(cudnn_handle(), + &one, + l.weightDesc, + l.weights_gpu, + l.ddstTensorDesc, + l.delta_gpu, + l.convDesc, + l.bd_algo, + state.workspace, + l.workspace_size, + &one, + l.dsrcTensorDesc, + state.delta)); + } +#else // CUDNN + int m = l.outputs; int k = l.batch; int n = l.inputs; float * a = l.delta_gpu; float * b = state.input; float * c = l.weight_updates_gpu; + gemm_ongpu(1,0,m,n,k,1,a,m,b,n,1,c,n); m = l.batch; @@ -329,5 +425,6 @@ void backward_connected_layer_gpu(connected_layer l, network_state state) c = state.delta; if(c) gemm_ongpu(0,0,m,n,k,1,a,k,b,n,1,c,n); +#endif // CUDNN } #endif diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index f8ab4971..279c176f 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -444,7 +444,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) cuda_convert_f32_to_f16(state.input, input16_size, input16); //fill_ongpu(output16_size / 2, 0, (float *)output16, 1); - cudnnConvolutionForward(cudnn_handle(), + CHECK_CUDNN(cudnnConvolutionForward(cudnn_handle(), &alpha, l.srcTensorDesc16, input16, @@ -456,7 +456,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) l.workspace_size, &beta, l.dstTensorDesc16, - output16); + output16)); if (l.batch_normalize) @@ -469,7 +469,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) float zero = 0; // Batch-normalization can still take FP16 inputs and outputs, saving half the bandwidth // compared to FP32, it’s just that the statistics and value adjustment should be done in FP32. - cudnnBatchNormalizationForwardTraining(cudnn_handle(), + CHECK_CUDNN(cudnnBatchNormalizationForwardTraining(cudnn_handle(), CUDNN_BATCHNORM_SPATIAL, &one, &zero, @@ -485,7 +485,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) l.rolling_variance_gpu, // output (should be FP32) .00001, l.mean_gpu, // output (should be FP32) - l.variance_gpu); // output (should be FP32) + l.variance_gpu)); // output (should be FP32) cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu); //forward_batchnorm_layer_gpu(l, state); @@ -508,7 +508,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) //#else - cudnnConvolutionForward(cudnn_handle(), + CHECK_CUDNN(cudnnConvolutionForward(cudnn_handle(), &alpha, //&one, l.srcTensorDesc, state.input, @@ -520,7 +520,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) l.workspace_size, &beta, //&one, l.dstTensorDesc, - l.output_gpu); + l.output_gpu)); //cudaDeviceSynchronize(); if (l.batch_normalize) { @@ -624,7 +624,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state //} float one = 1; float zero = 0; - cudnnBatchNormalizationBackward(cudnn_handle(), + CHECK_CUDNN(cudnnBatchNormalizationBackward(cudnn_handle(), CUDNN_BATCHNORM_SPATIAL, &one, &zero, @@ -642,7 +642,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state l.bias_updates_gpu, // output (should be FP32) .00001, l.mean_gpu, // input (should be FP32) - l.variance_gpu); // input (should be FP32) + l.variance_gpu)); // input (should be FP32) copy_ongpu(l.outputs*l.batch / 2, l.x_norm_gpu, 1, delta16, 1); //cudaMemcpyAsync(delta16, l.x_norm_gpu, l.outputs*l.batch * sizeof(half), cudaMemcpyDefault, get_cuda_stream()); } @@ -659,7 +659,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state // so we should copy f32 to f16, or compute: f16=(w_up - w*d*b*s)*m cuda_convert_f32_to_f16(l.weight_updates_gpu, l.c*l.n*l.size*l.size, l.weight_updates_gpu16); - cudnnConvolutionBackwardFilter(cudnn_handle(), + CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(), &one, l.srcTensorDesc16, input16, //state.input, @@ -671,7 +671,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state l.workspace_size, &one, l.dweightDesc16, - l.weight_updates_gpu16); // l.weight_updates_gpu); + l.weight_updates_gpu16)); // l.weight_updates_gpu); cuda_convert_f16_to_f32(l.weight_updates_gpu16, l.c*l.n*l.size*l.size, l.weight_updates_gpu); @@ -682,7 +682,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state // calculate delta for the next layer // convert input: l.weights_gpu (w), l.delta_gpu (dy) from fp32 to fp16 // get output: state.delta (dx) and convert it to fp32 (ONLY if it is fp16) - cudnnConvolutionBackwardData(cudnn_handle(), + CHECK_CUDNN(cudnnConvolutionBackwardData(cudnn_handle(), &alpha, l.weightDesc16, l.weights_gpu16, //l.weights_gpu, @@ -694,7 +694,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state l.workspace_size, &beta, l.dsrcTensorDesc16, - input16); // state.delta); + input16)); // state.delta); cuda_convert_f16_to_f32(input16, input16_size, state.delta); @@ -711,7 +711,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state // calculate conv weight updates // if used: beta=1 then loss decreases faster - cudnnConvolutionBackwardFilter(cudnn_handle(), + CHECK_CUDNN(cudnnConvolutionBackwardFilter(cudnn_handle(), &one, l.srcTensorDesc, state.input, @@ -723,13 +723,13 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state l.workspace_size, &one, l.dweightDesc, - l.weight_updates_gpu); + l.weight_updates_gpu)); if (state.delta) { if (l.binary || l.xnor) swap_binary(&l); // http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData // calculate delta for the next layer - cudnnConvolutionBackwardData(cudnn_handle(), + CHECK_CUDNN(cudnnConvolutionBackwardData(cudnn_handle(), &one, l.weightDesc, l.weights_gpu, @@ -741,7 +741,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state l.workspace_size, &one, l.dsrcTensorDesc, - state.delta); + state.delta)); if (l.binary || l.xnor) swap_binary(&l); if (l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta); } diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index fd055a37..39a72d3b 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -105,29 +105,29 @@ size_t get_workspace_size(layer l){ if(gpu_index >= 0){ size_t most = 0; size_t s = 0; - cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(), + CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(), l.srcTensorDesc, l.weightDesc, l.convDesc, l.dstTensorDesc, l.fw_algo, - &s); + &s)); if (s > most) most = s; - cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(), + CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(), l.srcTensorDesc, l.ddstTensorDesc, l.convDesc, l.dweightDesc, l.bf_algo, - &s); + &s)); if (s > most) most = s; - cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(), + CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(), l.weightDesc, l.ddstTensorDesc, l.convDesc, l.dsrcTensorDesc, l.bd_algo, - &s); + &s)); if (s > most) most = s; return most; } @@ -141,29 +141,29 @@ size_t get_workspace_size16(layer l) { if (gpu_index >= 0) { size_t most = 0; size_t s = 0; - cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(), + CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(), l.srcTensorDesc16, l.weightDesc16, l.convDesc, l.dstTensorDesc16, l.fw_algo16, - &s); + &s)); if (s > most) most = s; - cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(), + CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(), l.srcTensorDesc16, l.ddstTensorDesc16, l.convDesc, l.dweightDesc16, l.bf_algo16, - &s); + &s)); if (s > most) most = s; - cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(), + CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(), l.weightDesc16, l.ddstTensorDesc16, l.convDesc, l.dsrcTensorDesc16, l.bd_algo16, - &s); + &s)); if (s > most) most = s; return most; } @@ -175,6 +175,29 @@ size_t get_workspace_size16(layer l) { #ifdef GPU #ifdef CUDNN +void create_convolutional_cudnn_tensors(layer *l) +{ + CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->normTensorDesc)); + + CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->normDstTensorDesc)); + CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->srcTensorDesc)); + CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->dstTensorDesc)); + CHECK_CUDNN(cudnnCreateFilterDescriptor(&l->weightDesc)); + CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->dsrcTensorDesc)); + CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->ddstTensorDesc)); + CHECK_CUDNN(cudnnCreateFilterDescriptor(&l->dweightDesc)); + + CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->normDstTensorDescF16)); + CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->srcTensorDesc16)); + CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->dstTensorDesc16)); + CHECK_CUDNN(cudnnCreateFilterDescriptor(&l->weightDesc16)); + CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->dsrcTensorDesc16)); + CHECK_CUDNN(cudnnCreateTensorDescriptor(&l->ddstTensorDesc16)); + CHECK_CUDNN(cudnnCreateFilterDescriptor(&l->dweightDesc16)); + + CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&l->convDesc)); +} + void cudnn_convolutional_setup(layer *l, int cudnn_preference) { @@ -194,9 +217,9 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference) // 2. Loss Scaling - required only for: activation gradients. We do not use. // 3. FP32 Master Copy of Weights // More: http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#tensor_ops - cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH); + CHECK_CUDNN(cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH)); #if((CUDNN_MAJOR*10 + CUDNN_MINOR) >= 72) // cuDNN >= 7.2 - cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION); + CHECK_CUDNN(cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); #endif #endif @@ -205,38 +228,38 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference) //cudnnDataType_t data_type = CUDNN_DATA_INT8; // backward delta - cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w); - cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w); - cudnnSetFilter4dDescriptor(l->dweightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w)); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w)); + CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->dweightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size)); // forward - cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w); - cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w); - cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w)); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w)); + CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size)); //#ifdef CUDNN_HALF // backward delta - cudnnSetTensor4dDescriptor(l->dsrcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w); - cudnnSetTensor4dDescriptor(l->ddstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w); - cudnnSetFilter4dDescriptor(l->dweightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dsrcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w)); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->ddstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w)); + CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->dweightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size)); // forward - cudnnSetTensor4dDescriptor(l->srcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w); - cudnnSetTensor4dDescriptor(l->dstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w); - cudnnSetFilter4dDescriptor(l->weightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->srcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w)); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w)); + CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->weightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size)); // batch norm - cudnnSetTensor4dDescriptor(l->normDstTensorDescF16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normDstTensorDescF16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w)); //#endif // batch norm - cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1); - cudnnSetTensor4dDescriptor(l->normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1)); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w)); #if(CUDNN_MAJOR >= 6) - cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT); // cudnn >= 6.0 + CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT)); // cudnn >= 6.0 #else - cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION); // cudnn 5.1 + CHECK_CUDNN(cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION)); // cudnn 5.1 #endif int forward_algo = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST; int backward_algo = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST; @@ -249,30 +272,30 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference) printf(" CUDNN-slow "); } - cudnnGetConvolutionForwardAlgorithm(cudnn_handle(), + CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(cudnn_handle(), l->srcTensorDesc, l->weightDesc, l->convDesc, l->dstTensorDesc, forward_algo, 0, - &l->fw_algo); - cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(), + &l->fw_algo)); + CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(), l->weightDesc, l->ddstTensorDesc, l->convDesc, l->dsrcTensorDesc, backward_algo, 0, - &l->bd_algo); - cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(), + &l->bd_algo)); + CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(), l->srcTensorDesc, l->ddstTensorDesc, l->convDesc, l->dweightDesc, backward_filter, 0, - &l->bf_algo); + &l->bf_algo)); //if (data_type == CUDNN_DATA_HALF) { @@ -290,8 +313,9 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference) #endif #endif -convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index) +convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index) { + int total_batch = batch*steps; int i; convolutional_layer l = {0}; l.type = CONVOLUTIONAL; @@ -327,8 +351,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int l.outputs = l.out_h * l.out_w * l.out_c; l.inputs = l.w * l.h * l.c; - l.output = calloc(l.batch*l.outputs, sizeof(float)); - l.delta = calloc(l.batch*l.outputs, sizeof(float)); + l.output = calloc(total_batch*l.outputs, sizeof(float)); + l.delta = calloc(total_batch*l.outputs, sizeof(float)); l.forward = forward_convolutional_layer; l.backward = backward_convolutional_layer; @@ -364,8 +388,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int l.rolling_mean = calloc(n, sizeof(float)); l.rolling_variance = calloc(n, sizeof(float)); - l.x = calloc(l.batch*l.outputs, sizeof(float)); - l.x_norm = calloc(l.batch*l.outputs, sizeof(float)); + l.x = calloc(total_batch*l.outputs, sizeof(float)); + l.x_norm = calloc(total_batch*l.outputs, sizeof(float)); } if(adam){ l.adam = 1; @@ -402,8 +426,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int l.biases_gpu = cuda_make_array(l.biases, n); l.bias_updates_gpu = cuda_make_array(l.bias_updates, n); - l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n); - l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); + l.delta_gpu = cuda_make_array(l.delta, total_batch*out_h*out_w*n); + l.output_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n); if(binary){ l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size); @@ -427,29 +451,11 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int l.scales_gpu = cuda_make_array(l.scales, n); l.scale_updates_gpu = cuda_make_array(l.scale_updates, n); - l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); - l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); + l.x_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n); + l.x_norm_gpu = cuda_make_array(l.output, total_batch*out_h*out_w*n); } #ifdef CUDNN - cudnnCreateTensorDescriptor(&l.normTensorDesc); - - cudnnCreateTensorDescriptor(&l.normDstTensorDesc); - cudnnCreateTensorDescriptor(&l.srcTensorDesc); - cudnnCreateTensorDescriptor(&l.dstTensorDesc); - cudnnCreateFilterDescriptor(&l.weightDesc); - cudnnCreateTensorDescriptor(&l.dsrcTensorDesc); - cudnnCreateTensorDescriptor(&l.ddstTensorDesc); - cudnnCreateFilterDescriptor(&l.dweightDesc); - - cudnnCreateTensorDescriptor(&l.normDstTensorDescF16); - cudnnCreateTensorDescriptor(&l.srcTensorDesc16); - cudnnCreateTensorDescriptor(&l.dstTensorDesc16); - cudnnCreateFilterDescriptor(&l.weightDesc16); - cudnnCreateTensorDescriptor(&l.dsrcTensorDesc16); - cudnnCreateTensorDescriptor(&l.ddstTensorDesc16); - cudnnCreateFilterDescriptor(&l.dweightDesc16); - - cudnnCreateConvolutionDescriptor(&l.convDesc); + create_convolutional_cudnn_tensors(&l); cudnn_convolutional_setup(&l, cudnn_fastest); #endif } @@ -486,7 +492,7 @@ void denormalize_convolutional_layer(convolutional_layer l) void test_convolutional_layer() { - convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0, 0, 0); + convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0, 0, 0); l.batch_normalize = 1; float data[] = {1,1,1,1,1, 1,1,1,1,1, diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index 0bd9849e..19a06251 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -21,11 +21,12 @@ void add_bias_gpu(float *output, float *biases, int batch, int n, int size); void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size); #ifdef CUDNN void cudnn_convolutional_setup(layer *l, int cudnn_preference); +void create_convolutional_cudnn_tensors(layer *l); void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16); #endif #endif -convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index); +convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index); void denormalize_convolutional_layer(convolutional_layer l); void resize_convolutional_layer(convolutional_layer *layer, int w, int h); void forward_convolutional_layer(const convolutional_layer layer, network_state state); diff --git a/src/crnn_layer.c b/src/crnn_layer.c index db384cfb..d76c5e22 100644 --- a/src/crnn_layer.c +++ b/src/crnn_layer.c @@ -48,18 +48,21 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou l.input_layer = malloc(sizeof(layer)); fprintf(stderr, "\t\t"); - *(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0, 0); + *(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0, 0); l.input_layer->batch = batch; + if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size; l.self_layer = malloc(sizeof(layer)); fprintf(stderr, "\t\t"); - *(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0, 0); + *(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0, 0); l.self_layer->batch = batch; + if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size; l.output_layer = malloc(sizeof(layer)); fprintf(stderr, "\t\t"); - *(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0, 0); + *(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0, 0, 0); l.output_layer->batch = batch; + if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size; l.output = l.output_layer->output; l.delta = l.output_layer->delta; @@ -92,6 +95,7 @@ void forward_crnn_layer(layer l, network_state state) { network_state s = {0}; s.train = state.train; + s.workspace = state.workspace; int i; layer input_layer = *(l.input_layer); layer self_layer = *(l.self_layer); @@ -133,6 +137,7 @@ void backward_crnn_layer(layer l, network_state state) { network_state s = {0}; s.train = state.train; + s.workspace = state.workspace; int i; layer input_layer = *(l.input_layer); layer self_layer = *(l.self_layer); @@ -206,6 +211,7 @@ void forward_crnn_layer_gpu(layer l, network_state state) { network_state s = {0}; s.train = state.train; + s.workspace = state.workspace; int i; layer input_layer = *(l.input_layer); layer self_layer = *(l.self_layer); @@ -247,6 +253,7 @@ void backward_crnn_layer_gpu(layer l, network_state state) { network_state s = {0}; s.train = state.train; + s.workspace = state.workspace; int i; layer input_layer = *(l.input_layer); layer self_layer = *(l.self_layer); diff --git a/src/cuda.c b/src/cuda.c index 382a6cbd..d1f9afef 100644 --- a/src/cuda.c +++ b/src/cuda.c @@ -35,6 +35,9 @@ void check_error(cudaError_t status) printf("CUDA Error: %s\n", s); assert(0); snprintf(buffer, 256, "CUDA Error: %s", s); +#ifdef WIN32 + getchar(); +#endif error(buffer); } if (status2 != cudaSuccess) @@ -44,10 +47,20 @@ void check_error(cudaError_t status) printf("CUDA Error Prev: %s\n", s); assert(0); snprintf(buffer, 256, "CUDA Error Prev: %s", s); +#ifdef WIN32 + getchar(); +#endif error(buffer); } } +void check_error_extended(cudaError_t status, char *file, int line, char *date_time) +{ + if (status != cudaSuccess) + printf("CUDA Error: file: %s() : line: %d : build time: %s \n", file, line, date_time); + check_error(status); +} + dim3 cuda_gridsize(size_t n){ size_t k = (n-1) / BLOCK + 1; size_t x = k; @@ -117,6 +130,45 @@ cudnnHandle_t cudnn_handle() } return handle[i]; } + + +void cudnn_check_error(cudnnStatus_t status) +{ + //cudaDeviceSynchronize(); + cudnnStatus_t status2; + cudnnStatus_t status_tmp = cudnnQueryRuntimeError(cudnn_handle(), &status2, CUDNN_ERRQUERY_RAWCODE, NULL); + if (status != CUDNN_STATUS_SUCCESS) + { + const char *s = cudnnGetErrorString(status); + char buffer[256]; + printf("cuDNN Error: %s\n", s); + assert(0); + snprintf(buffer, 256, "cuDNN Error: %s", s); +#ifdef WIN32 + getchar(); +#endif + error(buffer); + } + if (status2 != CUDNN_STATUS_SUCCESS) + { + const char *s = cudnnGetErrorString(status); + char buffer[256]; + printf("cuDNN Error Prev: %s\n", s); + assert(0); + snprintf(buffer, 256, "cuDNN Error Prev: %s", s); +#ifdef WIN32 + getchar(); +#endif + error(buffer); + } +} + +void cudnn_check_error_extended(cudnnStatus_t status, char *file, int line, char *date_time) +{ + if (status != cudaSuccess) + printf("\n cuDNN Error in: file: %s() : line: %d : build time: %s \n", file, line, date_time); + cudnn_check_error(status); +} #endif cublasHandle_t blas_handle() diff --git a/src/cuda.h b/src/cuda.h index b803c17f..d2c7da00 100644 --- a/src/cuda.h +++ b/src/cuda.h @@ -24,6 +24,9 @@ extern int gpu_index; extern "C" { #endif // __cplusplus void check_error(cudaError_t status); + void check_error_extended(cudaError_t status, char *file, int line, char *date_time); +#define CHECK_CUDA(X) check_error_extended(X, __FILE__ " : " __FUNCTION__, __LINE__, __DATE__ " - " __TIME__ ); + cublasHandle_t blas_handle(); float *cuda_make_array(float *x, size_t n); int *cuda_make_int_array(size_t n); @@ -46,6 +49,9 @@ extern "C" { #ifdef CUDNN cudnnHandle_t cudnn_handle(); enum {cudnn_fastest, cudnn_smallest}; + +void cudnn_check_error_extended(cudnnStatus_t status, char *file, int line, char *date_time); +#define CHECK_CUDNN(X) cudnn_check_error_extended(X, __FILE__ " : " __FUNCTION__, __LINE__, __DATE__ " - " __TIME__ ); #endif #else // GPU diff --git a/src/gru_layer.c b/src/gru_layer.c index fa03e7eb..045bbb33 100644 --- a/src/gru_layer.c +++ b/src/gru_layer.c @@ -119,6 +119,7 @@ void forward_gru_layer(layer l, network_state state) { network_state s = {0}; s.train = state.train; + s.workspace = state.workspace; int i; layer input_z_layer = *(l.input_z_layer); layer input_r_layer = *(l.input_r_layer); @@ -219,6 +220,7 @@ void forward_gru_layer_gpu(layer l, network_state state) { network_state s = {0}; s.train = state.train; + s.workspace = state.workspace; int i; layer input_z_layer = *(l.input_z_layer); layer input_r_layer = *(l.input_r_layer); @@ -295,6 +297,7 @@ void backward_gru_layer_gpu(layer l, network_state state) { network_state s = {0}; s.train = state.train; + s.workspace = state.workspace; int i; layer input_z_layer = *(l.input_z_layer); layer input_r_layer = *(l.input_r_layer); diff --git a/src/lstm_layer.c b/src/lstm_layer.c index aefe9c46..cf3411e6 100644 --- a/src/lstm_layer.c +++ b/src/lstm_layer.c @@ -40,41 +40,49 @@ layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_n fprintf(stderr, "\t\t"); *(l.uf) = make_connected_layer(batch, steps, inputs, outputs, LINEAR, batch_normalize); l.uf->batch = batch; + if (l.workspace_size < l.uf->workspace_size) l.workspace_size = l.uf->workspace_size; l.ui = malloc(sizeof(layer)); fprintf(stderr, "\t\t"); *(l.ui) = make_connected_layer(batch, steps, inputs, outputs, LINEAR, batch_normalize); l.ui->batch = batch; + if (l.workspace_size < l.ui->workspace_size) l.workspace_size = l.ui->workspace_size; l.ug = malloc(sizeof(layer)); fprintf(stderr, "\t\t"); *(l.ug) = make_connected_layer(batch, steps, inputs, outputs, LINEAR, batch_normalize); l.ug->batch = batch; + if (l.workspace_size < l.ug->workspace_size) l.workspace_size = l.ug->workspace_size; l.uo = malloc(sizeof(layer)); fprintf(stderr, "\t\t"); *(l.uo) = make_connected_layer(batch, steps, inputs, outputs, LINEAR, batch_normalize); l.uo->batch = batch; + if (l.workspace_size < l.uo->workspace_size) l.workspace_size = l.uo->workspace_size; l.wf = malloc(sizeof(layer)); fprintf(stderr, "\t\t"); *(l.wf) = make_connected_layer(batch, steps, outputs, outputs, LINEAR, batch_normalize); l.wf->batch = batch; + if (l.workspace_size < l.wf->workspace_size) l.workspace_size = l.wf->workspace_size; l.wi = malloc(sizeof(layer)); fprintf(stderr, "\t\t"); *(l.wi) = make_connected_layer(batch, steps, outputs, outputs, LINEAR, batch_normalize); l.wi->batch = batch; + if (l.workspace_size < l.wi->workspace_size) l.workspace_size = l.wi->workspace_size; l.wg = malloc(sizeof(layer)); fprintf(stderr, "\t\t"); *(l.wg) = make_connected_layer(batch, steps, outputs, outputs, LINEAR, batch_normalize); l.wg->batch = batch; + if (l.workspace_size < l.wg->workspace_size) l.workspace_size = l.wg->workspace_size; l.wo = malloc(sizeof(layer)); fprintf(stderr, "\t\t"); *(l.wo) = make_connected_layer(batch, steps, outputs, outputs, LINEAR, batch_normalize); l.wo->batch = batch; + if (l.workspace_size < l.wo->workspace_size) l.workspace_size = l.wo->workspace_size; l.batch_normalize = batch_normalize; l.outputs = outputs; @@ -161,6 +169,7 @@ void forward_lstm_layer(layer l, network_state state) { network_state s = { 0 }; s.train = state.train; + s.workspace = state.workspace; int i; layer wf = *(l.wf); layer wi = *(l.wi); @@ -247,6 +256,7 @@ void backward_lstm_layer(layer l, network_state state) { network_state s = { 0 }; s.train = state.train; + s.workspace = state.workspace; int i; layer wf = *(l.wf); layer wi = *(l.wi); @@ -403,6 +413,7 @@ void forward_lstm_layer_gpu(layer l, network_state state) { network_state s = { 0 }; s.train = state.train; + s.workspace = state.workspace; int i; layer wf = *(l.wf); layer wi = *(l.wi); @@ -489,6 +500,7 @@ void backward_lstm_layer_gpu(layer l, network_state state) { network_state s = { 0 }; s.train = state.train; + s.workspace = state.workspace; int i; layer wf = *(l.wf); layer wi = *(l.wi); diff --git a/src/network_kernels.cu b/src/network_kernels.cu index 0076a87b..0e14ce56 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -58,7 +58,7 @@ void forward_network_gpu(network net, network_state state) if(l.delta_gpu && state.train){ fill_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1); } - //printf("%d - type: %d - ", i, l.type); + //printf("\n layer %d - type: %d - \n", i, l.type); //start_timer(); l.forward_gpu(l, state); //cudaDeviceSynchronize(); diff --git a/src/parser.c b/src/parser.c index 92850c95..90594c92 100644 --- a/src/parser.c +++ b/src/parser.c @@ -167,7 +167,7 @@ convolutional_layer parse_convolutional(list *options, size_params params) int xnor = option_find_int_quiet(options, "xnor", 0); int use_bin_output = option_find_int_quiet(options, "bin_output", 0); - convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index); + convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index); layer.flipped = option_find_int_quiet(options, "flipped", 0); layer.dot = option_find_float_quiet(options, "dot", 0); @@ -845,7 +845,6 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps) net.outputs = get_network_output_size(net); net.output = get_network_output(net); printf("Total BFLOPS %5.3f \n", bflops); - //printf("%ld\n", workspace_size); #ifdef GPU get_cuda_stream(); get_cuda_memcpy_stream(); @@ -864,6 +863,7 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps) check_error(cudaMalloc((void **)net.output16_gpu, *net.max_output16_size * sizeof(short))); //sizeof(half) } if (workspace_size) { + printf(" Allocate workspace_size = %zu \n", workspace_size); net.workspace = cuda_make_array(0, workspace_size / sizeof(float) + 1); } else { diff --git a/src/rnn_layer.c b/src/rnn_layer.c index 1b377b14..c1843f06 100644 --- a/src/rnn_layer.c +++ b/src/rnn_layer.c @@ -43,16 +43,19 @@ layer make_rnn_layer(int batch, int inputs, int hidden, int outputs, int steps, fprintf(stderr, "\t\t"); *(l.input_layer) = make_connected_layer(batch, steps, inputs, hidden, activation, batch_normalize); l.input_layer->batch = batch; + if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size; l.self_layer = malloc(sizeof(layer)); fprintf(stderr, "\t\t"); *(l.self_layer) = make_connected_layer(batch, steps, hidden, hidden, (log==2)?LOGGY:(log==1?LOGISTIC:activation), batch_normalize); l.self_layer->batch = batch; + if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size; l.output_layer = malloc(sizeof(layer)); fprintf(stderr, "\t\t"); *(l.output_layer) = make_connected_layer(batch, steps, hidden, outputs, activation, batch_normalize); l.output_layer->batch = batch; + if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size; l.outputs = outputs; l.output = l.output_layer->output; @@ -84,6 +87,7 @@ void forward_rnn_layer(layer l, network_state state) { network_state s = {0}; s.train = state.train; + s.workspace = state.workspace; int i; layer input_layer = *(l.input_layer); layer self_layer = *(l.self_layer); @@ -126,6 +130,7 @@ void backward_rnn_layer(layer l, network_state state) { network_state s = {0}; s.train = state.train; + s.workspace = state.workspace; int i; layer input_layer = *(l.input_layer); layer self_layer = *(l.self_layer); @@ -199,6 +204,7 @@ void forward_rnn_layer_gpu(layer l, network_state state) { network_state s = {0}; s.train = state.train; + s.workspace = state.workspace; int i; layer input_layer = *(l.input_layer); layer self_layer = *(l.self_layer); @@ -241,6 +247,7 @@ void backward_rnn_layer_gpu(layer l, network_state state) { network_state s = {0}; s.train = state.train; + s.workspace = state.workspace; int i; layer input_layer = *(l.input_layer); layer self_layer = *(l.self_layer);