add avgpool layer

This commit is contained in:
Joseph Redmon 2015-07-13 15:04:21 -07:00
parent 4b36675471
commit 8561e49b5a
10 changed files with 194 additions and 10 deletions

View File

@ -34,9 +34,9 @@ CFLAGS+= -DGPU
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o normalization_layer.o
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o
ifeq ($(GPU), 1)
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
endif
OBJS = $(addprefix $(OBJDIR), $(OBJ))

66
src/avgpool_layer.c Normal file
View File

@ -0,0 +1,66 @@
#include "avgpool_layer.h"
#include "cuda.h"
#include <stdio.h>
avgpool_layer make_avgpool_layer(int batch, int w, int h, int c)
{
fprintf(stderr, "Avgpool Layer: %d x %d x %d image\n", w,h,c);
avgpool_layer l = {0};
l.type = AVGPOOL;
l.batch = batch;
l.h = h;
l.w = w;
l.c = c;
l.out_w = 1;
l.out_h = 1;
l.out_c = c;
l.outputs = l.out_c;
l.inputs = h*w*c;
int output_size = l.outputs * batch;
l.output = calloc(output_size, sizeof(float));
l.delta = calloc(output_size, sizeof(float));
#ifdef GPU
l.output_gpu = cuda_make_array(l.output, output_size);
l.delta_gpu = cuda_make_array(l.delta, output_size);
#endif
return l;
}
void resize_avgpool_layer(avgpool_layer *l, int w, int h)
{
l->h = h;
l->w = w;
}
void forward_avgpool_layer(const avgpool_layer l, network_state state)
{
int b,i,k;
for(b = 0; b < l.batch; ++b){
for(k = 0; k < l.c; ++k){
int out_index = k + b*l.c;
l.output[out_index] = 0;
for(i = 0; i < l.h*l.w; ++i){
int in_index = i + l.h*l.w*(k + b*l.c);
l.output[out_index] += state.input[in_index];
}
l.output[out_index] /= l.h*l.w;
}
}
}
void backward_avgpool_layer(const avgpool_layer l, network_state state)
{
int b,i,k;
for(b = 0; b < l.batch; ++b){
for(k = 0; k < l.c; ++k){
int out_index = k + b*l.c;
for(i = 0; i < l.h*l.w; ++i){
int in_index = i + l.h*l.w*(k + b*l.c);
state.delta[in_index] = l.delta[out_index] / (l.h*l.w);
}
}
}
}

23
src/avgpool_layer.h Normal file
View File

@ -0,0 +1,23 @@
#ifndef AVGPOOL_LAYER_H
#define AVGPOOL_LAYER_H
#include "image.h"
#include "params.h"
#include "cuda.h"
#include "layer.h"
typedef layer avgpool_layer;
image get_avgpool_image(avgpool_layer l);
avgpool_layer make_avgpool_layer(int batch, int w, int h, int c);
void resize_avgpool_layer(avgpool_layer *l, int w, int h);
void forward_avgpool_layer(const avgpool_layer l, network_state state);
void backward_avgpool_layer(const avgpool_layer l, network_state state);
#ifdef GPU
void forward_avgpool_layer_gpu(avgpool_layer l, network_state state);
void backward_avgpool_layer_gpu(avgpool_layer l, network_state state);
#endif
#endif

View File

