From cad4d1618fee74471d335314cb77070fee951a42 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Sun, 25 Feb 2018 16:29:44 +0300 Subject: [PATCH] Added support for Tensor Cores CC >= 7.0 (V100). For FP16/32 (mixed precision) define CUDNN_HALF should be used. --- build/darknet/darknet.vcxproj | 2 +- src/convolutional_kernels.cu | 122 ++++++++++++++++++++++++++++++---- src/convolutional_layer.c | 8 ++- src/convolutional_layer.h | 1 + src/layer.c | 2 + src/layer.h | 1 + src/network.c | 3 + src/network_kernels.cu | 2 +- 8 files changed, 123 insertions(+), 18 deletions(-) diff --git a/build/darknet/darknet.vcxproj b/build/darknet/darknet.vcxproj index 0ff87992..a6c2b518 100644 --- a/build/darknet/darknet.vcxproj +++ b/build/darknet/darknet.vcxproj @@ -145,7 +145,7 @@ true true true - C:\opencv_3.0\opencv\build\x64\vc14\lib;C:\opencv_2.4.13\opencv\build\x64\vc12\lib;C:\opencv_2.4.13\opencv\build\x64\vc14\lib;$(CUDA_PATH)lib\$(PlatformName);$(cudnn)\lib\x64;%(AdditionalLibraryDirectories) + C:\opencv_3.0\opencv\build\x64\vc14\lib;C:\opencv_2.4.13\opencv\build\x64\vc14\lib;$(CUDA_PATH)lib\$(PlatformName);$(cudnn)\lib\x64;%(AdditionalLibraryDirectories) ..\..\3rdparty\lib\x64\pthreadVC2.lib;cublas.lib;curand.lib;cudart.lib;%(AdditionalDependencies) $(OutDir)\$(TargetName)$(TargetExt) diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index 3b2a349e..9d88a88c 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -81,8 +81,8 @@ __global__ void cuda_f32_to_f16(float* input_f32, size_t size, half *output_f16) //if (idx < size) *((unsigned short *)output_f16 + idx) = __float2half(input_f32[idx]); } -void cuda_convert_f32_to_f16(float* input_f32, size_t size, half *output_f16) { - cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, output_f16); +void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16) { + cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, (half *)output_f16); } __global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32) @@ -92,8 +92,8 @@ __global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32) //if (idx < size) output_f32[idx] = __half2float(*((unsigned short *)input_f16 + idx)); } -void cuda_convert_f16_to_f32(half* input_f16, size_t size, float *output_f32) { - cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f16, size, output_f32); +void cuda_convert_f16_to_f32(float* input_f16, size_t size, float *output_f32) { + cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> ((half *)input_f16, size, output_f32); } half *cuda_make_f16_from_f32_array(float *src, size_t n) @@ -102,7 +102,7 @@ half *cuda_make_f16_from_f32_array(float *src, size_t n) size_t size = sizeof(half)*n; check_error(cudaMalloc((void **)&dst16, size)); if (src) { - cuda_convert_f32_to_f16(src, n, dst16); + cuda_convert_f32_to_f16(src, n, (float *)dst16); } if (!dst16) error("Cuda malloc failed\n"); return dst16; @@ -124,8 +124,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) } #ifdef CUDNN - //float one = 1; // alpha[0], beta[0] is float for HALF and FLOAT - float alpha = 1, beta = 0; + float one = 1; // alpha[0], beta[0] is float for HALF and FLOAT + float alpha = 1, beta = 0; #ifdef CUDNN_HALF // Note: For improved performance it is advised to use beta[0] = 0.0. @@ -154,8 +154,9 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) output16 = cuda_make_f16_from_f32_array(NULL, max_output16_size); } - cuda_convert_f32_to_f16(state.input, input16_size, input16); + cuda_convert_f32_to_f16(state.input, input16_size, (float *)input16); + //fill_ongpu(output16_size / 2, 0, (float *)output16, 1); cudnnConvolutionForward(cudnn_handle(), &alpha, l.srcTensorDesc, @@ -170,11 +171,12 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) l.dstTensorDesc, output16); - cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu); + cuda_convert_f16_to_f32((float *)output16, output16_size, l.output_gpu); + #else cudnnConvolutionForward(cudnn_handle(), - &alpha, + &one, l.srcTensorDesc, state.input, l.weightDesc, @@ -183,7 +185,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) l.fw_algo, state.workspace, l.workspace_size, - &beta, + &one, l.dstTensorDesc, l.output_gpu); #endif @@ -230,7 +232,88 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state if(l.xnor) state.input = l.binary_input_gpu; #ifdef CUDNN - float one = 1; + float one = 1; + float alpha = 1, beta = 0; + +#ifdef CUDNN_HALF + + const size_t input16_size = l.batch*l.c*l.w*l.h; + static size_t max_input16_size = input16_size; + static half* input16 = cuda_make_f16_from_f32_array(NULL, max_input16_size); + + const size_t delta16_size = l.batch*l.n*l.out_w*l.out_h; + static size_t max_delta16_size = delta16_size; + static half* delta16 = cuda_make_f16_from_f32_array(NULL, max_delta16_size); + + if (max_input16_size < input16_size) { + max_input16_size = input16_size; + cuda_free((float *)input16); + input16 = cuda_make_f16_from_f32_array(state.input, max_input16_size); + } + + if (max_delta16_size < delta16_size) { + max_delta16_size = delta16_size; + cuda_free((float *)delta16); + delta16 = cuda_make_f16_from_f32_array(NULL, max_delta16_size); + } + + cuda_convert_f32_to_f16(state.input, input16_size, (float *)input16); + cuda_convert_f32_to_f16(l.delta_gpu, delta16_size, (float *)delta16); + + // convert input: state.input (x), l.delta_gpu (y) from fp32 to fp16 + // get output: l.weight_updates_gpu (dw) and convert it to fp32 (ONLY if it is fp16) + + // calculate conv weight updates + // Already: l.weight_updates_gpu = (l.weight_updates_gpu - l.weight*decay*batch*subdivision)*momentum + // so we should copy f32 to f16, or compute: f16=(w_up - w*d*b*s)*m + cuda_convert_f32_to_f16(l.weight_updates_gpu, l.c*l.n*l.size*l.size, l.weight_updates_gpu16); + + cudnnConvolutionBackwardFilter(cudnn_handle(), + &one, + l.srcTensorDesc, + input16, //state.input, + l.ddstTensorDesc, + delta16, //l.delta_gpu, + l.convDesc, + l.bf_algo, + state.workspace, + l.workspace_size, + &one, + l.dweightDesc, + l.weight_updates_gpu16); // l.weight_updates_gpu); + + cuda_convert_f16_to_f32(l.weight_updates_gpu16, l.c*l.n*l.size*l.size, l.weight_updates_gpu); + + if (state.delta) { + if (l.binary || l.xnor) swap_binary(&l); + + // http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData + // calculate delta for the next layer + // convert input: l.weights_gpu (w), l.delta_gpu (dy) from fp32 to fp16 + // get output: state.delta (dx) and convert it to fp32 (ONLY if it is fp16) + cudnnConvolutionBackwardData(cudnn_handle(), + &alpha, + l.weightDesc, + l.weights_gpu16, //l.weights_gpu, + l.ddstTensorDesc, + delta16, //l.delta_gpu, + l.convDesc, + l.bd_algo, + state.workspace, + l.workspace_size, + &beta, + l.dsrcTensorDesc, + input16); // state.delta); + + cuda_convert_f16_to_f32((float *)input16, input16_size, state.delta); + + if (l.binary || l.xnor) swap_binary(&l); + if (l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta); + } +#else // CUDNN_HALF + + // calculate conv weight updates + // if used: beta=1 then loss decreases faster cudnnConvolutionBackwardFilter(cudnn_handle(), &one, l.srcTensorDesc, @@ -248,6 +331,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state if(state.delta){ if(l.binary || l.xnor) swap_binary(&l); // http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData + // calculate delta for the next layer cudnnConvolutionBackwardData(cudnn_handle(), &one, l.weightDesc, @@ -265,7 +349,9 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state if(l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta); } -#else +#endif // CUDNN_HALF + +#else // CUDNN int m = l.n; int n = l.size*l.size*l.c; int k = l.out_w*l.out_h; @@ -318,7 +404,7 @@ void push_convolutional_layer(convolutional_layer layer) { cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size); #ifdef CUDNN_HALF - cuda_convert_f32_to_f16(layer.weights_gpu, layer.c*layer.n*layer.size*layer.size, (half *)layer.weights_gpu16); + cuda_convert_f32_to_f16(layer.weights_gpu, layer.c*layer.n*layer.size*layer.size, layer.weights_gpu16); #endif cuda_push_array(layer.biases_gpu, layer.biases, layer.n); cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size); @@ -358,6 +444,14 @@ void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float adam_gpu(size, layer.weights_gpu, layer.m_gpu, layer.v_gpu, layer.B1, layer.B2, learning_rate/batch, layer.eps, layer.t+1); fill_ongpu(size, 0, layer.weight_updates_gpu, 1); }else{ + // update weights: + // weights_gpu = weights_gpu*(1 - decay*lr) + weight_updates_gpu*lr / (batch*subdivision) = + // weights_gpu*(1 - 0.0005*0.001) + weight_updates_gpu*0.001/(64*8) = + // weights_gpu * 0.999 999 5 + weight_updates_gpu * 0.000 001 953125 + // + // weight_updates_gpu = (weight_updates_gpu - weights_gpu*decay*batch*subdivision)*momentum = + // (weight_updates_gpu - weights_gpu * 0.0005 * 64 * 8) * 0.9 = + // weight_updates_gpu*0.9 - weights_gpu*0.2304 axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1); axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1); scal_ongpu(size, momentum, layer.weight_updates_gpu, 1); diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index d35246ed..377b898e 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -141,7 +141,8 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference) { #ifdef CUDNN_HALF - // TRUE_HALF_CONFIG is only supported on architectures with true fp16 support (compute capability 5.3 and 6.0): Tegra X1, Jetson TX1, DRIVE CX, DRIVE PX, Quadro GP100, Tesla P100 + // TRUE_HALF_CONFIG is only supported on architectures with true fp16 support (compute capability 5.3 and 6.0): + // Tegra X1, Jetson TX1, DRIVE CX, DRIVE PX, Quadro GP100, Tesla P100 // PSEUDO_HALF_CONFIG is required for Tensor Cores - our case! const cudnnDataType_t data_type = CUDNN_DATA_HALF; #else @@ -164,10 +165,12 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference) // on architectures with DP4A support (compute capability 6.1 and later). //cudnnDataType_t data_type = CUDNN_DATA_INT8; + // backward delta cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w); cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w); cudnnSetFilter4dDescriptor(l->dweightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); + // forward cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w); cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w); cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); @@ -302,7 +305,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int l.weights_gpu = cuda_make_array(l.weights, c*n*size*size); #ifdef CUDNN_HALF - l.weights_gpu16 = cuda_make_array(l.weights, c*n*size*size/2); + l.weights_gpu16 = cuda_make_array(l.weights, c*n*size*size / 2); + l.weight_updates_gpu16 = cuda_make_array(l.weight_updates, c*n*size*size / 2); #endif l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size); diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index da98dcef..6d1e5177 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -21,6 +21,7 @@ void add_bias_gpu(float *output, float *biases, int batch, int n, int size); void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size); #ifdef CUDNN void cudnn_convolutional_setup(layer *l, int cudnn_preference); +void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16); #endif #endif diff --git a/src/layer.c b/src/layer.c index b88c9412..582cbb39 100644 --- a/src/layer.c +++ b/src/layer.c @@ -83,6 +83,8 @@ void free_layer(layer l) if (l.x_norm_gpu) cuda_free(l.x_norm_gpu); if (l.weights_gpu) cuda_free(l.weights_gpu); if (l.weight_updates_gpu) cuda_free(l.weight_updates_gpu); + if (l.weights_gpu16) cuda_free(l.weights_gpu16); + if (l.weight_updates_gpu16) cuda_free(l.weight_updates_gpu16); if (l.biases_gpu) cuda_free(l.biases_gpu); if (l.bias_updates_gpu) cuda_free(l.bias_updates_gpu); if (l.scales_gpu) cuda_free(l.scales_gpu); diff --git a/src/layer.h b/src/layer.h index 0f5addac..93aca6c7 100644 --- a/src/layer.h +++ b/src/layer.h @@ -243,6 +243,7 @@ struct layer{ float * weight_updates_gpu; float * weights_gpu16; + float * weight_updates_gpu16; float * biases_gpu; float * bias_updates_gpu; diff --git a/src/network.c b/src/network.c index c906b585..d23468de 100644 --- a/src/network.c +++ b/src/network.c @@ -316,6 +316,8 @@ void set_batch_network(network *net, int b) net->layers[i].batch = b; #ifdef CUDNN if(net->layers[i].type == CONVOLUTIONAL){ + cudnn_convolutional_setup(net->layers + i, cudnn_fastest); + /* layer *l = net->layers + i; cudnn_convolutional_setup(l, cudnn_fastest); // check for excessive memory consumption @@ -327,6 +329,7 @@ void set_batch_network(network *net, int b) cudnn_convolutional_setup(l, cudnn_smallest); l->workspace_size = get_workspace_size(*l); } + */ } #endif } diff --git a/src/network_kernels.cu b/src/network_kernels.cu index 6090bb09..503a1b8d 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -117,7 +117,7 @@ void forward_backward_network_gpu(network net, float *x, float *y) int i; for (i = 0; i < net.n; ++i) { layer l = net.layers[i]; - cuda_convert_f32_to_f16(l.weights_gpu, l.c*l.n*l.size*l.size, (half *)l.weights_gpu16); + cuda_convert_f32_to_f16(l.weights_gpu, l.c*l.n*l.size*l.size, l.weights_gpu16); } #endif forward_network_gpu(net, state);