mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Accelerated by another 5% using FP16/32 Batch-norm for Tensor Cores.
This commit is contained in:
1
Makefile
1
Makefile
@ -91,6 +91,7 @@ endif
|
||||
ifeq ($(CUDNN_HALF), 1)
|
||||
COMMON+= -DCUDNN_HALF
|
||||
CFLAGS+= -DCUDNN_HALF
|
||||
ARCH+= -gencode arch=compute_70,code=[sm_70,compute_70]
|
||||
endif
|
||||
|
||||
OBJ=http_stream.o gemm.o utils.o cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o reorg_old_layer.o super.o voxel.o tree.o yolo_layer.o upsample_layer.o
|
||||
|
@ -190,18 +190,18 @@ void forward_batchnorm_layer_gpu(layer l, network_state state)
|
||||
&one,
|
||||
&zero,
|
||||
l.normDstTensorDesc,
|
||||
l.x_gpu,
|
||||
l.x_gpu, // input
|
||||
l.normDstTensorDesc,
|
||||
l.output_gpu,
|
||||
l.output_gpu, // output
|
||||
l.normTensorDesc,
|
||||
l.scales_gpu,
|
||||
l.biases_gpu,
|
||||
.01,
|
||||
l.rolling_mean_gpu,
|
||||
l.rolling_variance_gpu,
|
||||
l.rolling_mean_gpu, // output (should be FP32)
|
||||
l.rolling_variance_gpu, // output (should be FP32)
|
||||
.00001,
|
||||
l.mean_gpu,
|
||||
l.variance_gpu);
|
||||
l.mean_gpu, // output (should be FP32)
|
||||
l.variance_gpu); // output (should be FP32)
|
||||
#else
|
||||
fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu);
|
||||
fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.variance_gpu);
|
||||
@ -243,18 +243,18 @@ void backward_batchnorm_layer_gpu(layer l, network_state state)
|
||||
&one,
|
||||
&one,
|
||||
l.normDstTensorDesc,
|
||||
l.x_gpu,
|
||||
l.x_gpu, // input
|
||||
l.normDstTensorDesc,
|
||||
l.delta_gpu,
|
||||
l.delta_gpu, // input
|
||||
l.normDstTensorDesc,
|
||||
l.x_norm_gpu,
|
||||
l.x_norm_gpu, // output
|
||||
l.normTensorDesc,
|
||||
l.scales_gpu,
|
||||
l.scale_updates_gpu,
|
||||
l.bias_updates_gpu,
|
||||
l.scales_gpu, // output (should be FP32)
|
||||
l.scale_updates_gpu, // output (should be FP32)
|
||||
l.bias_updates_gpu, // output (should be FP32)
|
||||
.00001,
|
||||
l.mean_gpu,
|
||||
l.variance_gpu);
|
||||
l.mean_gpu, // input (should be FP32)
|
||||
l.variance_gpu); // input (should be FP32)
|
||||
copy_ongpu(l.outputs*l.batch, l.x_norm_gpu, 1, l.delta_gpu, 1);
|
||||
#else
|
||||
backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.out_c, l.out_w*l.out_h);
|
||||
|
@ -169,7 +169,51 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
l.dstTensorDesc,
|
||||
output16);
|
||||
|
||||
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
||||
|
||||
if (l.batch_normalize)
|
||||
{
|
||||
if (state.train) // Training
|
||||
{
|
||||
copy_ongpu(l.outputs*l.batch / 2, output16, 1, l.x_gpu, 1);
|
||||
//cudaMemcpyAsync(l.x_gpu, output16, l.outputs*l.batch*sizeof(half), cudaMemcpyDefault, get_cuda_stream());
|
||||
float one = 1;
|
||||
float zero = 0;
|
||||
// Batch-normalization can still take FP16 inputs and outputs, saving half the bandwidth
|
||||
// compared to FP32, it<69>s just that the statistics and value adjustment should be done in FP32.
|
||||
cudnnBatchNormalizationForwardTraining(cudnn_handle(),
|
||||
CUDNN_BATCHNORM_SPATIAL,
|
||||
&one,
|
||||
&zero,
|
||||
l.normDstTensorDescF16,
|
||||
l.x_gpu, // input
|
||||
l.normDstTensorDescF16,
|
||||
output16, // output
|
||||
l.normTensorDesc,
|
||||
l.scales_gpu,
|
||||
l.biases_gpu,
|
||||
.01,
|
||||
l.rolling_mean_gpu, // output (should be FP32)
|
||||
l.rolling_variance_gpu, // output (should be FP32)
|
||||
.00001,
|
||||
l.mean_gpu, // output (should be FP32)
|
||||
l.variance_gpu); // output (should be FP32)
|
||||
|
||||
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
||||
//forward_batchnorm_layer_gpu(l, state);
|
||||
}
|
||||
else // Detection
|
||||
{
|
||||
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
||||
normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||
scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h);
|
||||
}
|
||||
}
|
||||
else // BIAS only
|
||||
{
|
||||
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
@ -186,7 +230,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
&one,
|
||||
l.dstTensorDesc,
|
||||
l.output_gpu);
|
||||
#endif
|
||||
#endif // CUDNN_HALF
|
||||
|
||||
|
||||
#else
|
||||
@ -203,12 +247,14 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef CUDNN_HALF
|
||||
if (l.batch_normalize) {
|
||||
forward_batchnorm_layer_gpu(l, state);
|
||||
}
|
||||
else {
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
}
|
||||
#endif // no CUDNN_HALF
|
||||
|
||||
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||
//if(l.dot > 0) dot_error_gpu(l);
|
||||
@ -222,12 +268,13 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
|
||||
backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
|
||||
#ifndef CUDNN_HALF
|
||||
if(l.batch_normalize){
|
||||
backward_batchnorm_layer_gpu(l, state);
|
||||
//axpy_ongpu(l.outputs*l.batch, -state.net.decay, l.x_gpu, 1, l.delta_gpu, 1);
|
||||
} else {
|
||||
//axpy_ongpu(l.outputs*l.batch, -state.net.decay, l.output_gpu, 1, l.delta_gpu, 1);
|
||||
//backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
}
|
||||
#endif // no CUDNN_HALF
|
||||
float *original_input = state.input;
|
||||
|
||||
if(l.xnor) state.input = l.binary_input_gpu;
|
||||
@ -256,7 +303,41 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
|
||||
cuda_convert_f32_to_f16(state.input, input16_size, input16);
|
||||
cuda_convert_f32_to_f16(l.delta_gpu, delta16_size, delta16);
|
||||
|
||||
|
||||
if (l.batch_normalize) {
|
||||
//if (!state.train) {
|
||||
// l.mean_gpu = l.rolling_mean_gpu;
|
||||
// l.variance_gpu = l.rolling_variance_gpu;
|
||||
//}
|
||||
float one = 1;
|
||||
float zero = 0;
|
||||
cudnnBatchNormalizationBackward(cudnn_handle(),
|
||||
CUDNN_BATCHNORM_SPATIAL,
|
||||
&one,
|
||||
&zero,
|
||||
&one,
|
||||
&one,
|
||||
l.normDstTensorDescF16,
|
||||
l.x_gpu, // input
|
||||
l.normDstTensorDescF16,
|
||||
delta16, // input
|
||||
l.normDstTensorDescF16,
|
||||
l.x_norm_gpu, // output
|
||||
l.normTensorDesc,
|
||||
l.scales_gpu, // output (should be FP32)
|
||||
l.scale_updates_gpu, // output (should be FP32)
|
||||
l.bias_updates_gpu, // output (should be FP32)
|
||||
.00001,
|
||||
l.mean_gpu, // input (should be FP32)
|
||||
l.variance_gpu); // input (should be FP32)
|
||||
copy_ongpu(l.outputs*l.batch / 2, l.x_norm_gpu, 1, delta16, 1);
|
||||
//cudaMemcpyAsync(delta16, l.x_norm_gpu, l.outputs*l.batch * sizeof(half), cudaMemcpyDefault, get_cuda_stream());
|
||||
}
|
||||
else
|
||||
{
|
||||
//backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
}
|
||||
|
||||
// convert input: state.input (x), l.delta_gpu (y) from fp32 to fp16
|
||||
// get output: l.weight_updates_gpu (dw) and convert it to fp32 (ONLY if it is fp16)
|
||||
|
||||
|
@ -178,6 +178,8 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
|
||||
// batch norm
|
||||
cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1);
|
||||
cudnnSetTensor4dDescriptor(l->normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
|
||||
|
||||
cudnnSetTensor4dDescriptor(l->normDstTensorDescF16, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
|
||||
#if(CUDNN_MAJOR >= 6)
|
||||
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT); // cudnn >= 6.0
|
||||
#else
|
||||
@ -379,6 +381,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
}
|
||||
#ifdef CUDNN
|
||||
cudnnCreateTensorDescriptor(&l.normDstTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.normDstTensorDescF16);
|
||||
cudnnCreateTensorDescriptor(&l.normTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.srcTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
|
||||
|
@ -281,7 +281,7 @@ struct layer{
|
||||
#ifdef CUDNN
|
||||
cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;
|
||||
cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc;
|
||||
cudnnTensorDescriptor_t normTensorDesc, normDstTensorDesc;
|
||||
cudnnTensorDescriptor_t normTensorDesc, normDstTensorDesc, normDstTensorDescF16;
|
||||
cudnnFilterDescriptor_t weightDesc;
|
||||
cudnnFilterDescriptor_t dweightDesc;
|
||||
cudnnConvolutionDescriptor_t convDesc;
|
||||
|
Reference in New Issue
Block a user