mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
grouped convolutions 🐍 🐍 🐍
This commit is contained in:
parent
62b781af4d
commit
fbd48ab606
@ -1,6 +1,10 @@
|
||||
[net]
|
||||
# Train
|
||||
batch=128
|
||||
subdivisions=1
|
||||
# Test
|
||||
# batch=1
|
||||
# subdivisions=1
|
||||
height=224
|
||||
width=224
|
||||
channels=3
|
||||
|
@ -1,6 +1,10 @@
|
||||
[net]
|
||||
# Train
|
||||
batch=128
|
||||
subdivisions=1
|
||||
# Test
|
||||
# batch=1
|
||||
# subdivisions=1
|
||||
height=224
|
||||
width=224
|
||||
channels=3
|
||||
|
@ -90,7 +90,7 @@ long numops(network net)
|
||||
for(i = 0; i < net.n; ++i){
|
||||
layer l = net.layers[i];
|
||||
if(l.type == CONVOLUTIONAL){
|
||||
ops += 2l * l.n * l.size*l.size*l.c * l.out_h*l.out_w;
|
||||
ops += 2l * l.n * l.size*l.size*l.c/l.groups * l.out_h*l.out_w;
|
||||
} else if(l.type == CONNECTED){
|
||||
ops += 2l * l.inputs * l.outputs;
|
||||
} else if (l.type == RNN){
|
||||
|
@ -74,12 +74,12 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network net)
|
||||
{
|
||||
fill_gpu(l.outputs*l.batch, 0, l.output_gpu, 1);
|
||||
if(l.binary){
|
||||
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
||||
binarize_weights_gpu(l.weights_gpu, l.n, l.c/l.groups*l.size*l.size, l.binary_weights_gpu);
|
||||
swap_binary(&l);
|
||||
}
|
||||
|
||||
if(l.xnor){
|
||||
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
||||
binarize_weights_gpu(l.weights_gpu, l.n, l.c/l.groups*l.size*l.size, l.binary_weights_gpu);
|
||||
swap_binary(&l);
|
||||
binarize_gpu(net.input_gpu, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
|
||||
net.input_gpu = l.binary_input_gpu;
|
||||
@ -102,16 +102,20 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network net)
|
||||
l.output_gpu);
|
||||
|
||||
#else
|
||||
int i;
|
||||
int m = l.n;
|
||||
int k = l.size*l.size*l.c;
|
||||
int i, j;
|
||||
int m = l.n/l.groups;
|
||||
int k = l.size*l.size*l.c/l.groups;
|
||||
int n = l.out_w*l.out_h;
|
||||
for(i = 0; i < l.batch; ++i){
|
||||
im2col_gpu(net.input_gpu + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, net.workspace);
|
||||
float * a = l.weights_gpu;
|
||||
float * b = net.workspace;
|
||||
float * c = l.output_gpu;
|
||||
gemm_gpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
|
||||
for(j = 0; j < l.groups; ++j){
|
||||
float *a = l.weights_gpu + j*l.nweights/l.groups;
|
||||
float *b = net.workspace;
|
||||
float *c = l.output_gpu + (i*l.groups + j)*n*m;
|
||||
|
||||
im2col_gpu(net.input_gpu + (i*l.groups + j)*l.c/l.groups*l.h*l.w,
|
||||
l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, b);
|
||||
gemm_gpu(0,0,m,n,k,1,a,k,b,n,1,c,n);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -221,30 +225,36 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network net)
|
||||
}
|
||||
|
||||
#else
|
||||
int m = l.n;
|
||||
int n = l.size*l.size*l.c;
|
||||
int m = l.n/l.groups;
|
||||
int n = l.size*l.size*l.c/l.groups;
|
||||
int k = l.out_w*l.out_h;
|
||||
|
||||
int i;
|
||||
int i, j;
|
||||
for(i = 0; i < l.batch; ++i){
|
||||
float * a = l.delta_gpu;
|
||||
float * b = net.workspace;
|
||||
float * c = l.weight_updates_gpu;
|
||||
for(j = 0; j < l.groups; ++j){
|
||||
float *a = l.delta_gpu + (i*l.groups + j)*m*k;
|
||||
float *b = net.workspace;
|
||||
float *c = l.weight_updates_gpu + j*l.nweights/l.groups;
|
||||
|
||||
im2col_gpu(net.input_gpu + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, net.workspace);
|
||||
gemm_gpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
|
||||
float *im = net.input+(i*l.groups + j)*l.c/l.groups*l.h*l.w;
|
||||
|
||||
if(net.delta_gpu){
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
float * a = l.weights_gpu;
|
||||
float * b = l.delta_gpu;
|
||||
float * c = net.workspace;
|
||||
im2col_gpu(im, l.c/l.groups, l.h, l.w,
|
||||
l.size, l.stride, l.pad, b);
|
||||
gemm_gpu(0,1,m,n,k,1,a,k,b,k,1,c,n);
|
||||
|
||||
gemm_gpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
|
||||
if(net.delta_gpu){
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
a = l.weights_gpu + j*l.nweights/l.groups;
|
||||
b = l.delta_gpu + (i*l.groups + j)*m*k;
|
||||
c = net.workspace;
|
||||
|
||||
col2im_gpu(net.workspace, l.c, l.h, l.w, l.size, l.stride, l.pad, net.delta_gpu + i*l.c*l.h*l.w);
|
||||
if(l.binary || l.xnor) {
|
||||
swap_binary(&l);
|
||||
gemm_gpu(1,0,n,k,m,1,a,n,b,k,0,c,k);
|
||||
|
||||
col2im_gpu(net.workspace, l.c/l.groups, l.h, l.w, l.size, l.stride,
|
||||
l.pad, net.delta_gpu + (i*l.groups + j)*l.c/l.groups*l.h*l.w);
|
||||
if(l.binary || l.xnor) {
|
||||
swap_binary(&l);
|
||||
}
|
||||
}
|
||||
if(l.xnor) gradient_array_gpu(original_input + i*l.c*l.h*l.w, l.c*l.h*l.w, HARDTAN, net.delta_gpu + i*l.c*l.h*l.w);
|
||||
}
|
||||
@ -252,29 +262,29 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network net)
|
||||
#endif
|
||||
}
|
||||
|
||||
void pull_convolutional_layer(convolutional_layer layer)
|
||||
void pull_convolutional_layer(layer l)
|
||||
{
|
||||
cuda_pull_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_pull_array(layer.biases_gpu, layer.biases, layer.n);
|
||||
cuda_pull_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
|
||||
if (layer.batch_normalize){
|
||||
cuda_pull_array(layer.scales_gpu, layer.scales, layer.n);
|
||||
cuda_pull_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n);
|
||||
cuda_pull_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n);
|
||||
cuda_pull_array(l.weights_gpu, l.weights, l.nweights);
|
||||
cuda_pull_array(l.biases_gpu, l.biases, l.n);
|
||||
cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.nweights);
|
||||
cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.n);
|
||||
if (l.batch_normalize){
|
||||
cuda_pull_array(l.scales_gpu, l.scales, l.n);
|
||||
cuda_pull_array(l.rolling_mean_gpu, l.rolling_mean, l.n);
|
||||
cuda_pull_array(l.rolling_variance_gpu, l.rolling_variance, l.n);
|
||||
}
|
||||
}
|
||||
|
||||
void push_convolutional_layer(convolutional_layer layer)
|
||||
void push_convolutional_layer(layer l)
|
||||
{
|
||||
cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
|
||||
cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
|
||||
if (layer.batch_normalize){
|
||||
cuda_push_array(layer.scales_gpu, layer.scales, layer.n);
|
||||
cuda_push_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n);
|
||||
cuda_push_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n);
|
||||
cuda_push_array(l.weights_gpu, l.weights, l.nweights);
|
||||
cuda_push_array(l.biases_gpu, l.biases, l.n);
|
||||
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.nweights);
|
||||
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n);
|
||||
if (l.batch_normalize){
|
||||
cuda_push_array(l.scales_gpu, l.scales, l.n);
|
||||
cuda_push_array(l.rolling_mean_gpu, l.rolling_mean, l.n);
|
||||
cuda_push_array(l.rolling_variance_gpu, l.rolling_variance, l.n);
|
||||
}
|
||||
}
|
||||
|
||||
@ -285,18 +295,16 @@ void update_convolutional_layer_gpu(layer l, update_args a)
|
||||
float decay = a.decay;
|
||||
int batch = a.batch;
|
||||
|
||||
int size = l.size*l.size*l.c*l.n;
|
||||
|
||||
if(a.adam){
|
||||
adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, a.B1, a.B2, a.eps, decay, learning_rate, size, batch, a.t);
|
||||
adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, a.B1, a.B2, a.eps, decay, learning_rate, l.nweights, batch, a.t);
|
||||
adam_update_gpu(l.biases_gpu, l.bias_updates_gpu, l.bias_m_gpu, l.bias_v_gpu, a.B1, a.B2, a.eps, decay, learning_rate, l.n, batch, a.t);
|
||||
if(l.scales_gpu){
|
||||
adam_update_gpu(l.scales_gpu, l.scale_updates_gpu, l.scale_m_gpu, l.scale_v_gpu, a.B1, a.B2, a.eps, decay, learning_rate, l.n, batch, a.t);
|
||||
}
|
||||
}else{
|
||||
axpy_gpu(size, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1);
|
||||
axpy_gpu(size, learning_rate/batch, l.weight_updates_gpu, 1, l.weights_gpu, 1);
|
||||
scal_gpu(size, momentum, l.weight_updates_gpu, 1);
|
||||
axpy_gpu(l.nweights, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1);
|
||||
axpy_gpu(l.nweights, learning_rate/batch, l.weight_updates_gpu, 1, l.weights_gpu, 1);
|
||||
scal_gpu(l.nweights, momentum, l.weight_updates_gpu, 1);
|
||||
|
||||
axpy_gpu(l.n, learning_rate/batch, l.bias_updates_gpu, 1, l.biases_gpu, 1);
|
||||
scal_gpu(l.n, momentum, l.bias_updates_gpu, 1);
|
||||
|
@ -115,7 +115,7 @@ static size_t get_workspace_size(layer l){
|
||||
return most;
|
||||
}
|
||||
#endif
|
||||
return (size_t)l.out_h*l.out_w*l.size*l.size*l.c*sizeof(float);
|
||||
return (size_t)l.out_h*l.out_w*l.size*l.size*l.c/l.groups*sizeof(float);
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
@ -124,17 +124,27 @@ void cudnn_convolutional_setup(layer *l)
|
||||
{
|
||||
cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
|
||||
cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
|
||||
cudnnSetFilter4dDescriptor(l->dweightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
||||
|
||||
cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
|
||||
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
|
||||
cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1);
|
||||
cudnnSetFilter4dDescriptor(l->weightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
||||
|
||||
cudnnSetFilter4dDescriptor(l->dweightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c/l->groups, l->size, l->size);
|
||||
cudnnSetFilter4dDescriptor(l->weightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c/l->groups, l->size, l->size);
|
||||
#if CUDNN_MAJOR >= 6
|
||||
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT);
|
||||
#else
|
||||
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
|
||||
#endif
|
||||
|
||||
#if CUDNN_MAJOR >= 7
|
||||
cudnnSetConvolutionGroupCount(l->convDesc, l->groups);
|
||||
#else
|
||||
if(l->groups > 1){
|
||||
error("CUDNN < 7 doesn't support groups, please upgrade!");
|
||||
}
|
||||
#endif
|
||||
|
||||
cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
|
||||
l->srcTensorDesc,
|
||||
l->weightDesc,
|
||||
@ -163,12 +173,13 @@ void cudnn_convolutional_setup(layer *l)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam)
|
||||
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam)
|
||||
{
|
||||
int i;
|
||||
convolutional_layer l = {0};
|
||||
l.type = CONVOLUTIONAL;
|
||||
|
||||
l.groups = groups;
|
||||
l.h = h;
|
||||
l.w = w;
|
||||
l.c = c;
|
||||
@ -181,20 +192,20 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
l.pad = padding;
|
||||
l.batch_normalize = batch_normalize;
|
||||
|
||||
l.weights = calloc(c*n*size*size, sizeof(float));
|
||||
l.weight_updates = calloc(c*n*size*size, sizeof(float));
|
||||
l.weights = calloc(c/groups*n*size*size, sizeof(float));
|
||||
l.weight_updates = calloc(c/groups*n*size*size, sizeof(float));
|
||||
|
||||
l.biases = calloc(n, sizeof(float));
|
||||
l.bias_updates = calloc(n, sizeof(float));
|
||||
|
||||
l.nweights = c*n*size*size;
|
||||
l.nweights = c/groups*n*size*size;
|
||||
l.nbiases = n;
|
||||
|
||||
// float scale = 1./sqrt(size*size*c);
|
||||
float scale = sqrt(2./(size*size*c));
|
||||
float scale = sqrt(2./(size*size*c/l.groups));
|
||||
//scale = .02;
|
||||
//for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_uniform(-1, 1);
|
||||
for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_normal();
|
||||
for(i = 0; i < l.nweights; ++i) l.weights[i] = scale*rand_normal();
|
||||
int out_w = convolutional_out_width(l);
|
||||
int out_h = convolutional_out_height(l);
|
||||
l.out_h = out_h;
|
||||
@ -210,12 +221,12 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
l.backward = backward_convolutional_layer;
|
||||
l.update = update_convolutional_layer;
|
||||
if(binary){
|
||||
l.binary_weights = calloc(c*n*size*size, sizeof(float));
|
||||
l.cweights = calloc(c*n*size*size, sizeof(char));
|
||||
l.binary_weights = calloc(l.nweights, sizeof(float));
|
||||
l.cweights = calloc(l.nweights, sizeof(char));
|
||||
l.scales = calloc(n, sizeof(float));
|
||||
}
|
||||
if(xnor){
|
||||
l.binary_weights = calloc(c*n*size*size, sizeof(float));
|
||||
l.binary_weights = calloc(l.nweights, sizeof(float));
|
||||
l.binary_input = calloc(l.inputs*l.batch, sizeof(float));
|
||||
}
|
||||
|
||||
@ -238,8 +249,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
l.x_norm = calloc(l.batch*l.outputs, sizeof(float));
|
||||
}
|
||||
if(adam){
|
||||
l.m = calloc(c*n*size*size, sizeof(float));
|
||||
l.v = calloc(c*n*size*size, sizeof(float));
|
||||
l.m = calloc(l.nweights, sizeof(float));
|
||||
l.v = calloc(l.nweights, sizeof(float));
|
||||
l.bias_m = calloc(n, sizeof(float));
|
||||
l.scale_m = calloc(n, sizeof(float));
|
||||
l.bias_v = calloc(n, sizeof(float));
|
||||
@ -253,16 +264,16 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
|
||||
if(gpu_index >= 0){
|
||||
if (adam) {
|
||||
l.m_gpu = cuda_make_array(l.m, c*n*size*size);
|
||||
l.v_gpu = cuda_make_array(l.v, c*n*size*size);
|
||||
l.m_gpu = cuda_make_array(l.m, l.nweights);
|
||||
l.v_gpu = cuda_make_array(l.v, l.nweights);
|
||||
l.bias_m_gpu = cuda_make_array(l.bias_m, n);
|
||||
l.bias_v_gpu = cuda_make_array(l.bias_v, n);
|
||||
l.scale_m_gpu = cuda_make_array(l.scale_m, n);
|
||||
l.scale_v_gpu = cuda_make_array(l.scale_v, n);
|
||||
}
|
||||
|
||||
l.weights_gpu = cuda_make_array(l.weights, c*n*size*size);
|
||||
l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
|
||||
l.weights_gpu = cuda_make_array(l.weights, l.nweights);
|
||||
l.weight_updates_gpu = cuda_make_array(l.weight_updates, l.nweights);
|
||||
|
||||
l.biases_gpu = cuda_make_array(l.biases, n);
|
||||
l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
|
||||
@ -271,10 +282,10 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
||||
|
||||
if(binary){
|
||||
l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size);
|
||||
l.binary_weights_gpu = cuda_make_array(l.weights, l.nweights);
|
||||
}
|
||||
if(xnor){
|
||||
l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size);
|
||||
l.binary_weights_gpu = cuda_make_array(l.weights, l.nweights);
|
||||
l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
|
||||
}
|
||||
|
||||
@ -320,8 +331,8 @@ void denormalize_convolutional_layer(convolutional_layer l)
|
||||
int i, j;
|
||||
for(i = 0; i < l.n; ++i){
|
||||
float scale = l.scales[i]/sqrt(l.rolling_variance[i] + .00001);
|
||||
for(j = 0; j < l.c*l.size*l.size; ++j){
|
||||
l.weights[i*l.c*l.size*l.size + j] *= scale;
|
||||
for(j = 0; j < l.c/l.groups*l.size*l.size; ++j){
|
||||
l.weights[i*l.c/l.groups*l.size*l.size + j] *= scale;
|
||||
}
|
||||
l.biases[i] -= l.rolling_mean[i] * scale;
|
||||
l.scales[i] = 1;
|
||||
@ -432,54 +443,50 @@ void backward_bias(float *bias_updates, float *delta, int batch, int n, int size
|
||||
|
||||
void forward_convolutional_layer(convolutional_layer l, network net)
|
||||
{
|
||||
int out_h = l.out_h;
|
||||
int out_w = l.out_w;
|
||||
int i;
|
||||
int i, j;
|
||||
|
||||
fill_cpu(l.outputs*l.batch, 0, l.output, 1);
|
||||
|
||||
if(l.xnor){
|
||||
binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
|
||||
binarize_weights(l.weights, l.n, l.c/l.groups*l.size*l.size, l.binary_weights);
|
||||
swap_binary(&l);
|
||||
binarize_cpu(net.input, l.c*l.h*l.w*l.batch, l.binary_input);
|
||||
net.input = l.binary_input;
|
||||
}
|
||||
|
||||
int m = l.n;
|
||||
int k = l.size*l.size*l.c;
|
||||
int n = out_h*out_w;
|
||||
|
||||
|
||||
float *a = l.weights;
|
||||
float *b = net.workspace;
|
||||
float *c = l.output;
|
||||
|
||||
int m = l.n/l.groups;
|
||||
int k = l.size*l.size*l.c/l.groups;
|
||||
int n = l.out_w*l.out_h;
|
||||
for(i = 0; i < l.batch; ++i){
|
||||
im2col_cpu(net.input, l.c, l.h, l.w,
|
||||
l.size, l.stride, l.pad, b);
|
||||
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
|
||||
c += n*m;
|
||||
net.input += l.c*l.h*l.w;
|
||||
for(j = 0; j < l.groups; ++j){
|
||||
float *a = l.weights + j*l.nweights/l.groups;
|
||||
float *b = net.workspace;
|
||||
float *c = l.output + (i*l.groups + j)*n*m;
|
||||
|
||||
im2col_cpu(net.input + (i*l.groups + j)*l.c/l.groups*l.h*l.w,
|
||||
l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, b);
|
||||
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
|
||||
}
|
||||
}
|
||||
|
||||
if(l.batch_normalize){
|
||||
forward_batchnorm_layer(l, net);
|
||||
} else {
|
||||
add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
|
||||
add_bias(l.output, l.biases, l.batch, l.n, l.out_h*l.out_w);
|
||||
}
|
||||
|
||||
activate_array(l.output, m*n*l.batch, l.activation);
|
||||
activate_array(l.output, l.outputs*l.batch, l.activation);
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
}
|
||||
|
||||
void backward_convolutional_layer(convolutional_layer l, network net)
|
||||
{
|
||||
int i;
|
||||
int m = l.n;
|
||||
int n = l.size*l.size*l.c;
|
||||
int i, j;
|
||||
int m = l.n/l.groups;
|
||||
int n = l.size*l.size*l.c/l.groups;
|
||||
int k = l.out_w*l.out_h;
|
||||
|
||||
gradient_array(l.output, m*k*l.batch, l.activation, l.delta);
|
||||
gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
|
||||
|
||||
if(l.batch_normalize){
|
||||
backward_batchnorm_layer(l, net);
|
||||
@ -488,24 +495,27 @@ void backward_convolutional_layer(convolutional_layer l, network net)
|
||||
}
|
||||
|
||||
for(i = 0; i < l.batch; ++i){
|
||||
float *a = l.delta + i*m*k;
|
||||
float *b = net.workspace;
|
||||
float *c = l.weight_updates;
|
||||
for(j = 0; j < l.groups; ++j){
|
||||
float *a = l.delta + (i*l.groups + j)*m*k;
|
||||
float *b = net.workspace;
|
||||
float *c = l.weight_updates + j*l.nweights/l.groups;
|
||||
|
||||
float *im = net.input+i*l.c*l.h*l.w;
|
||||
float *im = net.input+(i*l.groups + j)*l.c/l.groups*l.h*l.w;
|
||||
|
||||
im2col_cpu(im, l.c, l.h, l.w,
|
||||
l.size, l.stride, l.pad, b);
|
||||
gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
|
||||
im2col_cpu(im, l.c/l.groups, l.h, l.w,
|
||||
l.size, l.stride, l.pad, b);
|
||||
gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
|
||||
|
||||
if(net.delta){
|
||||
a = l.weights;
|
||||
b = l.delta + i*m*k;
|
||||
c = net.workspace;
|
||||
if(net.delta){
|
||||
a = l.weights + j*l.nweights/l.groups;
|
||||
b = l.delta + (i*l.groups + j)*m*k;
|
||||
c = net.workspace;
|
||||
|
||||
gemm(1,0,n,k,m,1,a,n,b,k,0,c,k);
|
||||
gemm(1,0,n,k,m,1,a,n,b,k,0,c,k);
|
||||
|
||||
col2im_cpu(net.workspace, l.c, l.h, l.w, l.size, l.stride, l.pad, net.delta+i*l.c*l.h*l.w);
|
||||
col2im_cpu(net.workspace, l.c/l.groups, l.h, l.w, l.size, l.stride,
|
||||
l.pad, net.delta + (i*l.groups + j)*l.c/l.groups*l.h*l.w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -517,7 +527,6 @@ void update_convolutional_layer(convolutional_layer l, update_args a)
|
||||
float decay = a.decay;
|
||||
int batch = a.batch;
|
||||
|
||||
int size = l.size*l.size*l.c*l.n;
|
||||
axpy_cpu(l.n, learning_rate/batch, l.bias_updates, 1, l.biases, 1);
|
||||
scal_cpu(l.n, momentum, l.bias_updates, 1);
|
||||
|
||||
@ -526,9 +535,9 @@ void update_convolutional_layer(convolutional_layer l, update_args a)
|
||||
scal_cpu(l.n, momentum, l.scale_updates, 1);
|
||||
}
|
||||
|
||||
axpy_cpu(size, -decay*batch, l.weights, 1, l.weight_updates, 1);
|
||||
axpy_cpu(size, learning_rate/batch, l.weight_updates, 1, l.weights, 1);
|
||||
scal_cpu(size, momentum, l.weight_updates, 1);
|
||||
axpy_cpu(l.nweights, -decay*batch, l.weights, 1, l.weight_updates, 1);
|
||||
axpy_cpu(l.nweights, learning_rate/batch, l.weight_updates, 1, l.weights, 1);
|
||||
scal_cpu(l.nweights, momentum, l.weight_updates, 1);
|
||||
}
|
||||
|
||||
|
||||
@ -536,7 +545,7 @@ image get_convolutional_weight(convolutional_layer l, int i)
|
||||
{
|
||||
int h = l.size;
|
||||
int w = l.size;
|
||||
int c = l.c;
|
||||
int c = l.c/l.groups;
|
||||
return float_to_image(w,h,c,l.weights+i*h*w*c);
|
||||
}
|
||||
|
||||
@ -572,10 +581,10 @@ image *get_weights(convolutional_layer l)
|
||||
weights[i] = copy_image(get_convolutional_weight(l, i));
|
||||
normalize_image(weights[i]);
|
||||
/*
|
||||
char buff[256];
|
||||
sprintf(buff, "filter%d", i);
|
||||
save_image(weights[i], buff);
|
||||
*/
|
||||
char buff[256];
|
||||
sprintf(buff, "filter%d", i);
|
||||
save_image(weights[i], buff);
|
||||
*/
|
||||
}
|
||||
//error("hey");
|
||||
return weights;
|
||||
|
@ -25,7 +25,7 @@ void cudnn_convolutional_setup(layer *l);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam);
|
||||
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam);
|
||||
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
|
||||
void forward_convolutional_layer(const convolutional_layer layer, network net);
|
||||
void update_convolutional_layer(convolutional_layer layer, update_args a);
|
||||
|
@ -48,17 +48,17 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
|
||||
|
||||
l.input_layer = malloc(sizeof(layer));
|
||||
fprintf(stderr, "\t\t");
|
||||
*(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
|
||||
*(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 1, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
|
||||
l.input_layer->batch = batch;
|
||||
|
||||
l.self_layer = malloc(sizeof(layer));
|
||||
fprintf(stderr, "\t\t");
|
||||
*(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
|
||||
*(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 1, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
|
||||
l.self_layer->batch = batch;
|
||||
|
||||
l.output_layer = malloc(sizeof(layer));
|
||||
fprintf(stderr, "\t\t");
|
||||
*(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
|
||||
*(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 1, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
|
||||
l.output_layer->batch = batch;
|
||||
|
||||
l.output = l.output_layer->output;
|
||||
|
@ -162,7 +162,7 @@ void merge_weights(layer l, layer base)
|
||||
{
|
||||
if (l.type == CONVOLUTIONAL) {
|
||||
axpy_cpu(l.n, 1, l.bias_updates, 1, base.biases, 1);
|
||||
axpy_cpu(l.n*l.size*l.size*l.c, 1, l.weight_updates, 1, base.weights, 1);
|
||||
axpy_cpu(l.nweights, 1, l.weight_updates, 1, base.weights, 1);
|
||||
if (l.scales) {
|
||||
axpy_cpu(l.n, 1, l.scale_updates, 1, base.scales, 1);
|
||||
}
|
||||
@ -176,7 +176,7 @@ void scale_weights(layer l, float s)
|
||||
{
|
||||
if (l.type == CONVOLUTIONAL) {
|
||||
scal_cpu(l.n, s, l.biases, 1);
|
||||
scal_cpu(l.n*l.size*l.size*l.c, s, l.weights, 1);
|
||||
scal_cpu(l.nweights, s, l.weights, 1);
|
||||
if (l.scales) {
|
||||
scal_cpu(l.n, s, l.scales, 1);
|
||||
}
|
||||
@ -191,7 +191,7 @@ void pull_weights(layer l)
|
||||
{
|
||||
if(l.type == CONVOLUTIONAL || l.type == DECONVOLUTIONAL){
|
||||
cuda_pull_array(l.biases_gpu, l.bias_updates, l.n);
|
||||
cuda_pull_array(l.weights_gpu, l.weight_updates, l.n*l.size*l.size*l.c);
|
||||
cuda_pull_array(l.weights_gpu, l.weight_updates, l.nweights);
|
||||
if(l.scales) cuda_pull_array(l.scales_gpu, l.scale_updates, l.n);
|
||||
} else if(l.type == CONNECTED){
|
||||
cuda_pull_array(l.biases_gpu, l.bias_updates, l.outputs);
|
||||
@ -203,7 +203,7 @@ void push_weights(layer l)
|
||||
{
|
||||
if(l.type == CONVOLUTIONAL || l.type == DECONVOLUTIONAL){
|
||||
cuda_push_array(l.biases_gpu, l.biases, l.n);
|
||||
cuda_push_array(l.weights_gpu, l.weights, l.n*l.size*l.size*l.c);
|
||||
cuda_push_array(l.weights_gpu, l.weights, l.nweights);
|
||||
if(l.scales) cuda_push_array(l.scales_gpu, l.scales, l.n);
|
||||
} else if(l.type == CONNECTED){
|
||||
cuda_push_array(l.biases_gpu, l.biases, l.outputs);
|
||||
@ -215,7 +215,7 @@ void distribute_weights(layer l, layer base)
|
||||
{
|
||||
if (l.type == CONVOLUTIONAL || l.type == DECONVOLUTIONAL) {
|
||||
cuda_push_array(l.biases_gpu, base.biases, l.n);
|
||||
cuda_push_array(l.weights_gpu, base.weights, l.n*l.size*l.size*l.c);
|
||||
cuda_push_array(l.weights_gpu, base.weights, l.nweights);
|
||||
if (base.scales) cuda_push_array(l.scales_gpu, base.scales, l.n);
|
||||
} else if (l.type == CONNECTED) {
|
||||
cuda_push_array(l.biases_gpu, base.biases, l.outputs);
|
||||
@ -230,7 +230,7 @@ void pull_updates(layer l)
|
||||
{
|
||||
if(l.type == CONVOLUTIONAL){
|
||||
cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.n);
|
||||
cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.n*l.size*l.size*l.c);
|
||||
cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.nweights);
|
||||
if(l.scale_updates) cuda_pull_array(l.scale_updates_gpu, l.scale_updates, l.n);
|
||||
} else if(l.type == CONNECTED){
|
||||
cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
|
||||
@ -242,7 +242,7 @@ void push_updates(layer l)
|
||||
{
|
||||
if(l.type == CONVOLUTIONAL){
|
||||
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n);
|
||||
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.n*l.size*l.size*l.c);
|
||||
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.nweights);
|
||||
if(l.scale_updates) cuda_push_array(l.scale_updates_gpu, l.scale_updates, l.n);
|
||||
} else if(l.type == CONNECTED){
|
||||
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
|
||||
@ -263,7 +263,7 @@ void merge_updates(layer l, layer base)
|
||||
{
|
||||
if (l.type == CONVOLUTIONAL) {
|
||||
axpy_cpu(l.n, 1, l.bias_updates, 1, base.bias_updates, 1);
|
||||
axpy_cpu(l.n*l.size*l.size*l.c, 1, l.weight_updates, 1, base.weight_updates, 1);
|
||||
axpy_cpu(l.nweights, 1, l.weight_updates, 1, base.weight_updates, 1);
|
||||
if (l.scale_updates) {
|
||||
axpy_cpu(l.n, 1, l.scale_updates, 1, base.scale_updates, 1);
|
||||
}
|
||||
@ -277,7 +277,7 @@ void distribute_updates(layer l, layer base)
|
||||
{
|
||||
if(l.type == CONVOLUTIONAL || l.type == DECONVOLUTIONAL){
|
||||
cuda_push_array(l.bias_updates_gpu, base.bias_updates, l.n);
|
||||
cuda_push_array(l.weight_updates_gpu, base.weight_updates, l.n*l.size*l.size*l.c);
|
||||
cuda_push_array(l.weight_updates_gpu, base.weight_updates, l.nweights);
|
||||
if(base.scale_updates) cuda_push_array(l.scale_updates_gpu, base.scale_updates, l.n);
|
||||
} else if(l.type == CONNECTED){
|
||||
cuda_push_array(l.bias_updates_gpu, base.bias_updates, l.outputs);
|
||||
|
@ -173,6 +173,7 @@ convolutional_layer parse_convolutional(list *options, size_params params)
|
||||
int stride = option_find_int(options, "stride",1);
|
||||
int pad = option_find_int_quiet(options, "pad",0);
|
||||
int padding = option_find_int_quiet(options, "padding",0);
|
||||
int groups = option_find_int_quiet(options, "groups", 1);
|
||||
if(pad) padding = size/2;
|
||||
|
||||
char *activation_s = option_find_str(options, "activation", "logistic");
|
||||
@ -188,7 +189,7 @@ convolutional_layer parse_convolutional(list *options, size_params params)
|
||||
int binary = option_find_int_quiet(options, "binary", 0);
|
||||
int xnor = option_find_int_quiet(options, "xnor", 0);
|
||||
|
||||
convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam);
|
||||
convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,groups,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam);
|
||||
layer.flipped = option_find_int_quiet(options, "flipped", 0);
|
||||
layer.dot = option_find_float_quiet(options, "dot", 0);
|
||||
|
||||
@ -841,7 +842,7 @@ void save_convolutional_weights(layer l, FILE *fp)
|
||||
pull_convolutional_layer(l);
|
||||
}
|
||||
#endif
|
||||
int num = l.n*l.c*l.size*l.size;
|
||||
int num = l.nweights;
|
||||
fwrite(l.biases, sizeof(float), l.n, fp);
|
||||
if (l.batch_normalize){
|
||||
fwrite(l.scales, sizeof(float), l.n, fp);
|
||||
@ -1041,7 +1042,7 @@ void load_convolutional_weights(layer l, FILE *fp)
|
||||
//load_convolutional_weights_binary(l, fp);
|
||||
//return;
|
||||
}
|
||||
int num = l.n*l.c*l.size*l.size;
|
||||
int num = l.nweights;
|
||||
fread(l.biases, sizeof(float), l.n, fp);
|
||||
if (l.batch_normalize && (!l.dontloadscales)){
|
||||
fread(l.scales, sizeof(float), l.n, fp);
|
||||
|
Loading…
Reference in New Issue
Block a user