Changes to make routing work better

This commit is contained in:
Joseph Redmon 2015-07-21 16:09:33 -07:00
parent e56d1eff13
commit d00f0a1ccd
18 changed files with 29 additions and 86 deletions

View File

@ -1,5 +1,5 @@
GPU=0
OPENCV=0
GPU=1
OPENCV=1
DEBUG=0
ARCH= -arch=sm_52

View File

@ -58,7 +58,7 @@ void backward_avgpool_layer(const avgpool_layer l, network_state state)
int out_index = k + b*l.c;
for(i = 0; i < l.h*l.w; ++i){
int in_index = i + l.h*l.w*(k + b*l.c);
state.delta[in_index] = l.delta[out_index] / (l.h*l.w);
state.delta[in_index] += l.delta[out_index] / (l.h*l.w);
}
}
}

View File

@ -35,7 +35,7 @@ __global__ void backward_avgpool_layer_kernel(int n, int w, int h, int c, float
int out_index = (k + c*b);
for(i = 0; i < w*h; ++i){
int in_index = i + h*w*(k + b*c);
in_delta[in_index] = out_delta[out_index] / (w*h);
in_delta[in_index] += out_delta[out_index] / (w*h);
}
}

View File

@ -33,7 +33,7 @@ __global__ void col2im_gpu_kernel(const int n, const float* data_col,
val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col];
}
}
data_im[index] = val;
data_im[index] += val;
}
}
@ -53,62 +53,3 @@ void col2im_ongpu(float *data_col,
width_col, data_im);
}
/*
__global__ void col2im_kernel(float *data_col,
int channels, int height, int width,
int ksize, int stride, int pad, float *data_im)
{
int height_col = (height - ksize) / stride + 1;
int width_col = (width - ksize) / stride + 1;
if (pad){
height_col = 1 + (height-1) / stride;
width_col = 1 + (width-1) / stride;
pad = ksize/2;
}
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(id >= channels*height*width) return;
int index = id;
int w = id%width + pad;
id /= width;
int h = id%height + pad;
id /= height;
int c = id%channels;
int w_start = (w-ksize+stride)/stride;
int w_end = w/stride + 1;
int h_start = (h-ksize+stride)/stride;
int h_end = h/stride + 1;
// int rows = channels * ksize * ksize;
// int cols = height_col*width_col;
int col_offset = (c*ksize*ksize + h * ksize + w)*height_col*width_col;
int h_coeff = (1-stride*ksize*height_col)*width_col;
int w_coeff = 1-stride*height_col*width_col;
float val = 0;
int h_col, w_col;
for(h_col = h_start; h_col < h_end; ++h_col){
for(w_col = w_start; w_col < w_end; ++w_col){
int col_index = col_offset +h_col*h_coeff + w_col*w_coeff;
float part = (w_col < 0 || h_col < 0 || h_col >= height_col || w_col >= width_col) ? 0 : data_col[col_index];
val += part;
}
}
data_im[index] = val;
}
extern "C" void col2im_ongpu(float *data_col,
int channels, int height, int width,
int ksize, int stride, int pad, float *data_im)
{
size_t n = channels*height*width;
col2im_kernel<<<cuda_gridsize(n), BLOCK>>>(data_col, channels, height, width, ksize, stride, pad, data_im);
check_error(cudaPeekAtLastError());
}
*/

View File

@ -103,7 +103,7 @@ void backward_connected_layer(connected_layer l, network_state state)
b = l.weights;
c = state.delta;
if(c) gemm(0,1,m,n,k,1,a,k,b,k,0,c,n);
if(c) gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
}
#ifdef GPU
@ -173,6 +173,6 @@ void backward_connected_layer_gpu(connected_layer l, network_state state)
b = l.weights_gpu;
c = state.delta;
if(c) gemm_ongpu(0,1,m,n,k,1,a,k,b,k,0,c,n);
if(c) gemm_ongpu(0,1,m,n,k,1,a,k,b,k,1,c,n);
}
#endif

View File

@ -82,8 +82,6 @@ void backward_convolutional_layer_gpu(convolutional_layer layer, network_state s
gradient_array_ongpu(layer.output_gpu, m*k*layer.batch, layer.activation, layer.delta_gpu);
backward_bias_gpu(layer.bias_updates_gpu, layer.delta_gpu, layer.batch, layer.n, k);
if(state.delta) scal_ongpu(layer.batch*layer.h*layer.w*layer.c, 0, state.delta, 1);
for(i = 0; i < layer.batch; ++i){
float * a = layer.delta_gpu;
float * b = layer.col_image_gpu;

View File

@ -188,8 +188,6 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
gradient_array(l.output, m*k*l.batch, l.activation, l.delta);
backward_bias(l.bias_updates, l.delta, l.batch, l.n, k);
if(state.delta) memset(state.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float));
for(i = 0; i < l.batch; ++i){
float *a = l.delta + i*m*k;
float *b = l.col_image;

View File

@ -61,7 +61,7 @@ void forward_cost_layer(cost_layer l, network_state state)
void backward_cost_layer(const cost_layer l, network_state state)
{
copy_cpu(l.batch*l.inputs, l.delta, 1, state.delta, 1);
axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, state.delta, 1);
}
#ifdef GPU
@ -92,7 +92,7 @@ void forward_cost_layer_gpu(cost_layer l, network_state state)
void backward_cost_layer_gpu(const cost_layer l, network_state state)
{
copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1);
axpy_ongpu(l.batch*l.inputs, 1, l.delta_gpu, 1, state.delta, 1);
}
#endif

View File

