mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Fixed nan issue for training with CUDNN_HALF=1 by using Tensor Cores
This commit is contained in:
@ -138,7 +138,9 @@ void fast_binarize_weights_gpu(float *weights, int n, int size, float *binary, f
|
||||
__global__ void cuda_f32_to_f16(float* input_f32, size_t size, half *output_f16)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < size) output_f16[idx] = __float2half(input_f32[idx]);
|
||||
//if (idx < size) output_f16[idx] = __float2half(input_f32[idx]);
|
||||
if (idx < size) output_f16[idx] = __float2half_rn(input_f32[idx]);
|
||||
// __float2half_ru, __float2half_rd, __float2half_rz, __float2half_rn
|
||||
//if (idx < size) *((unsigned short *)output_f16 + idx) = __float2half(input_f32[idx]);
|
||||
}
|
||||
|
||||
@ -290,113 +292,128 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
float one = 1; // alpha[0], beta[0] is float for HALF and FLOAT
|
||||
float alpha = 1, beta = 0;
|
||||
|
||||
#ifdef CUDNN_HALF
|
||||
// Note: For improved performance it is advised to use beta[0] = 0.0.
|
||||
// For Tensor Core: cudnnSetConvolutionMathType() where cudnnMathType_t mathType = CUDNN_TENSOR_OP_MATH;
|
||||
// 1. or CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM and use CUDNN_DATA_HALF
|
||||
// 2. or CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED
|
||||
// More: http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#tensor_ops
|
||||
|
||||
const size_t input16_size = l.batch*l.c*l.w*l.h;
|
||||
const size_t output16_size = l.batch*l.out_c*l.out_h*l.out_w;
|
||||
|
||||
if (*state.net.max_input16_size < input16_size) {
|
||||
//printf("\n input16_size: cur = %zu \t max = %zu \n", input16_size, *state.net.max_input16_size);
|
||||
*state.net.max_input16_size = input16_size;
|
||||
if (*state.net.input16_gpu) cuda_free(*state.net.input16_gpu);
|
||||
*state.net.input16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_input16_size);
|
||||
}
|
||||
float *input16 = *state.net.input16_gpu;
|
||||
|
||||
if (*state.net.max_output16_size < output16_size) {
|
||||
*state.net.max_output16_size = output16_size;
|
||||
if (*state.net.output16_gpu) cuda_free(*state.net.output16_gpu);
|
||||
*state.net.output16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_output16_size);
|
||||
}
|
||||
float *output16 = *state.net.output16_gpu;
|
||||
|
||||
cuda_convert_f32_to_f16(state.input, input16_size, input16);
|
||||
|
||||
//fill_ongpu(output16_size / 2, 0, (float *)output16, 1);
|
||||
cudnnConvolutionForward(cudnn_handle(),
|
||||
&alpha,
|
||||
l.srcTensorDesc,
|
||||
input16,
|
||||
l.weightDesc,
|
||||
l.weights_gpu16,
|
||||
l.convDesc,
|
||||
l.fw_algo,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&beta,
|
||||
l.dstTensorDesc,
|
||||
output16);
|
||||
|
||||
|
||||
if (l.batch_normalize)
|
||||
//#ifdef CUDNN_HALF
|
||||
//if (state.use_mixed_precision) {
|
||||
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
if(state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > state.net.burn_in))
|
||||
{
|
||||
if (state.train) // Training
|
||||
{
|
||||
copy_ongpu(l.outputs*l.batch / 2, output16, 1, l.x_gpu, 1);
|
||||
//cudaMemcpyAsync(l.x_gpu, output16, l.outputs*l.batch*sizeof(half), cudaMemcpyDefault, get_cuda_stream());
|
||||
float one = 1;
|
||||
float zero = 0;
|
||||
// Batch-normalization can still take FP16 inputs and outputs, saving half the bandwidth
|
||||
// compared to FP32, it<69>s just that the statistics and value adjustment should be done in FP32.
|
||||
cudnnBatchNormalizationForwardTraining(cudnn_handle(),
|
||||
CUDNN_BATCHNORM_SPATIAL,
|
||||
&one,
|
||||
&zero,
|
||||
l.normDstTensorDescF16,
|
||||
l.x_gpu, // input
|
||||
l.normDstTensorDescF16,
|
||||
output16, // output
|
||||
l.normTensorDesc,
|
||||
l.scales_gpu,
|
||||
l.biases_gpu,
|
||||
.01,
|
||||
l.rolling_mean_gpu, // output (should be FP32)
|
||||
l.rolling_variance_gpu, // output (should be FP32)
|
||||
.00001,
|
||||
l.mean_gpu, // output (should be FP32)
|
||||
l.variance_gpu); // output (should be FP32)
|
||||
//printf("\n CUDNN_HALF!!! state.index = %d \n", state.index);
|
||||
|
||||
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
||||
//forward_batchnorm_layer_gpu(l, state);
|
||||
// Note: For improved performance it is advised to use beta[0] = 0.0.
|
||||
// For Tensor Core: cudnnSetConvolutionMathType() where cudnnMathType_t mathType = CUDNN_TENSOR_OP_MATH;
|
||||
// 1. or CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM and use CUDNN_DATA_HALF
|
||||
// 2. or CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED
|
||||
// More: http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#tensor_ops
|
||||
|
||||
const size_t input16_size = l.batch*l.c*l.w*l.h;
|
||||
const size_t output16_size = l.batch*l.out_c*l.out_h*l.out_w;
|
||||
|
||||
if (*state.net.max_input16_size < input16_size) {
|
||||
//printf("\n input16_size: cur = %zu \t max = %zu \n", input16_size, *state.net.max_input16_size);
|
||||
*state.net.max_input16_size = input16_size;
|
||||
if (*state.net.input16_gpu) cuda_free(*state.net.input16_gpu);
|
||||
*state.net.input16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_input16_size);
|
||||
}
|
||||
else // Detection
|
||||
float *input16 = *state.net.input16_gpu;
|
||||
|
||||
if (*state.net.max_output16_size < output16_size) {
|
||||
*state.net.max_output16_size = output16_size;
|
||||
if (*state.net.output16_gpu) cuda_free(*state.net.output16_gpu);
|
||||
*state.net.output16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_output16_size);
|
||||
}
|
||||
float *output16 = *state.net.output16_gpu;
|
||||
|
||||
cuda_convert_f32_to_f16(state.input, input16_size, input16);
|
||||
|
||||
//fill_ongpu(output16_size / 2, 0, (float *)output16, 1);
|
||||
cudnnConvolutionForward(cudnn_handle(),
|
||||
&alpha,
|
||||
l.srcTensorDesc16,
|
||||
input16,
|
||||
l.weightDesc16,
|
||||
l.weights_gpu16,
|
||||
l.convDesc,
|
||||
l.fw_algo16,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&beta,
|
||||
l.dstTensorDesc16,
|
||||
output16);
|
||||
|
||||
|
||||
if (l.batch_normalize)
|
||||
{
|
||||
if (state.train) // Training
|
||||
{
|
||||
copy_ongpu(l.outputs*l.batch / 2, output16, 1, l.x_gpu, 1);
|
||||
//cudaMemcpyAsync(l.x_gpu, output16, l.outputs*l.batch*sizeof(half), cudaMemcpyDefault, get_cuda_stream());
|
||||
float one = 1;
|
||||
float zero = 0;
|
||||
// Batch-normalization can still take FP16 inputs and outputs, saving half the bandwidth
|
||||
// compared to FP32, it<69>s just that the statistics and value adjustment should be done in FP32.
|
||||
cudnnBatchNormalizationForwardTraining(cudnn_handle(),
|
||||
CUDNN_BATCHNORM_SPATIAL,
|
||||
&one,
|
||||
&zero,
|
||||
l.normDstTensorDescF16,
|
||||
l.x_gpu, // input
|
||||
l.normDstTensorDescF16,
|
||||
output16, // output
|
||||
l.normTensorDesc,
|
||||
l.scales_gpu,
|
||||
l.biases_gpu,
|
||||
.01,
|
||||
l.rolling_mean_gpu, // output (should be FP32)
|
||||
l.rolling_variance_gpu, // output (should be FP32)
|
||||
.00001,
|
||||
l.mean_gpu, // output (should be FP32)
|
||||
l.variance_gpu); // output (should be FP32)
|
||||
|
||||
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
||||
//forward_batchnorm_layer_gpu(l, state);
|
||||
}
|
||||
else // Detection
|
||||
{
|
||||
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
||||
normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||
scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h);
|
||||
}
|
||||
}
|
||||
else // BIAS only
|
||||
{
|
||||
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
||||
normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||
scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h);
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
}
|
||||
}
|
||||
else // BIAS only
|
||||
{
|
||||
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
else {
|
||||
|
||||
//#else
|
||||
|
||||
cudnnConvolutionForward(cudnn_handle(),
|
||||
&alpha, //&one,
|
||||
l.srcTensorDesc,
|
||||
state.input,
|
||||
l.weightDesc,
|
||||
l.weights_gpu,
|
||||
l.convDesc,
|
||||
l.fw_algo,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&beta, //&one,
|
||||
l.dstTensorDesc,
|
||||
l.output_gpu);
|
||||
|
||||
//cudaDeviceSynchronize();
|
||||
if (l.batch_normalize) {
|
||||
forward_batchnorm_layer_gpu(l, state);
|
||||
}
|
||||
else {
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
}
|
||||
//#endif // CUDNN_HALF
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
cudnnConvolutionForward(cudnn_handle(),
|
||||
&alpha, //&one,
|
||||
l.srcTensorDesc,
|
||||
state.input,
|
||||
l.weightDesc,
|
||||
l.weights_gpu,
|
||||
l.convDesc,
|
||||
l.fw_algo,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&beta, //&one,
|
||||
l.dstTensorDesc,
|
||||
l.output_gpu);
|
||||
|
||||
//cudaDeviceSynchronize();
|
||||
#endif // CUDNN_HALF
|
||||
|
||||
|
||||
#else
|
||||
fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
|
||||
@ -418,16 +435,17 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
}
|
||||
gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef CUDNN_HALF
|
||||
if (l.batch_normalize) {
|
||||
forward_batchnorm_layer_gpu(l, state);
|
||||
}
|
||||
else {
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
}
|
||||
#endif // no CUDNN_HALF
|
||||
#endif
|
||||
|
||||
//#ifndef CUDNN_HALF
|
||||
//#endif // no CUDNN_HALF
|
||||
|
||||
if (l.activation != LINEAR) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||
//if(l.dot > 0) dot_error_gpu(l);
|
||||
@ -441,13 +459,13 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
|
||||
backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
|
||||
#ifndef CUDNN_HALF
|
||||
if(l.batch_normalize){
|
||||
backward_batchnorm_layer_gpu(l, state);
|
||||
} else {
|
||||
//backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
}
|
||||
#endif // no CUDNN_HALF
|
||||
//#ifndef CUDNN_HALF
|
||||
//if(l.batch_normalize){
|
||||
// backward_batchnorm_layer_gpu(l, state);
|
||||
//} else {
|
||||
// //backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
//}
|
||||
//#endif // no CUDNN_HALF
|
||||
float *original_input = state.input;
|
||||
|
||||
if(l.xnor) state.input = l.binary_input_gpu;
|
||||
@ -455,117 +473,126 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
float one = 1;
|
||||
float alpha = 1, beta = 0;
|
||||
|
||||
#ifdef CUDNN_HALF
|
||||
|
||||
const size_t input16_size = l.batch*l.c*l.w*l.h;
|
||||
const size_t delta16_size = l.batch*l.n*l.out_w*l.out_h;
|
||||
|
||||
if (*state.net.max_input16_size < input16_size) {
|
||||
*state.net.max_input16_size = input16_size;
|
||||
if(*state.net.input16_gpu) cuda_free(*state.net.input16_gpu);
|
||||
*state.net.input16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_input16_size);
|
||||
}
|
||||
float *input16 = *state.net.input16_gpu;
|
||||
|
||||
if (*state.net.max_output16_size < delta16_size) {
|
||||
*state.net.max_output16_size = delta16_size;
|
||||
if(*state.net.output16_gpu) cuda_free(*state.net.output16_gpu);
|
||||
*state.net.output16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_output16_size);
|
||||
}
|
||||
float *delta16 = *state.net.output16_gpu;
|
||||
|
||||
cuda_convert_f32_to_f16(state.input, input16_size, input16);
|
||||
cuda_convert_f32_to_f16(l.delta_gpu, delta16_size, delta16);
|
||||
|
||||
if (l.batch_normalize) {
|
||||
//if (!state.train) {
|
||||
// l.mean_gpu = l.rolling_mean_gpu;
|
||||
// l.variance_gpu = l.rolling_variance_gpu;
|
||||
//}
|
||||
float one = 1;
|
||||
float zero = 0;
|
||||
cudnnBatchNormalizationBackward(cudnn_handle(),
|
||||
CUDNN_BATCHNORM_SPATIAL,
|
||||
&one,
|
||||
&zero,
|
||||
&one,
|
||||
&one,
|
||||
l.normDstTensorDescF16,
|
||||
l.x_gpu, // input
|
||||
l.normDstTensorDescF16,
|
||||
delta16, // input
|
||||
l.normDstTensorDescF16,
|
||||
l.x_norm_gpu, // output
|
||||
l.normTensorDesc,
|
||||
l.scales_gpu, // output (should be FP32)
|
||||
l.scale_updates_gpu, // output (should be FP32)
|
||||
l.bias_updates_gpu, // output (should be FP32)
|
||||
.00001,
|
||||
l.mean_gpu, // input (should be FP32)
|
||||
l.variance_gpu); // input (should be FP32)
|
||||
copy_ongpu(l.outputs*l.batch / 2, l.x_norm_gpu, 1, delta16, 1);
|
||||
//cudaMemcpyAsync(delta16, l.x_norm_gpu, l.outputs*l.batch * sizeof(half), cudaMemcpyDefault, get_cuda_stream());
|
||||
}
|
||||
else
|
||||
//#ifdef CUDNN_HALF
|
||||
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > state.net.burn_in))
|
||||
{
|
||||
//backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
}
|
||||
|
||||
// convert input: state.input (x), l.delta_gpu (y) from fp32 to fp16
|
||||
// get output: l.weight_updates_gpu (dw) and convert it to fp32 (ONLY if it is fp16)
|
||||
const size_t input16_size = l.batch*l.c*l.w*l.h;
|
||||
const size_t delta16_size = l.batch*l.n*l.out_w*l.out_h;
|
||||
|
||||
// calculate conv weight updates
|
||||
// Already: l.weight_updates_gpu = (l.weight_updates_gpu - l.weight*decay*batch*subdivision)*momentum
|
||||
// so we should copy f32 to f16, or compute: f16=(w_up - w*d*b*s)*m
|
||||
cuda_convert_f32_to_f16(l.weight_updates_gpu, l.c*l.n*l.size*l.size, l.weight_updates_gpu16);
|
||||
if (*state.net.max_input16_size < input16_size) {
|
||||
*state.net.max_input16_size = input16_size;
|
||||
if (*state.net.input16_gpu) cuda_free(*state.net.input16_gpu);
|
||||
*state.net.input16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_input16_size);
|
||||
}
|
||||
float *input16 = *state.net.input16_gpu;
|
||||
|
||||
cudnnConvolutionBackwardFilter(cudnn_handle(),
|
||||
&one,
|
||||
l.srcTensorDesc,
|
||||
input16, //state.input,
|
||||
l.ddstTensorDesc,
|
||||
delta16, //l.delta_gpu,
|
||||
l.convDesc,
|
||||
l.bf_algo,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&one,
|
||||
l.dweightDesc,
|
||||
l.weight_updates_gpu16); // l.weight_updates_gpu);
|
||||
if (*state.net.max_output16_size < delta16_size) {
|
||||
*state.net.max_output16_size = delta16_size;
|
||||
if (*state.net.output16_gpu) cuda_free(*state.net.output16_gpu);
|
||||
*state.net.output16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_output16_size);
|
||||
}
|
||||
float *delta16 = *state.net.output16_gpu;
|
||||
|
||||
cuda_convert_f16_to_f32(l.weight_updates_gpu16, l.c*l.n*l.size*l.size, l.weight_updates_gpu);
|
||||
cuda_convert_f32_to_f16(state.input, input16_size, input16);
|
||||
cuda_convert_f32_to_f16(l.delta_gpu, delta16_size, delta16);
|
||||
|
||||
if (state.delta) {
|
||||
if (l.binary || l.xnor) swap_binary(&l);
|
||||
if (l.batch_normalize) {
|
||||
//if (!state.train) {
|
||||
// l.mean_gpu = l.rolling_mean_gpu;
|
||||
// l.variance_gpu = l.rolling_variance_gpu;
|
||||
//}
|
||||
float one = 1;
|
||||
float zero = 0;
|
||||
cudnnBatchNormalizationBackward(cudnn_handle(),
|
||||
CUDNN_BATCHNORM_SPATIAL,
|
||||
&one,
|
||||
&zero,
|
||||
&one,
|
||||
&one,
|
||||
l.normDstTensorDescF16,
|
||||
l.x_gpu, // input
|
||||
l.normDstTensorDescF16,
|
||||
delta16, // input
|
||||
l.normDstTensorDescF16,
|
||||
l.x_norm_gpu, // output
|
||||
l.normTensorDesc,
|
||||
l.scales_gpu, // output (should be FP32)
|
||||
l.scale_updates_gpu, // output (should be FP32)
|
||||
l.bias_updates_gpu, // output (should be FP32)
|
||||
.00001,
|
||||
l.mean_gpu, // input (should be FP32)
|
||||
l.variance_gpu); // input (should be FP32)
|
||||
copy_ongpu(l.outputs*l.batch / 2, l.x_norm_gpu, 1, delta16, 1);
|
||||
//cudaMemcpyAsync(delta16, l.x_norm_gpu, l.outputs*l.batch * sizeof(half), cudaMemcpyDefault, get_cuda_stream());
|
||||
}
|
||||
else
|
||||
{
|
||||
//backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
}
|
||||
|
||||
// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
|
||||
// calculate delta for the next layer
|
||||
// convert input: l.weights_gpu (w), l.delta_gpu (dy) from fp32 to fp16
|
||||
// get output: state.delta (dx) and convert it to fp32 (ONLY if it is fp16)
|
||||
cudnnConvolutionBackwardData(cudnn_handle(),
|
||||
&alpha,
|
||||
l.weightDesc,
|
||||
l.weights_gpu16, //l.weights_gpu,
|
||||
l.ddstTensorDesc,
|
||||
// convert input: state.input (x), l.delta_gpu (y) from fp32 to fp16
|
||||
// get output: l.weight_updates_gpu (dw) and convert it to fp32 (ONLY if it is fp16)
|
||||
|
||||
// calculate conv weight updates
|
||||
// Already: l.weight_updates_gpu = (l.weight_updates_gpu - l.weight*decay*batch*subdivision)*momentum
|
||||
// so we should copy f32 to f16, or compute: f16=(w_up - w*d*b*s)*m
|
||||
cuda_convert_f32_to_f16(l.weight_updates_gpu, l.c*l.n*l.size*l.size, l.weight_updates_gpu16);
|
||||
|
||||
cudnnConvolutionBackwardFilter(cudnn_handle(),
|
||||
&one,
|
||||
l.srcTensorDesc16,
|
||||
input16, //state.input,
|
||||
l.ddstTensorDesc16,
|
||||
delta16, //l.delta_gpu,
|
||||
l.convDesc,
|
||||
l.bd_algo,
|
||||
l.bf_algo16,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&beta,
|
||||
l.dsrcTensorDesc,
|
||||
input16); // state.delta);
|
||||
&one,
|
||||
l.dweightDesc,
|
||||
l.weight_updates_gpu16); // l.weight_updates_gpu);
|
||||
|
||||
cuda_convert_f16_to_f32(input16, input16_size, state.delta);
|
||||
cuda_convert_f16_to_f32(l.weight_updates_gpu16, l.c*l.n*l.size*l.size, l.weight_updates_gpu);
|
||||
|
||||
if (l.binary || l.xnor) swap_binary(&l);
|
||||
if (l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta);
|
||||
if (state.delta) {
|
||||
if (l.binary || l.xnor) swap_binary(&l);
|
||||
|
||||
// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
|
||||
// calculate delta for the next layer
|
||||
// convert input: l.weights_gpu (w), l.delta_gpu (dy) from fp32 to fp16
|
||||
// get output: state.delta (dx) and convert it to fp32 (ONLY if it is fp16)
|
||||
cudnnConvolutionBackwardData(cudnn_handle(),
|
||||
&alpha,
|
||||
l.weightDesc16,
|
||||
l.weights_gpu16, //l.weights_gpu,
|
||||
l.ddstTensorDesc16,
|
||||
delta16, //l.delta_gpu,
|
||||
l.convDesc,
|
||||
l.bd_algo16,
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&beta,
|
||||
l.dsrcTensorDesc16,
|
||||
input16); // state.delta);
|
||||
|
||||
cuda_convert_f16_to_f32(input16, input16_size, state.delta);
|
||||
|
||||
if (l.binary || l.xnor) swap_binary(&l);
|
||||
if (l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta);
|
||||
}
|
||||
}
|
||||
#else // CUDNN_HALF
|
||||
else {
|
||||
//#else // CUDNN_HALF
|
||||
|
||||
// calculate conv weight updates
|
||||
// if used: beta=1 then loss decreases faster
|
||||
cudnnConvolutionBackwardFilter(cudnn_handle(),
|
||||
if(l.batch_normalize){
|
||||
backward_batchnorm_layer_gpu(l, state);
|
||||
}
|
||||
|
||||
// calculate conv weight updates
|
||||
// if used: beta=1 then loss decreases faster
|
||||
cudnnConvolutionBackwardFilter(cudnn_handle(),
|
||||
&one,
|
||||
l.srcTensorDesc,
|
||||
state.input,
|
||||
@ -579,11 +606,11 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
l.dweightDesc,
|
||||
l.weight_updates_gpu);
|
||||
|
||||
if(state.delta){
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
|
||||
// calculate delta for the next layer
|
||||
cudnnConvolutionBackwardData(cudnn_handle(),
|
||||
if (state.delta) {
|
||||
if (l.binary || l.xnor) swap_binary(&l);
|
||||
// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
|
||||
// calculate delta for the next layer
|
||||
cudnnConvolutionBackwardData(cudnn_handle(),
|
||||
&one,
|
||||
l.weightDesc,
|
||||
l.weights_gpu,
|
||||
@ -596,13 +623,18 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
&one,
|
||||
l.dsrcTensorDesc,
|
||||
state.delta);
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
if(l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta);
|
||||
if (l.binary || l.xnor) swap_binary(&l);
|
||||
if (l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // CUDNN_HALF
|
||||
//#endif // CUDNN_HALF
|
||||
|
||||
#else // CUDNN
|
||||
if (l.batch_normalize) {
|
||||
backward_batchnorm_layer_gpu(l, state);
|
||||
}
|
||||
|
||||
int m = l.n;
|
||||
int n = l.size*l.size*l.c;
|
||||
int k = l.out_w*l.out_h;
|
||||
|
Reference in New Issue
Block a user