mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Improve training performance - batch-norm using cuDNN.
This commit is contained in:
@ -52,6 +52,12 @@ layer make_batchnorm_layer(int batch, int w, int h, int c)
|
|||||||
|
|
||||||
layer.x_gpu = cuda_make_array(layer.output, layer.batch*layer.outputs);
|
layer.x_gpu = cuda_make_array(layer.output, layer.batch*layer.outputs);
|
||||||
layer.x_norm_gpu = cuda_make_array(layer.output, layer.batch*layer.outputs);
|
layer.x_norm_gpu = cuda_make_array(layer.output, layer.batch*layer.outputs);
|
||||||
|
#ifdef CUDNN
|
||||||
|
cudnnCreateTensorDescriptor(&layer.normTensorDesc);
|
||||||
|
cudnnCreateTensorDescriptor(&layer.dstTensorDesc);
|
||||||
|
cudnnSetTensor4dDescriptor(layer.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, layer.batch, layer.out_c, layer.out_h, layer.out_w);
|
||||||
|
cudnnSetTensor4dDescriptor(layer.normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, layer.out_c, 1, 1);
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
return layer;
|
return layer;
|
||||||
}
|
}
|
||||||
@ -170,7 +176,7 @@ void push_batchnorm_layer(layer l)
|
|||||||
cuda_push_array(l.rolling_mean_gpu, l.rolling_mean, l.c);
|
cuda_push_array(l.rolling_mean_gpu, l.rolling_mean, l.c);
|
||||||
cuda_push_array(l.rolling_variance_gpu, l.rolling_variance, l.c);
|
cuda_push_array(l.rolling_variance_gpu, l.rolling_variance, l.c);
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
void forward_batchnorm_layer_gpu(layer l, network_state state)
|
void forward_batchnorm_layer_gpu(layer l, network_state state)
|
||||||
{
|
{
|
||||||
if(l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1);
|
if(l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1);
|
||||||
@ -209,3 +215,98 @@ void backward_batchnorm_layer_gpu(const layer l, network_state state)
|
|||||||
if(l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, state.delta, 1);
|
if(l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, state.delta, 1);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
void forward_batchnorm_layer_gpu(layer l, network_state state)
|
||||||
|
{
|
||||||
|
if (l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1);
|
||||||
|
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
|
||||||
|
if (state.train) {
|
||||||
|
#ifdef CUDNN
|
||||||
|
float one = 1;
|
||||||
|
float zero = 0;
|
||||||
|
cudnnBatchNormalizationForwardTraining(cudnn_handle(),
|
||||||
|
CUDNN_BATCHNORM_SPATIAL,
|
||||||
|
&one,
|
||||||
|
&zero,
|
||||||
|
l.dstTensorDesc,
|
||||||
|
l.x_gpu,
|
||||||
|
l.dstTensorDesc,
|
||||||
|
l.output_gpu,
|
||||||
|
l.normTensorDesc,
|
||||||
|
l.scales_gpu,
|
||||||
|
l.biases_gpu,
|
||||||
|
.01,
|
||||||
|
l.rolling_mean_gpu,
|
||||||
|
l.rolling_variance_gpu,
|
||||||
|
.00001,
|
||||||
|
l.mean_gpu,
|
||||||
|
l.variance_gpu);
|
||||||
|
#else
|
||||||
|
fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu);
|
||||||
|
fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.variance_gpu);
|
||||||
|
|
||||||
|
scal_ongpu(l.out_c, .99, l.rolling_mean_gpu, 1);
|
||||||
|
axpy_ongpu(l.out_c, .01, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
|
||||||
|
scal_ongpu(l.out_c, .99, l.rolling_variance_gpu, 1);
|
||||||
|
axpy_ongpu(l.out_c, .01, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
|
||||||
|
|
||||||
|
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
|
||||||
|
normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||||
|
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1);
|
||||||
|
|
||||||
|
scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||||
|
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||||
|
scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||||
|
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward_batchnorm_layer_gpu(layer l, network_state state)
|
||||||
|
{
|
||||||
|
if (!state.train) {
|
||||||
|
l.mean_gpu = l.rolling_mean_gpu;
|
||||||
|
l.variance_gpu = l.rolling_variance_gpu;
|
||||||
|
}
|
||||||
|
#ifdef CUDNN
|
||||||
|
float one = 1;
|
||||||
|
float zero = 0;
|
||||||
|
cudnnBatchNormalizationBackward(cudnn_handle(),
|
||||||
|
CUDNN_BATCHNORM_SPATIAL,
|
||||||
|
&one,
|
||||||
|
&zero,
|
||||||
|
&one,
|
||||||
|
&one,
|
||||||
|
l.dstTensorDesc,
|
||||||
|
l.x_gpu,
|
||||||
|
l.dstTensorDesc,
|
||||||
|
l.delta_gpu,
|
||||||
|
l.dstTensorDesc,
|
||||||
|
l.x_norm_gpu,
|
||||||
|
l.normTensorDesc,
|
||||||
|
l.scales_gpu,
|
||||||
|
l.scale_updates_gpu,
|
||||||
|
l.bias_updates_gpu,
|
||||||
|
.00001,
|
||||||
|
l.mean_gpu,
|
||||||
|
l.variance_gpu);
|
||||||
|
copy_ongpu(l.outputs*l.batch, l.x_norm_gpu, 1, l.delta_gpu, 1);
|
||||||
|
#else
|
||||||
|
backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.out_c, l.out_w*l.out_h);
|
||||||
|
backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.scale_updates_gpu);
|
||||||
|
|
||||||
|
scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
|
||||||
|
|
||||||
|
fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.mean_delta_gpu);
|
||||||
|
fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.variance_delta_gpu);
|
||||||
|
normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.delta_gpu);
|
||||||
|
#endif
|
||||||
|
if (l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, state.delta, 1);
|
||||||
|
}
|
||||||
|
#endif
|
@ -80,6 +80,7 @@ void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forwa
|
|||||||
|
|
||||||
void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output);
|
void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output);
|
||||||
void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t);
|
void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t);
|
||||||
|
void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, float eps, float decay, float rate, int n, int batch, int t);
|
||||||
|
|
||||||
void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out);
|
void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out);
|
||||||
|
|
||||||
|
@ -145,8 +145,8 @@ __global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float
|
|||||||
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||||
if (index >= N) return;
|
if (index >= N) return;
|
||||||
|
|
||||||
x[index] = x[index] - (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps));
|
x[index] = x[index] - (rate * sqrtf(1.F-powf(B2, t)) / (1.F-powf(B1, t)) * m[index] / (sqrtf(v[index]) + eps));
|
||||||
//if(index == 0) printf("%f %f %f %f\n", m[index], v[index], (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps)));
|
//if(index == 0) printf("%f %f %f %f\n", m[index], v[index], (rate * sqrtf(1.F-powf(B2, t)) / (1.F-powf(B1, t)) * m[index] / (sqrt(v[index]) + eps)));
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
|
extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
|
||||||
@ -155,13 +155,27 @@ extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2
|
|||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extern "C" void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, float eps, float decay, float rate, int n, int batch, int t)
|
||||||
|
{
|
||||||
|
scal_ongpu(n, B1, m, 1);
|
||||||
|
scal_ongpu(n, B2, v, 1);
|
||||||
|
axpy_ongpu(n, -decay*batch, w, 1, d, 1);
|
||||||
|
|
||||||
|
axpy_ongpu(n, (1 - B1), d, 1, m, 1);
|
||||||
|
mul_ongpu(n, d, 1, d, 1);
|
||||||
|
axpy_ongpu(n, (1 - B2), d, 1, v, 1);
|
||||||
|
|
||||||
|
adam_gpu(n, w, m, v, B1, B2, rate, eps, t);
|
||||||
|
fill_ongpu(n, 0, d, 1);
|
||||||
|
}
|
||||||
|
|
||||||
__global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial)
|
__global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial)
|
||||||
{
|
{
|
||||||
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||||
if (index >= N) return;
|
if (index >= N) return;
|
||||||
int f = (index/spatial)%filters;
|
int f = (index/spatial)%filters;
|
||||||
|
|
||||||
x[index] = (x[index] - mean[f])/(sqrt(variance[f]) + .000001f);
|
x[index] = (x[index] - mean[f])/(sqrtf(variance[f]) + .000001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void normalize_delta_kernel(int N, float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
|
__global__ void normalize_delta_kernel(int N, float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
|
||||||
@ -170,7 +184,7 @@ __global__ void normalize_delta_kernel(int N, float *x, float *mean, float *vari
|
|||||||
if (index >= N) return;
|
if (index >= N) return;
|
||||||
int f = (index/spatial)%filters;
|
int f = (index/spatial)%filters;
|
||||||
|
|
||||||
delta[index] = delta[index] * 1./(sqrt(variance[f]) + .000001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
|
delta[index] = delta[index] * 1.F/(sqrtf(variance[f]) + .000001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
|
extern "C" void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
|
||||||
@ -192,7 +206,7 @@ __global__ void variance_delta_kernel(float *x, float *delta, float *mean, floa
|
|||||||
variance_delta[i] += delta[index]*(x[index] - mean[i]);
|
variance_delta[i] += delta[index]*(x[index] - mean[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
variance_delta[i] *= -.5 * pow(variance[i] + .000001f, (float)(-3./2.));
|
variance_delta[i] *= -.5 * powf(variance[i] + .000001f, (float)(-3./2.));
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void accumulate_kernel(float *x, int n, int groups, float *sum)
|
__global__ void accumulate_kernel(float *x, int n, int groups, float *sum)
|
||||||
@ -230,7 +244,7 @@ __global__ void fast_mean_delta_kernel(float *delta, float *variance, int batch,
|
|||||||
for(i = 0; i < threads; ++i){
|
for(i = 0; i < threads; ++i){
|
||||||
mean_delta[filter] += local[i];
|
mean_delta[filter] += local[i];
|
||||||
}
|
}
|
||||||
mean_delta[filter] *= (-1./sqrt(variance[filter] + .000001f));
|
mean_delta[filter] *= (-1.F/sqrtf(variance[filter] + .000001f));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -259,7 +273,7 @@ __global__ void fast_variance_delta_kernel(float *x, float *delta, float *mean,
|
|||||||
for(i = 0; i < threads; ++i){
|
for(i = 0; i < threads; ++i){
|
||||||
variance_delta[filter] += local[i];
|
variance_delta[filter] += local[i];
|
||||||
}
|
}
|
||||||
variance_delta[filter] *= -.5 * pow(variance[filter] + .000001f, (float)(-3./2.));
|
variance_delta[filter] *= -.5 * powf(variance[filter] + .000001f, (float)(-3./2.));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -276,7 +290,7 @@ __global__ void mean_delta_kernel(float *delta, float *variance, int batch, int
|
|||||||
mean_delta[i] += delta[index];
|
mean_delta[i] += delta[index];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mean_delta[i] *= (-1./sqrt(variance[i] + .000001f));
|
mean_delta[i] *= (-1.F/sqrtf(variance[i] + .000001f));
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
|
extern "C" void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
|
||||||
@ -299,7 +313,7 @@ extern "C" void fast_variance_delta_gpu(float *x, float *delta, float *mean, flo
|
|||||||
|
|
||||||
__global__ void mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
|
__global__ void mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
|
||||||
{
|
{
|
||||||
float scale = 1./(batch * spatial);
|
float scale = 1.F/(batch * spatial);
|
||||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||||
if (i >= filters) return;
|
if (i >= filters) return;
|
||||||
int j,k;
|
int j,k;
|
||||||
@ -315,7 +329,7 @@ __global__ void mean_kernel(float *x, int batch, int filters, int spatial, floa
|
|||||||
|
|
||||||
__global__ void variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
|
__global__ void variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
|
||||||
{
|
{
|
||||||
float scale = 1./(batch * spatial - 1);
|
float scale = 1.F/(batch * spatial - 1);
|
||||||
int j,k;
|
int j,k;
|
||||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||||
if (i >= filters) return;
|
if (i >= filters) return;
|
||||||
@ -323,7 +337,7 @@ __global__ void variance_kernel(float *x, float *mean, int batch, int filters, i
|
|||||||
for(j = 0; j < batch; ++j){
|
for(j = 0; j < batch; ++j){
|
||||||
for(k = 0; k < spatial; ++k){
|
for(k = 0; k < spatial; ++k){
|
||||||
int index = j*filters*spatial + i*spatial + k;
|
int index = j*filters*spatial + i*spatial + k;
|
||||||
variance[i] += pow((x[index] - mean[i]), 2);
|
variance[i] += powf((x[index] - mean[i]), 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
variance[i] *= scale;
|
variance[i] *= scale;
|
||||||
@ -370,7 +384,7 @@ __global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX, f
|
|||||||
__global__ void pow_kernel(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
|
__global__ void pow_kernel(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
|
||||||
{
|
{
|
||||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||||
if(i < N) Y[i*INCY] = pow(X[i*INCX], ALPHA);
|
if(i < N) Y[i*INCY] = powf(X[i*INCX], ALPHA);
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void const_kernel(int N, float ALPHA, float *X, int INCX)
|
__global__ void const_kernel(int N, float ALPHA, float *X, int INCX)
|
||||||
@ -474,7 +488,7 @@ __global__ void fast_variance_kernel(float *x, float *mean, int batch, int filt
|
|||||||
for(i = 0; i < spatial; i += threads){
|
for(i = 0; i < spatial; i += threads){
|
||||||
int index = j*spatial*filters + filter*spatial + i + id;
|
int index = j*spatial*filters + filter*spatial + i + id;
|
||||||
|
|
||||||
local[id] += (i+id < spatial) ? pow((x[index] - mean[filter]), 2) : 0;
|
local[id] += (i+id < spatial) ? powf((x[index] - mean[filter]), 2) : 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@ -646,7 +660,7 @@ extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int
|
|||||||
if(sample < 1) sample = 1;
|
if(sample < 1) sample = 1;
|
||||||
|
|
||||||
int size = batch * minw * minh * minc;
|
int size = batch * minw * minh * minc;
|
||||||
shortcut_kernel<<<cuda_gridsize(size), BLOCK>>>(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
|
shortcut_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -769,3 +783,4 @@ extern "C" void softmax_gpu(float *input, int n, int offset, int groups, float t
|
|||||||
softmax_kernel<<<cuda_gridsize(batch), BLOCK, 0, get_cuda_stream()>>>(inputs, offset, batch, input, temp, output);
|
softmax_kernel<<<cuda_gridsize(batch), BLOCK, 0, get_cuda_stream()>>>(inputs, offset, batch, input, temp, output);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,6 +97,12 @@ connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVAT
|
|||||||
|
|
||||||
l.x_gpu = cuda_make_array(l.output, l.batch*outputs);
|
l.x_gpu = cuda_make_array(l.output, l.batch*outputs);
|
||||||
l.x_norm_gpu = cuda_make_array(l.output, l.batch*outputs);
|
l.x_norm_gpu = cuda_make_array(l.output, l.batch*outputs);
|
||||||
|
#ifdef CUDNN
|
||||||
|
cudnnCreateTensorDescriptor(&l.normTensorDesc);
|
||||||
|
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
|
||||||
|
cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
|
||||||
|
cudnnSetTensor4dDescriptor(l.normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l.out_c, 1, 1);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
l.activation = activation;
|
l.activation = activation;
|
||||||
@ -283,9 +289,10 @@ void forward_connected_layer_gpu(connected_layer l, network_state state)
|
|||||||
if (l.batch_normalize) {
|
if (l.batch_normalize) {
|
||||||
forward_batchnorm_layer_gpu(l, state);
|
forward_batchnorm_layer_gpu(l, state);
|
||||||
}
|
}
|
||||||
for(i = 0; i < l.batch; ++i){
|
else {
|
||||||
axpy_ongpu(l.outputs, 1, l.biases_gpu, 1, l.output_gpu + i*l.outputs, 1);
|
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.outputs, 1);
|
||||||
}
|
}
|
||||||
|
//for(i = 0; i < l.batch; ++i) axpy_ongpu(l.outputs, 1, l.biases_gpu, 1, l.output_gpu + i*l.outputs, 1);
|
||||||
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ __global__ void binarize_input_kernel(float *input, int n, int size, float *bina
|
|||||||
int i = 0;
|
int i = 0;
|
||||||
float mean = 0;
|
float mean = 0;
|
||||||
for(i = 0; i < n; ++i){
|
for(i = 0; i < n; ++i){
|
||||||
mean += abs(input[i*size + s]);
|
mean += fabs(input[i*size + s]);
|
||||||
}
|
}
|
||||||
mean = mean / n;
|
mean = mean / n;
|
||||||
for(i = 0; i < n; ++i){
|
for(i = 0; i < n; ++i){
|
||||||
@ -59,7 +59,7 @@ __global__ void binarize_weights_kernel(float *weights, int n, int size, float *
|
|||||||
int i = 0;
|
int i = 0;
|
||||||
float mean = 0;
|
float mean = 0;
|
||||||
for(i = 0; i < size; ++i){
|
for(i = 0; i < size; ++i){
|
||||||
mean += abs(weights[f*size + i]);
|
mean += fabs(weights[f*size + i]);
|
||||||
}
|
}
|
||||||
mean = mean / size;
|
mean = mean / size;
|
||||||
for(i = 0; i < size; ++i){
|
for(i = 0; i < size; ++i){
|
||||||
@ -206,7 +206,9 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
|||||||
if (l.batch_normalize) {
|
if (l.batch_normalize) {
|
||||||
forward_batchnorm_layer_gpu(l, state);
|
forward_batchnorm_layer_gpu(l, state);
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||||
|
}
|
||||||
|
|
||||||
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||||
//if(l.dot > 0) dot_error_gpu(l);
|
//if(l.dot > 0) dot_error_gpu(l);
|
||||||
|
@ -174,6 +174,9 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
|
|||||||
cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w);
|
cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w);
|
||||||
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
|
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
|
||||||
cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
||||||
|
|
||||||
|
// batch norm
|
||||||
|
cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1);
|
||||||
#if(CUDNN_MAJOR >= 6)
|
#if(CUDNN_MAJOR >= 6)
|
||||||
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT); // cudnn >= 6.0
|
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT); // cudnn >= 6.0
|
||||||
#else
|
#else
|
||||||
@ -341,6 +344,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
|||||||
l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
||||||
}
|
}
|
||||||
#ifdef CUDNN
|
#ifdef CUDNN
|
||||||
|
cudnnCreateTensorDescriptor(&l.normTensorDesc);
|
||||||
cudnnCreateTensorDescriptor(&l.srcTensorDesc);
|
cudnnCreateTensorDescriptor(&l.srcTensorDesc);
|
||||||
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
|
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
|
||||||
cudnnCreateFilterDescriptor(&l.weightDesc);
|
cudnnCreateFilterDescriptor(&l.weightDesc);
|
||||||
|
@ -19,6 +19,9 @@ extern int gpu_index;
|
|||||||
#include "cudnn.h"
|
#include "cudnn.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
void check_error(cudaError_t status);
|
void check_error(cudaError_t status);
|
||||||
cublasHandle_t blas_handle();
|
cublasHandle_t blas_handle();
|
||||||
float *cuda_make_array(float *x, size_t n);
|
float *cuda_make_array(float *x, size_t n);
|
||||||
@ -32,6 +35,9 @@ void cuda_random(float *x_gpu, size_t n);
|
|||||||
float cuda_compare(float *x_gpu, float *x, size_t n, char *s);
|
float cuda_compare(float *x_gpu, float *x, size_t n, char *s);
|
||||||
dim3 cuda_gridsize(size_t n);
|
dim3 cuda_gridsize(size_t n);
|
||||||
cudaStream_t get_cuda_stream();
|
cudaStream_t get_cuda_stream();
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef CUDNN
|
#ifdef CUDNN
|
||||||
cudnnHandle_t cudnn_handle();
|
cudnnHandle_t cudnn_handle();
|
||||||
|
@ -91,7 +91,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
|||||||
args.small_object = l.small_object;
|
args.small_object = l.small_object;
|
||||||
args.d = &buffer;
|
args.d = &buffer;
|
||||||
args.type = DETECTION_DATA;
|
args.type = DETECTION_DATA;
|
||||||
args.threads = 8; // 64
|
args.threads = 64; // 8
|
||||||
|
|
||||||
args.angle = net.angle;
|
args.angle = net.angle;
|
||||||
args.exposure = net.exposure;
|
args.exposure = net.exposure;
|
||||||
@ -1031,6 +1031,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
|
|||||||
}
|
}
|
||||||
image im = load_image_color(input,0,0);
|
image im = load_image_color(input,0,0);
|
||||||
image sized = resize_image(im, net.w, net.h);
|
image sized = resize_image(im, net.w, net.h);
|
||||||
|
//image sized = letterbox_image(im, net.w, net.h);
|
||||||
layer l = net.layers[net.n-1];
|
layer l = net.layers[net.n-1];
|
||||||
|
|
||||||
box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
|
box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
|
||||||
|
@ -352,6 +352,7 @@ IplImage* draw_train_chart(float max_img_loss, int max_batches, int number_of_li
|
|||||||
}
|
}
|
||||||
cvPutText(img, "Iteration number", cvPoint(draw_size / 2, img_size - 10), &font, CV_RGB(0, 0, 0));
|
cvPutText(img, "Iteration number", cvPoint(draw_size / 2, img_size - 10), &font, CV_RGB(0, 0, 0));
|
||||||
cvPutText(img, "Press 's' to save: chart.jpg", cvPoint(5, img_size - 10), &font, CV_RGB(0, 0, 0));
|
cvPutText(img, "Press 's' to save: chart.jpg", cvPoint(5, img_size - 10), &font, CV_RGB(0, 0, 0));
|
||||||
|
printf(" If error occurs - run training with flag: -dont_show \n");
|
||||||
cvNamedWindow("average loss", CV_WINDOW_NORMAL);
|
cvNamedWindow("average loss", CV_WINDOW_NORMAL);
|
||||||
cvMoveWindow("average loss", 0, 0);
|
cvMoveWindow("average loss", 0, 0);
|
||||||
cvResizeWindow("average loss", img_size, img_size);
|
cvResizeWindow("average loss", img_size, img_size);
|
||||||
|
13
src/layer.h
13
src/layer.h
@ -42,6 +42,18 @@ typedef enum{
|
|||||||
SSE, MASKED, SMOOTH
|
SSE, MASKED, SMOOTH
|
||||||
} COST_TYPE;
|
} COST_TYPE;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int batch;
|
||||||
|
float learning_rate;
|
||||||
|
float momentum;
|
||||||
|
float decay;
|
||||||
|
int adam;
|
||||||
|
float B1;
|
||||||
|
float B2;
|
||||||
|
float eps;
|
||||||
|
int t;
|
||||||
|
} update_args;
|
||||||
|
|
||||||
struct layer{
|
struct layer{
|
||||||
LAYER_TYPE type;
|
LAYER_TYPE type;
|
||||||
ACTIVATION activation;
|
ACTIVATION activation;
|
||||||
@ -261,6 +273,7 @@ struct layer{
|
|||||||
#ifdef CUDNN
|
#ifdef CUDNN
|
||||||
cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;
|
cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;
|
||||||
cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc;
|
cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc;
|
||||||
|
cudnnTensorDescriptor_t normTensorDesc;
|
||||||
cudnnFilterDescriptor_t weightDesc;
|
cudnnFilterDescriptor_t weightDesc;
|
||||||
cudnnFilterDescriptor_t dweightDesc;
|
cudnnFilterDescriptor_t dweightDesc;
|
||||||
cudnnConvolutionDescriptor_t convDesc;
|
cudnnConvolutionDescriptor_t convDesc;
|
||||||
|
@ -121,7 +121,7 @@ void forward_backward_network_gpu(network net, float *x, float *y)
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
forward_network_gpu(net, state);
|
forward_network_gpu(net, state);
|
||||||
cudaStreamSynchronize(get_cuda_stream());
|
//cudaStreamSynchronize(get_cuda_stream());
|
||||||
backward_network_gpu(net, state);
|
backward_network_gpu(net, state);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -434,7 +434,7 @@ void forward_region_layer_gpu(const region_layer l, network_state state)
|
|||||||
cuda_pull_array(state.truth, truth_cpu, num_truth);
|
cuda_pull_array(state.truth, truth_cpu, num_truth);
|
||||||
}
|
}
|
||||||
cuda_pull_array(l.output_gpu, in_cpu, l.batch*l.inputs);
|
cuda_pull_array(l.output_gpu, in_cpu, l.batch*l.inputs);
|
||||||
cudaStreamSynchronize(get_cuda_stream());
|
//cudaStreamSynchronize(get_cuda_stream());
|
||||||
network_state cpu_state = state;
|
network_state cpu_state = state;
|
||||||
cpu_state.train = state.train;
|
cpu_state.train = state.train;
|
||||||
cpu_state.truth = truth_cpu;
|
cpu_state.truth = truth_cpu;
|
||||||
@ -444,7 +444,7 @@ void forward_region_layer_gpu(const region_layer l, network_state state)
|
|||||||
free(cpu_state.input);
|
free(cpu_state.input);
|
||||||
if(!state.train) return;
|
if(!state.train) return;
|
||||||
cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs);
|
cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs);
|
||||||
cudaStreamSynchronize(get_cuda_stream());
|
//cudaStreamSynchronize(get_cuda_stream());
|
||||||
if(cpu_state.truth) free(cpu_state.truth);
|
if(cpu_state.truth) free(cpu_state.truth);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user