tweak to maxpool layers

This commit is contained in:
Joseph Redmon 2018-08-02 23:45:54 -07:00
parent 49ba88d9f7
commit e209b3bbbf
2 changed files with 8 additions and 8 deletions

View File

@ -27,8 +27,8 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
l.w = w;
l.c = c;
l.pad = padding;
l.out_w = (w + 2*padding)/stride;
l.out_h = (h + 2*padding)/stride;
l.out_w = (w + 2*padding - size)/stride + 1;
l.out_h = (h + 2*padding - size)/stride + 1;
l.out_c = c;
l.outputs = l.out_h * l.out_w * l.out_c;
l.inputs = h*w*c;
@ -57,8 +57,8 @@ void resize_maxpool_layer(maxpool_layer *l, int w, int h)
l->w = w;
l->inputs = h*w*l->c;
l->out_w = (w + 2*l->pad)/l->stride;
l->out_h = (h + 2*l->pad)/l->stride;
l->out_w = (w + 2*l->pad - l->size)/l->stride + 1;
l->out_h = (h + 2*l->pad - l->size)/l->stride + 1;
l->outputs = l->out_w * l->out_h * l->c;
int output_size = l->outputs * l->batch;

View File

@ -9,8 +9,8 @@ extern "C" {
__global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *input, float *output, int *indexes)
{
int h = (in_h + 2*pad)/stride;
int w = (in_w + 2*pad)/stride;
int h = (in_h + 2*pad - size)/stride + 1;
int w = (in_w + 2*pad - size)/stride + 1;
int c = in_c;
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@ -49,8 +49,8 @@ __global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c
__global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *delta, float *prev_delta, int *indexes)
{
int h = (in_h + 2*pad)/stride;
int w = (in_w + 2*pad)/stride;
int h = (in_h + 2*pad - size)/stride + 1;
int w = (in_w + 2*pad - size)/stride + 1;
int c = in_c;
int area = (size-1)/stride;