mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Commit for pjreddie master PR for cudnnBatchNormalizationForwardInference
This commit is contained in:
parent
508381b37f
commit
2be86da6b8
@ -228,9 +228,29 @@ void forward_batchnorm_layer_gpu(layer l, network net)
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h);
|
||||
#endif
|
||||
} else {
|
||||
#ifdef CUDNN
|
||||
float one = 1;
|
||||
float zero = 0;
|
||||
cudnnBatchNormalizationForwardInference(cudnn_handle(),
|
||||
CUDNN_BATCHNORM_SPATIAL,
|
||||
&one,
|
||||
&zero,
|
||||
l.dstTensorDesc,
|
||||
l.x_gpu,
|
||||
l.dstTensorDesc,
|
||||
l.output_gpu,
|
||||
l.normTensorDesc,
|
||||
l.scales_gpu,
|
||||
l.biases_gpu,
|
||||
l.rolling_mean_gpu,
|
||||
l.rolling_variance_gpu,
|
||||
.00001);
|
||||
|
||||
#else
|
||||
normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||
scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h);
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user