mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Generalizing conv layer so deconv is easier
This commit is contained in:
parent
7ee45082f1
commit
979d02126b
@ -8,7 +8,7 @@ extern "C" {
|
|||||||
#include "cuda.h"
|
#include "cuda.h"
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void bias(int n, int size, float *biases, float *output)
|
__global__ void bias_output_kernel(float *output, float *biases, int n, int size)
|
||||||
{
|
{
|
||||||
int offset = blockIdx.x * blockDim.x + threadIdx.x;
|
int offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int filter = blockIdx.y;
|
int filter = blockIdx.y;
|
||||||
@ -17,18 +17,16 @@ __global__ void bias(int n, int size, float *biases, float *output)
|
|||||||
if(offset < size) output[(batch*n+filter)*size + offset] = biases[filter];
|
if(offset < size) output[(batch*n+filter)*size + offset] = biases[filter];
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void bias_output_gpu(const convolutional_layer layer)
|
extern "C" void bias_output_gpu(float *output, float *biases, int batch, int n, int size)
|
||||||
{
|
{
|
||||||
int size = convolutional_out_height(layer)*convolutional_out_width(layer);
|
|
||||||
|
|
||||||
dim3 dimBlock(BLOCK, 1, 1);
|
dim3 dimBlock(BLOCK, 1, 1);
|
||||||
dim3 dimGrid((size-1)/BLOCK + 1, layer.n, layer.batch);
|
dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
|
||||||
|
|
||||||
bias<<<dimGrid, dimBlock>>>(layer.n, size, layer.biases_gpu, layer.output_gpu);
|
bias_output_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void learn_bias(int batch, int n, int size, float *delta, float *bias_updates, float scale)
|
__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size, float scale)
|
||||||
{
|
{
|
||||||
__shared__ float part[BLOCK];
|
__shared__ float part[BLOCK];
|
||||||
int i,b;
|
int i,b;
|
||||||
@ -48,36 +46,14 @@ __global__ void learn_bias(int batch, int n, int size, float *delta, float *bias
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void learn_bias_convolutional_layer_ongpu(convolutional_layer layer)
|
extern "C" void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
|
||||||
{
|
{
|
||||||
int size = convolutional_out_height(layer)*convolutional_out_width(layer);
|
float alpha = 1./batch;
|
||||||
float alpha = 1./layer.batch;
|
|
||||||
|
|
||||||
learn_bias<<<layer.n, BLOCK>>>(layer.batch, layer.n, size, layer.delta_gpu, layer.bias_updates_gpu, alpha);
|
backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size, alpha);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void test_learn_bias(convolutional_layer l)
|
|
||||||
{
|
|
||||||
int i;
|
|
||||||
int size = convolutional_out_height(l) * convolutional_out_width(l);
|
|
||||||
for(i = 0; i < size*l.batch*l.n; ++i){
|
|
||||||
l.delta[i] = rand_uniform();
|
|
||||||
}
|
|
||||||
for(i = 0; i < l.n; ++i){
|
|
||||||
l.bias_updates[i] = rand_uniform();
|
|
||||||
}
|
|
||||||
cuda_push_array(l.delta_gpu, l.delta, size*l.batch*l.n);
|
|
||||||
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n);
|
|
||||||
float *gpu = (float *) calloc(l.n, sizeof(float));
|
|
||||||
cuda_pull_array(l.bias_updates_gpu, gpu, l.n);
|
|
||||||
for(i = 0; i < l.n; ++i) printf("%.9g %.9g\n", l.bias_updates[i], gpu[i]);
|
|
||||||
learn_bias_convolutional_layer_ongpu(l);
|
|
||||||
learn_bias_convolutional_layer(l);
|
|
||||||
cuda_pull_array(l.bias_updates_gpu, gpu, l.n);
|
|
||||||
for(i = 0; i < l.n; ++i) printf("%.9g %.9g\n", l.bias_updates[i], gpu[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, float *in)
|
extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, float *in)
|
||||||
{
|
{
|
||||||
int i;
|
int i;
|
||||||
@ -86,7 +62,7 @@ extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, float
|
|||||||
int n = convolutional_out_height(layer)*
|
int n = convolutional_out_height(layer)*
|
||||||
convolutional_out_width(layer);
|
convolutional_out_width(layer);
|
||||||
|
|
||||||
bias_output_gpu(layer);
|
bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, n);
|
||||||
|
|
||||||
for(i = 0; i < layer.batch; ++i){
|
for(i = 0; i < layer.batch; ++i){
|
||||||
im2col_ongpu(in, i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_gpu);
|
im2col_ongpu(in, i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_gpu);
|
||||||
@ -106,8 +82,9 @@ extern "C" void backward_convolutional_layer_gpu(convolutional_layer layer, floa
|
|||||||
int n = layer.size*layer.size*layer.c;
|
int n = layer.size*layer.size*layer.c;
|
||||||
int k = convolutional_out_height(layer)*
|
int k = convolutional_out_height(layer)*
|
||||||
convolutional_out_width(layer);
|
convolutional_out_width(layer);
|
||||||
|
|
||||||
gradient_array_ongpu(layer.output_gpu, m*k*layer.batch, layer.activation, layer.delta_gpu);
|
gradient_array_ongpu(layer.output_gpu, m*k*layer.batch, layer.activation, layer.delta_gpu);
|
||||||
learn_bias_convolutional_layer_ongpu(layer);
|
backward_bias_gpu(layer.bias_updates_gpu, layer.delta_gpu, layer.batch, layer.n, k);
|
||||||
|
|
||||||
if(delta_gpu) scal_ongpu(layer.batch*layer.h*layer.w*layer.c, 0, delta_gpu, 1);
|
if(delta_gpu) scal_ongpu(layer.batch*layer.h*layer.w*layer.c, 0, delta_gpu, 1);
|
||||||
|
|
||||||
|
@ -111,27 +111,37 @@ void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c)
|
|||||||
layer->batch*out_h * out_w * layer->n*sizeof(float));
|
layer->batch*out_h * out_w * layer->n*sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
||||||
void bias_output(const convolutional_layer layer)
|
void bias_output(float *output, float *biases, int batch, int n, int size)
|
||||||
{
|
{
|
||||||
int i,j,b;
|
int i,j,b;
|
||||||
int out_h = convolutional_out_height(layer);
|
for(b = 0; b < batch; ++b){
|
||||||
int out_w = convolutional_out_width(layer);
|
for(i = 0; i < n; ++i){
|
||||||
for(b = 0; b < layer.batch; ++b){
|
for(j = 0; j < size; ++j){
|
||||||
for(i = 0; i < layer.n; ++i){
|
output[(b*n + i)*size + j] = biases[i];
|
||||||
for(j = 0; j < out_h*out_w; ++j){
|
|
||||||
layer.output[(b*layer.n + i)*out_h*out_w + j] = layer.biases[i];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void backward_bias(float *bias_updates, float *delta, int batch, int n, int size)
|
||||||
|
{
|
||||||
|
float alpha = 1./batch;
|
||||||
|
int i,b;
|
||||||
|
for(b = 0; b < batch; ++b){
|
||||||
|
for(i = 0; i < n; ++i){
|
||||||
|
bias_updates[i] += alpha * sum_array(delta+size*(i+b*n), size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void forward_convolutional_layer(const convolutional_layer layer, float *in)
|
void forward_convolutional_layer(const convolutional_layer layer, float *in)
|
||||||
{
|
{
|
||||||
int out_h = convolutional_out_height(layer);
|
int out_h = convolutional_out_height(layer);
|
||||||
int out_w = convolutional_out_width(layer);
|
int out_w = convolutional_out_width(layer);
|
||||||
int i;
|
int i;
|
||||||
|
|
||||||
bias_output(layer);
|
bias_output(layer.output, layer.biases, layer.batch, layer.n, out_h*out_w);
|
||||||
|
|
||||||
int m = layer.n;
|
int m = layer.n;
|
||||||
int k = layer.size*layer.size*layer.c;
|
int k = layer.size*layer.size*layer.c;
|
||||||
@ -151,19 +161,6 @@ void forward_convolutional_layer(const convolutional_layer layer, float *in)
|
|||||||
activate_array(layer.output, m*n*layer.batch, layer.activation);
|
activate_array(layer.output, m*n*layer.batch, layer.activation);
|
||||||
}
|
}
|
||||||
|
|
||||||
void learn_bias_convolutional_layer(convolutional_layer layer)
|
|
||||||
{
|
|
||||||
float alpha = 1./layer.batch;
|
|
||||||
int i,b;
|
|
||||||
int size = convolutional_out_height(layer)
|
|
||||||
*convolutional_out_width(layer);
|
|
||||||
for(b = 0; b < layer.batch; ++b){
|
|
||||||
for(i = 0; i < layer.n; ++i){
|
|
||||||
layer.bias_updates[i] += alpha * sum_array(layer.delta+size*(i+b*layer.n), size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void backward_convolutional_layer(convolutional_layer layer, float *in, float *delta)
|
void backward_convolutional_layer(convolutional_layer layer, float *in, float *delta)
|
||||||
{
|
{
|
||||||
float alpha = 1./layer.batch;
|
float alpha = 1./layer.batch;
|
||||||
@ -174,8 +171,7 @@ void backward_convolutional_layer(convolutional_layer layer, float *in, float *d
|
|||||||
convolutional_out_width(layer);
|
convolutional_out_width(layer);
|
||||||
|
|
||||||
gradient_array(layer.output, m*k*layer.batch, layer.activation, layer.delta);
|
gradient_array(layer.output, m*k*layer.batch, layer.activation, layer.delta);
|
||||||
|
backward_bias(layer.bias_updates, layer.delta, layer.batch, layer.n, k);
|
||||||
learn_bias_convolutional_layer(layer);
|
|
||||||
|
|
||||||
if(delta) memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float));
|
if(delta) memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float));
|
||||||
|
|
||||||
|
@ -45,10 +45,12 @@ typedef struct {
|
|||||||
void forward_convolutional_layer_gpu(convolutional_layer layer, float * in);
|
void forward_convolutional_layer_gpu(convolutional_layer layer, float * in);
|
||||||
void backward_convolutional_layer_gpu(convolutional_layer layer, float * in, float * delta_gpu);
|
void backward_convolutional_layer_gpu(convolutional_layer layer, float * in, float * delta_gpu);
|
||||||
void update_convolutional_layer_gpu(convolutional_layer layer);
|
void update_convolutional_layer_gpu(convolutional_layer layer);
|
||||||
|
|
||||||
void push_convolutional_layer(convolutional_layer layer);
|
void push_convolutional_layer(convolutional_layer layer);
|
||||||
void pull_convolutional_layer(convolutional_layer layer);
|
void pull_convolutional_layer(convolutional_layer layer);
|
||||||
void learn_bias_convolutional_layer_ongpu(convolutional_layer layer);
|
|
||||||
void bias_output_gpu(const convolutional_layer layer);
|
void bias_output_gpu(float *output, float *biases, int batch, int n, int size);
|
||||||
|
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, float learning_rate, float momentum, float decay);
|
convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, float learning_rate, float momentum, float decay);
|
||||||
@ -59,14 +61,15 @@ image *visualize_convolutional_layer(convolutional_layer layer, char *window, im
|
|||||||
|
|
||||||
void backward_convolutional_layer(convolutional_layer layer, float *in, float *delta);
|
void backward_convolutional_layer(convolutional_layer layer, float *in, float *delta);
|
||||||
|
|
||||||
void bias_output(const convolutional_layer layer);
|
void bias_output(float *output, float *biases, int batch, int n, int size);
|
||||||
|
void backward_bias(float *bias_updates, float *delta, int batch, int n, int size);
|
||||||
|
|
||||||
image get_convolutional_image(convolutional_layer layer);
|
image get_convolutional_image(convolutional_layer layer);
|
||||||
image get_convolutional_delta(convolutional_layer layer);
|
image get_convolutional_delta(convolutional_layer layer);
|
||||||
image get_convolutional_filter(convolutional_layer layer, int i);
|
image get_convolutional_filter(convolutional_layer layer, int i);
|
||||||
|
|
||||||
int convolutional_out_height(convolutional_layer layer);
|
int convolutional_out_height(convolutional_layer layer);
|
||||||
int convolutional_out_width(convolutional_layer layer);
|
int convolutional_out_width(convolutional_layer layer);
|
||||||
void learn_bias_convolutional_layer(convolutional_layer layer);
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -225,8 +225,7 @@ char *basename(char *cfgfile)
|
|||||||
void train_imagenet(char *cfgfile, char *weightfile)
|
void train_imagenet(char *cfgfile, char *weightfile)
|
||||||
{
|
{
|
||||||
float avg_loss = -1;
|
float avg_loss = -1;
|
||||||
// TODO
|
srand(time(0));
|
||||||
srand(0);
|
|
||||||
char *base = basename(cfgfile);
|
char *base = basename(cfgfile);
|
||||||
printf("%s\n", base);
|
printf("%s\n", base);
|
||||||
network net = parse_network_cfg(cfgfile);
|
network net = parse_network_cfg(cfgfile);
|
||||||
@ -585,25 +584,6 @@ void visualize_cat()
|
|||||||
cvWaitKey(0);
|
cvWaitKey(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GPU
|
|
||||||
void test_convolutional_layer()
|
|
||||||
{
|
|
||||||
network net = parse_network_cfg("cfg/nist_conv.cfg");
|
|
||||||
int size = get_network_input_size(net);
|
|
||||||
float *in = calloc(size, sizeof(float));
|
|
||||||
int i;
|
|
||||||
for(i = 0; i < size; ++i) in[i] = rand_normal();
|
|
||||||
convolutional_layer layer = *(convolutional_layer *)net.layers[0];
|
|
||||||
int out_size = convolutional_out_height(layer)*convolutional_out_width(layer)*layer.batch;
|
|
||||||
cuda_compare(layer.output_gpu, layer.output, out_size, "nothing");
|
|
||||||
cuda_compare(layer.biases_gpu, layer.biases, layer.n, "biases");
|
|
||||||
cuda_compare(layer.filters_gpu, layer.filters, layer.n*layer.size*layer.size*layer.c, "filters");
|
|
||||||
bias_output(layer);
|
|
||||||
bias_output_gpu(layer);
|
|
||||||
cuda_compare(layer.output_gpu, layer.output, out_size, "biased output");
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void test_correct_nist()
|
void test_correct_nist()
|
||||||
{
|
{
|
||||||
network net = parse_network_cfg("cfg/nist_conv.cfg");
|
network net = parse_network_cfg("cfg/nist_conv.cfg");
|
||||||
|
Loading…
Reference in New Issue
Block a user