mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Use half_float16 instead of float32 if defined both CUDNN and CUDNN_HALF. Use Tensor Cores.
This commit is contained in:
@ -74,6 +74,38 @@ void binarize_weights_gpu(float *weights, int n, int size, float *binary)
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
__global__ void cuda_f32_to_f16(float* input_f32, size_t size, half *output_f16)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < size) output_f16[idx] = input_f32[idx];
|
||||
}
|
||||
|
||||
void cuda_convert_f32_to_f16(float* input_f32, size_t size, half *output_f16) {
|
||||
cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, output_f16);
|
||||
}
|
||||
|
||||
__global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < size) output_f32[idx] = input_f16[idx];
|
||||
}
|
||||
|
||||
void cuda_convert_f16_to_f32(half* input_f16, size_t size, float *output_f32) {
|
||||
cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f16, size, output_f32);
|
||||
}
|
||||
|
||||
half *cuda_make_f16_from_f32_array(float *src, size_t n)
|
||||
{
|
||||
half *dst16;
|
||||
size_t size = sizeof(half)*n;
|
||||
check_error(cudaMalloc((void **)&dst16, size));
|
||||
if (src) {
|
||||
cuda_convert_f32_to_f16(src, n, dst16);
|
||||
}
|
||||
if (!dst16) error("Cuda malloc failed\n");
|
||||
return dst16;
|
||||
}
|
||||
|
||||
void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
{
|
||||
fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
|
||||
@ -90,9 +122,57 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
}
|
||||
|
||||
#ifdef CUDNN
|
||||
float one = 1;
|
||||
//float one = 1; // alpha[0], beta[0] is float for HALF and FLOAT
|
||||
float alpha = 1, beta = 0;
|
||||
|
||||
#ifdef CUDNN_HALF
|
||||
// Note: For improved performance it is advised to use beta[0] = 0.0.
|
||||
// For Tensor Core: cudnnSetConvolutionMathType() where cudnnMathType_t mathType = CUDNN_TENSOR_OP_MATH;
|
||||
// 1. or CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM and use CUDNN_DATA_HALF
|
||||
// 2. or CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED
|
||||
// More: http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#tensor_ops
|
||||
|
||||
const size_t input16_size = l.batch*l.c*l.w*l.h;
|
||||
static size_t max_input16_size = input16_size;
|
||||
static half* input16 = cuda_make_f16_from_f32_array(NULL, max_input16_size);
|
||||
|
||||
const size_t output16_size = l.batch*l.out_c*l.out_h*l.out_w;
|
||||
static size_t max_output16_size = output16_size;
|
||||
static half* output16 = cuda_make_f16_from_f32_array(NULL, max_output16_size);
|
||||
|
||||
if (max_input16_size < input16_size) {
|
||||
max_input16_size = input16_size;
|
||||
cuda_free((float *)input16);
|
||||
input16 = cuda_make_f16_from_f32_array(state.input, max_input16_size);
|
||||
}
|
||||
|
||||
if (max_output16_size < output16_size) {
|
||||
max_output16_size = output16_size;
|
||||
cuda_free((float *)output16);
|
||||
output16 = cuda_make_f16_from_f32_array(NULL, max_output16_size);
|
||||
}
|
||||
|
||||
cuda_convert_f32_to_f16(state.input, input16_size, input16);
|
||||
|
||||
cudnnConvolutionForward(cudnn_handle(),
|
||||
&alpha,
|
||||
l.srcTensorDesc,
|
||||
input16,
|
||||
l.weightDesc,
|
||||
l.weights_gpu16,
|
||||
l.convDesc,
|
||||
l.fw_algo,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&beta,
|
||||
l.dstTensorDesc,
|
||||
output16);
|
||||
|
||||
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
||||
#else
|
||||
|
||||
cudnnConvolutionForward(cudnn_handle(),
|
||||
&one,
|
||||
&alpha,
|
||||
l.srcTensorDesc,
|
||||
state.input,
|
||||
l.weightDesc,
|
||||
@ -101,9 +181,11 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
l.fw_algo,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&one,
|
||||
&beta,
|
||||
l.dstTensorDesc,
|
||||
l.output_gpu);
|
||||
#endif
|
||||
|
||||
|
||||
#else
|
||||
int i;
|
||||
@ -232,6 +314,9 @@ void pull_convolutional_layer(convolutional_layer layer)
|
||||
void push_convolutional_layer(convolutional_layer layer)
|
||||
{
|
||||
cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
|
||||
#ifdef CUDNN_HALF
|
||||
cuda_convert_f32_to_f16(layer.weights_gpu, layer.c*layer.n*layer.size*layer.size, (half *)layer.weights_gpu16);
|
||||
#endif
|
||||
cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
|
||||
cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
|
||||
|
Reference in New Issue
Block a user