@ -0,0 +1,57 @@
extern "C" {
#include "avgpool_layer.h"
#include "cuda.h"
}
__global__ void forward_avgpool_layer_kernel(int n, int w, int h, int c, float *input, float *output)
{
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(id >= n) return;
int k = id % c;
id /= c;
int b = id;
int i;
int out_index = (k + c*b);
output[out_index] = 0;
for(i = 0; i < w*h; ++i){
int in_index = i + h*w*(k + b*c);
output[out_index] += input[in_index];
}
output[out_index] /= w*h;
}
__global__ void backward_avgpool_layer_kernel(int n, int w, int h, int c, float *in_delta, float *out_delta)
{
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(id >= n) return;
int k = id % c;
id /= c;
int b = id;
int i;
int out_index = (k + c*b);
for(i = 0; i < w*h; ++i){
int in_index = i + h*w*(k + b*c);
in_delta[in_index] = out_delta[out_index] / (w*h);
}
}
extern "C" void forward_avgpool_layer_gpu(avgpool_layer layer, network_state state)
{
size_t n = layer.c*layer.batch;
forward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, state.input, layer.output_gpu);
check_error(cudaPeekAtLastError());
}
extern "C" void backward_avgpool_layer_gpu(avgpool_layer layer, network_state state)
{
size_t n = layer.c*layer.batch;
backward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, state.delta, layer.delta_gpu);
check_error(cudaPeekAtLastError());
}

View File

@ -25,7 +25,7 @@ void train_imagenet(char *cfgfile, char *weightfile)
pthread_t load_thread;
data train;
data buffer;
load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, 256, 256, &buffer);
load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, net.w, net.h, &buffer);
while(1){
++i;
time=clock();
@ -38,7 +38,7 @@ void train_imagenet(char *cfgfile, char *weightfile)
cvWaitKey(0);
*/
load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, 256, 256, &buffer);
load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, net.w, net.h, &buffer);
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
float loss = train_network(net, train);
@ -47,7 +47,7 @@ void train_imagenet(char *cfgfile, char *weightfile)
avg_loss = avg_loss*.9 + loss*.1;
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
free_data(train);
if((i % 30000) == 0) net.learning_rate *= .1;
if((i % 20000) == 0) net.learning_rate *= .1;
if(i%1000==0){
char buff[256];
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);

View File