@ -159,8 +159,6 @@ void backward_deconvolutional_layer(deconvolutional_layer l, network_state state
gradient_array(l.output, size*l.n*l.batch, l.activation, l.delta);
backward_bias(l.bias_updates, l.delta, l.batch, l.n, size);
if(state.delta) memset(state.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float));
for(i = 0; i < l.batch; ++i){
int m = l.c;
int n = l.size*l.size*l.n;

View File

@ -141,20 +141,20 @@ void backward_detection_layer(const detection_layer l, network_state state)
float scale = 1;
float latent_delta = 0;
if(l.joint) scale = state.input[in_i++];
else if (l.objectness) state.delta[in_i++] = -l.delta[out_i++];
else if (l.background) state.delta[in_i++] = scale*l.delta[out_i++];
else if (l.objectness) state.delta[in_i++] += -l.delta[out_i++];
else if (l.background) state.delta[in_i++] += scale*l.delta[out_i++];
for(j = 0; j < l.classes; ++j){
latent_delta += state.input[in_i]*l.delta[out_i];
state.delta[in_i++] = scale*l.delta[out_i++];
state.delta[in_i++] += scale*l.delta[out_i++];
}
if (l.objectness) {
}else if (l.background) gradient_array(l.output + out_i, l.coords, LOGISTIC, l.delta + out_i);
for(j = 0; j < l.coords; ++j){
state.delta[in_i++] = l.delta[out_i++];
state.delta[in_i++] += l.delta[out_i++];
}
if(l.joint) state.delta[in_i-l.coords-l.classes-l.joint] = latent_delta;
if(l.joint) state.delta[in_i-l.coords-l.classes-l.joint] += latent_delta;
}
}
@ -198,7 +198,8 @@ void backward_detection_layer_gpu(detection_layer l, network_state state)
cpu_state.truth = truth_cpu;
cpu_state.delta = delta_cpu;
cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
cuda_pull_array(state.delta, delta_cpu, l.batch*l.inputs);
cuda_pull_array(l.delta_gpu, l.delta, l.batch*outputs);
backward_detection_layer(l, cpu_state);
cuda_push_array(state.delta, delta_cpu, l.batch*l.inputs);

View File

@ -114,7 +114,6 @@ void backward_maxpool_layer(const maxpool_layer l, network_state state)
int h = (l.h-1)/l.stride + 1;
int w = (l.w-1)/l.stride + 1;
int c = l.c;
memset(state.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float));
for(i = 0; i < h*w*c*l.batch; ++i){
int index = l.indexes[i];
state.delta[index] += l.delta[i];

View File

@ -77,7 +77,7 @@ __global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_
d += (valid && indexes[out_index] == index) ? delta[out_index] : 0;
}
}
prev_delta[index] = d;
prev_delta[index] += d;
}
extern "C" void forward_maxpool_layer_gpu(maxpool_layer layer, network_state state)

View File

@ -68,6 +68,9 @@ void forward_network(network net, network_state state)
int i;
for(i = 0; i < net.n; ++i){
layer l = net.layers[i];
if(l.delta){
scal_cpu(l.outputs * l.batch, 0, l.delta, 1);
}
if(l.type == CONVOLUTIONAL){
forward_convolutional_layer(l, state);
} else if(l.type == DECONVOLUTIONAL){

View File

@ -21,6 +21,7 @@ extern "C" {
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "route_layer.h"
#include "blas.h"
}
float * get_network_output_gpu_layer(network net, int i);
@ -32,6 +33,9 @@ void forward_network_gpu(network net, network_state state)
int i;
for(i = 0; i < net.n; ++i){
layer l = net.layers[i];
if(l.delta){
scal_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1);
}
if(l.type == CONVOLUTIONAL){
forward_convolutional_layer_gpu(l, state);
} else if(l.type == DECONVOLUTIONAL){

View File

@ -90,6 +90,7 @@ void forward_normalization_layer(const layer layer, network_state state)
void backward_normalization_layer(const layer layer, network_state state)
{
// TODO This is approximate ;-)
// Also this should add in to delta instead of overwritting.
int w = layer.w;
int h = layer.h;

View File

@ -54,7 +54,7 @@ void backward_route_layer(const route_layer l, network net)
float *delta = net.layers[index].delta;
int input_size = l.input_sizes[i];
for(j = 0; j < l.batch; ++j){
copy_cpu(input_size, l.delta + offset + j*l.outputs, 1, delta + j*input_size, 1);
axpy_cpu(input_size, 1, l.delta + offset + j*l.outputs, 1, delta + j*input_size, 1);
}
offset += input_size;
}
@ -85,7 +85,7 @@ void backward_route_layer_gpu(const route_layer l, network net)
float *delta = net.layers[index].delta_gpu;
int input_size = l.input_sizes[i];
for(j = 0; j < l.batch; ++j){
copy_ongpu(input_size, l.delta_gpu + offset + j*l.outputs, 1, delta + j*input_size, 1);
axpy_ongpu(input_size, 1, l.delta_gpu + offset + j*l.outputs, 1, delta + j*input_size, 1);
}
offset += input_size;
}

View File

@ -58,7 +58,7 @@ void backward_softmax_layer(const softmax_layer l, network_state state)
{
int i;
for(i = 0; i < l.inputs*l.batch; ++i){
state.delta[i] = l.delta[i];
state.delta[i] += l.delta[i];
}
}

View File

@ -42,7 +42,7 @@ extern "C" void forward_softmax_layer_gpu(const softmax_layer layer, network_sta
extern "C" void backward_softmax_layer_gpu(const softmax_layer layer, network_state state)
{
copy_ongpu(layer.batch*layer.inputs, layer.delta_gpu, 1, state.delta, 1);
axpy_ongpu(layer.batch*layer.inputs, 1, layer.delta_gpu, 1, state.delta, 1);
}
/* This is if you want softmax w/o log-loss classification. You probably don't.