@ -14,7 +14,8 @@ typedef enum {
CROP,
ROUTE,
COST,
NORMALIZATION
NORMALIZATION,
AVGPOOL
} LAYER_TYPE;
typedef enum{

View File

@ -12,6 +12,7 @@
#include "detection_layer.h"
#include "normalization_layer.h"
#include "maxpool_layer.h"
#include "avgpool_layer.h"
#include "cost_layer.h"
#include "softmax_layer.h"
#include "dropout_layer.h"
@ -28,6 +29,8 @@ char *get_layer_string(LAYER_TYPE a)
return "connected";
case MAXPOOL:
return "maxpool";
case AVGPOOL:
return "avgpool";
case SOFTMAX:
return "softmax";
case DETECTION:
@ -83,6 +86,8 @@ void forward_network(network net, network_state state)
forward_softmax_layer(l, state);
} else if(l.type == MAXPOOL){
forward_maxpool_layer(l, state);
} else if(l.type == AVGPOOL){
forward_avgpool_layer(l, state);
} else if(l.type == DROPOUT){
forward_dropout_layer(l, state);
} else if(l.type == ROUTE){
@ -156,6 +161,8 @@ void backward_network(network net, network_state state)
backward_normalization_layer(l, state);
} else if(l.type == MAXPOOL){
if(i != 0) backward_maxpool_layer(l, state);
} else if(l.type == AVGPOOL){
backward_avgpool_layer(l, state);
} else if(l.type == DROPOUT){
backward_dropout_layer(l, state);
} else if(l.type == DETECTION){
@ -273,6 +280,9 @@ int resize_network(network *net, int w, int h)
resize_convolutional_layer(&l, w, h);
}else if(l.type == MAXPOOL){
resize_maxpool_layer(&l, w, h);
}else if(l.type == AVGPOOL){
resize_avgpool_layer(&l, w, h);
break;
}else if(l.type == NORMALIZATION){
resize_normalization_layer(&l, w, h);
}else{

View File

@ -15,6 +15,7 @@ extern "C" {
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "maxpool_layer.h"
#include "avgpool_layer.h"
#include "normalization_layer.h"
#include "cost_layer.h"
#include "softmax_layer.h"
@ -49,6 +50,8 @@ void forward_network_gpu(network net, network_state state)
forward_normalization_layer_gpu(l, state);
} else if(l.type == MAXPOOL){
forward_maxpool_layer_gpu(l, state);
} else if(l.type == AVGPOOL){
forward_avgpool_layer_gpu(l, state);
} else if(l.type == DROPOUT){
forward_dropout_layer_gpu(l, state);
} else if(l.type == ROUTE){
@ -79,6 +82,8 @@ void backward_network_gpu(network net, network_state state)
backward_deconvolutional_layer_gpu(l, state);
} else if(l.type == MAXPOOL){
if(i != 0) backward_maxpool_layer_gpu(l, state);
} else if(l.type == AVGPOOL){
if(i != 0) backward_avgpool_layer_gpu(l, state);
} else if(l.type == DROPOUT){
backward_dropout_layer_gpu(l, state);
} else if(l.type == DETECTION){

View File

@ -40,10 +40,10 @@ void resize_normalization_layer(layer *layer, int w, int h)
layer->out_w = w;
layer->inputs = w*h*c;
layer->outputs = layer->inputs;
layer->output = realloc(layer->output, h * w * layer->c * layer->batch * sizeof(float));
layer->delta = realloc(layer->delta, h * w * layer->c * layer->batch * sizeof(float));
layer->squared = realloc(layer->squared, h * w * layer->c * layer->batch * sizeof(float));
layer->norms = realloc(layer->norms, h * w * layer->c * layer->batch * sizeof(float));
layer->output = realloc(layer->output, h * w * c * batch * sizeof(float));
layer->delta = realloc(layer->delta, h * w * c * batch * sizeof(float));
layer->squared = realloc(layer->squared, h * w * c * batch * sizeof(float));
layer->norms = realloc(layer->norms, h * w * c * batch * sizeof(float));
#ifdef GPU
cuda_free(layer->output_gpu);
cuda_free(layer->delta_gpu);

View File

@ -14,6 +14,7 @@
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "detection_layer.h"
#include "avgpool_layer.h"
#include "route_layer.h"
#include "list.h"
#include "option_list.h"
@ -29,6 +30,7 @@ int is_convolutional(section *s);
int is_deconvolutional(section *s);
int is_connected(section *s);
int is_maxpool(section *s);
int is_avgpool(section *s);
int is_dropout(section *s);
int is_softmax(section *s);
int is_normalization(section *s);
@ -214,6 +216,19 @@ maxpool_layer parse_maxpool(list *options, size_params params)
return layer;
}
avgpool_layer parse_avgpool(list *options, size_params params)
{
int batch,w,h,c;
w = params.w;
h = params.h;
c = params.c;
batch=params.batch;
if(!(h && w && c)) error("Layer before avgpool layer must output image.");
avgpool_layer layer = make_avgpool_layer(batch,w,h,c);
return layer;
}
dropout_layer parse_dropout(list *options, size_params params)
{
float probability = option_find_float(options, "probability", .5);
@ -333,6 +348,8 @@ network parse_network_cfg(char *filename)
l = parse_normalization(options, params);
}else if(is_maxpool(s)){
l = parse_maxpool(options, params);
}else if(is_avgpool(s)){
l = parse_avgpool(options, params);
}else if(is_route(s)){
l = parse_route(options, params, net);
}else if(is_dropout(s)){
@ -402,6 +419,11 @@ int is_maxpool(section *s)
return (strcmp(s->type, "[max]")==0
|| strcmp(s->type, "[maxpool]")==0);
}
int is_avgpool(section *s)
{
return (strcmp(s->type, "[avg]")==0
|| strcmp(s->type, "[avgpool]")==0);
}
int is_dropout(section *s)
{
return (strcmp(s->type, "[dropout]")==0);