mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
add avgpool layer
This commit is contained in:
parent
4b36675471
commit
8561e49b5a
4
Makefile
4
Makefile
@ -34,9 +34,9 @@ CFLAGS+= -DGPU
|
|||||||
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
|
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
|
||||||
endif
|
endif
|
||||||
|
|
||||||
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o normalization_layer.o
|
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o
|
||||||
ifeq ($(GPU), 1)
|
ifeq ($(GPU), 1)
|
||||||
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
|
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
|
||||||
endif
|
endif
|
||||||
|
|
||||||
OBJS = $(addprefix $(OBJDIR), $(OBJ))
|
OBJS = $(addprefix $(OBJDIR), $(OBJ))
|
||||||
|
66
src/avgpool_layer.c
Normal file
66
src/avgpool_layer.c
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
#include "avgpool_layer.h"
|
||||||
|
#include "cuda.h"
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
avgpool_layer make_avgpool_layer(int batch, int w, int h, int c)
|
||||||
|
{
|
||||||
|
fprintf(stderr, "Avgpool Layer: %d x %d x %d image\n", w,h,c);
|
||||||
|
avgpool_layer l = {0};
|
||||||
|
l.type = AVGPOOL;
|
||||||
|
l.batch = batch;
|
||||||
|
l.h = h;
|
||||||
|
l.w = w;
|
||||||
|
l.c = c;
|
||||||
|
l.out_w = 1;
|
||||||
|
l.out_h = 1;
|
||||||
|
l.out_c = c;
|
||||||
|
l.outputs = l.out_c;
|
||||||
|
l.inputs = h*w*c;
|
||||||
|
int output_size = l.outputs * batch;
|
||||||
|
l.output = calloc(output_size, sizeof(float));
|
||||||
|
l.delta = calloc(output_size, sizeof(float));
|
||||||
|
#ifdef GPU
|
||||||
|
l.output_gpu = cuda_make_array(l.output, output_size);
|
||||||
|
l.delta_gpu = cuda_make_array(l.delta, output_size);
|
||||||
|
#endif
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
|
||||||
|
void resize_avgpool_layer(avgpool_layer *l, int w, int h)
|
||||||
|
{
|
||||||
|
l->h = h;
|
||||||
|
l->w = w;
|
||||||
|
}
|
||||||
|
|
||||||
|
void forward_avgpool_layer(const avgpool_layer l, network_state state)
|
||||||
|
{
|
||||||
|
int b,i,k;
|
||||||
|
|
||||||
|
for(b = 0; b < l.batch; ++b){
|
||||||
|
for(k = 0; k < l.c; ++k){
|
||||||
|
int out_index = k + b*l.c;
|
||||||
|
l.output[out_index] = 0;
|
||||||
|
for(i = 0; i < l.h*l.w; ++i){
|
||||||
|
int in_index = i + l.h*l.w*(k + b*l.c);
|
||||||
|
l.output[out_index] += state.input[in_index];
|
||||||
|
}
|
||||||
|
l.output[out_index] /= l.h*l.w;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward_avgpool_layer(const avgpool_layer l, network_state state)
|
||||||
|
{
|
||||||
|
int b,i,k;
|
||||||
|
|
||||||
|
for(b = 0; b < l.batch; ++b){
|
||||||
|
for(k = 0; k < l.c; ++k){
|
||||||
|
int out_index = k + b*l.c;
|
||||||
|
for(i = 0; i < l.h*l.w; ++i){
|
||||||
|
int in_index = i + l.h*l.w*(k + b*l.c);
|
||||||
|
state.delta[in_index] = l.delta[out_index] / (l.h*l.w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
23
src/avgpool_layer.h
Normal file
23
src/avgpool_layer.h
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
#ifndef AVGPOOL_LAYER_H
|
||||||
|
#define AVGPOOL_LAYER_H
|
||||||
|
|
||||||
|
#include "image.h"
|
||||||
|
#include "params.h"
|
||||||
|
#include "cuda.h"
|
||||||
|
#include "layer.h"
|
||||||
|
|
||||||
|
typedef layer avgpool_layer;
|
||||||
|
|
||||||
|
image get_avgpool_image(avgpool_layer l);
|
||||||
|
avgpool_layer make_avgpool_layer(int batch, int w, int h, int c);
|
||||||
|
void resize_avgpool_layer(avgpool_layer *l, int w, int h);
|
||||||
|
void forward_avgpool_layer(const avgpool_layer l, network_state state);
|
||||||
|
void backward_avgpool_layer(const avgpool_layer l, network_state state);
|
||||||
|
|
||||||
|
#ifdef GPU
|
||||||
|
void forward_avgpool_layer_gpu(avgpool_layer l, network_state state);
|
||||||
|
void backward_avgpool_layer_gpu(avgpool_layer l, network_state state);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
57
src/avgpool_layer_kernels.cu
Normal file
57
src/avgpool_layer_kernels.cu
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
extern "C" {
|
||||||
|
#include "avgpool_layer.h"
|
||||||
|
#include "cuda.h"
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void forward_avgpool_layer_kernel(int n, int w, int h, int c, float *input, float *output)
|
||||||
|
{
|
||||||
|
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||||
|
if(id >= n) return;
|
||||||
|
|
||||||
|
int k = id % c;
|
||||||
|
id /= c;
|
||||||
|
int b = id;
|
||||||
|
|
||||||
|
int i;
|
||||||
|
int out_index = (k + c*b);
|
||||||
|
output[out_index] = 0;
|
||||||
|
for(i = 0; i < w*h; ++i){
|
||||||
|
int in_index = i + h*w*(k + b*c);
|
||||||
|
output[out_index] += input[in_index];
|
||||||
|
}
|
||||||
|
output[out_index] /= w*h;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void backward_avgpool_layer_kernel(int n, int w, int h, int c, float *in_delta, float *out_delta)
|
||||||
|
{
|
||||||
|
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||||
|
if(id >= n) return;
|
||||||
|
|
||||||
|
int k = id % c;
|
||||||
|
id /= c;
|
||||||
|
int b = id;
|
||||||
|
|
||||||
|
int i;
|
||||||
|
int out_index = (k + c*b);
|
||||||
|
for(i = 0; i < w*h; ++i){
|
||||||
|
int in_index = i + h*w*(k + b*c);
|
||||||
|
in_delta[in_index] = out_delta[out_index] / (w*h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void forward_avgpool_layer_gpu(avgpool_layer layer, network_state state)
|
||||||
|
{
|
||||||
|
size_t n = layer.c*layer.batch;
|
||||||
|
|
||||||
|
forward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, state.input, layer.output_gpu);
|
||||||
|
check_error(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void backward_avgpool_layer_gpu(avgpool_layer layer, network_state state)
|
||||||
|
{
|
||||||
|
size_t n = layer.c*layer.batch;
|
||||||
|
|
||||||
|
backward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, state.delta, layer.delta_gpu);
|
||||||
|
check_error(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
|
@ -25,7 +25,7 @@ void train_imagenet(char *cfgfile, char *weightfile)
|
|||||||
pthread_t load_thread;
|
pthread_t load_thread;
|
||||||
data train;
|
data train;
|
||||||
data buffer;
|
data buffer;
|
||||||
load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, 256, 256, &buffer);
|
load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, net.w, net.h, &buffer);
|
||||||
while(1){
|
while(1){
|
||||||
++i;
|
++i;
|
||||||
time=clock();
|
time=clock();
|
||||||
@ -38,7 +38,7 @@ void train_imagenet(char *cfgfile, char *weightfile)
|
|||||||
cvWaitKey(0);
|
cvWaitKey(0);
|
||||||
*/
|
*/
|
||||||
|
|
||||||
load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, 256, 256, &buffer);
|
load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, net.w, net.h, &buffer);
|
||||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||||
time=clock();
|
time=clock();
|
||||||
float loss = train_network(net, train);
|
float loss = train_network(net, train);
|
||||||
@ -47,7 +47,7 @@ void train_imagenet(char *cfgfile, char *weightfile)
|
|||||||
avg_loss = avg_loss*.9 + loss*.1;
|
avg_loss = avg_loss*.9 + loss*.1;
|
||||||
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
|
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
|
||||||
free_data(train);
|
free_data(train);
|
||||||
if((i % 30000) == 0) net.learning_rate *= .1;
|
if((i % 20000) == 0) net.learning_rate *= .1;
|
||||||
if(i%1000==0){
|
if(i%1000==0){
|
||||||
char buff[256];
|
char buff[256];
|
||||||
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
|
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
|
||||||
|
@ -14,7 +14,8 @@ typedef enum {
|
|||||||
CROP,
|
CROP,
|
||||||
ROUTE,
|
ROUTE,
|
||||||
COST,
|
COST,
|
||||||
NORMALIZATION
|
NORMALIZATION,
|
||||||
|
AVGPOOL
|
||||||
} LAYER_TYPE;
|
} LAYER_TYPE;
|
||||||
|
|
||||||
typedef enum{
|
typedef enum{
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
#include "detection_layer.h"
|
#include "detection_layer.h"
|
||||||
#include "normalization_layer.h"
|
#include "normalization_layer.h"
|
||||||
#include "maxpool_layer.h"
|
#include "maxpool_layer.h"
|
||||||
|
#include "avgpool_layer.h"
|
||||||
#include "cost_layer.h"
|
#include "cost_layer.h"
|
||||||
#include "softmax_layer.h"
|
#include "softmax_layer.h"
|
||||||
#include "dropout_layer.h"
|
#include "dropout_layer.h"
|
||||||
@ -28,6 +29,8 @@ char *get_layer_string(LAYER_TYPE a)
|
|||||||
return "connected";
|
return "connected";
|
||||||
case MAXPOOL:
|
case MAXPOOL:
|
||||||
return "maxpool";
|
return "maxpool";
|
||||||
|
case AVGPOOL:
|
||||||
|
return "avgpool";
|
||||||
case SOFTMAX:
|
case SOFTMAX:
|
||||||
return "softmax";
|
return "softmax";
|
||||||
case DETECTION:
|
case DETECTION:
|
||||||
@ -83,6 +86,8 @@ void forward_network(network net, network_state state)
|
|||||||
forward_softmax_layer(l, state);
|
forward_softmax_layer(l, state);
|
||||||
} else if(l.type == MAXPOOL){
|
} else if(l.type == MAXPOOL){
|
||||||
forward_maxpool_layer(l, state);
|
forward_maxpool_layer(l, state);
|
||||||
|
} else if(l.type == AVGPOOL){
|
||||||
|
forward_avgpool_layer(l, state);
|
||||||
} else if(l.type == DROPOUT){
|
} else if(l.type == DROPOUT){
|
||||||
forward_dropout_layer(l, state);
|
forward_dropout_layer(l, state);
|
||||||
} else if(l.type == ROUTE){
|
} else if(l.type == ROUTE){
|
||||||
@ -156,6 +161,8 @@ void backward_network(network net, network_state state)
|
|||||||
backward_normalization_layer(l, state);
|
backward_normalization_layer(l, state);
|
||||||
} else if(l.type == MAXPOOL){
|
} else if(l.type == MAXPOOL){
|
||||||
if(i != 0) backward_maxpool_layer(l, state);
|
if(i != 0) backward_maxpool_layer(l, state);
|
||||||
|
} else if(l.type == AVGPOOL){
|
||||||
|
backward_avgpool_layer(l, state);
|
||||||
} else if(l.type == DROPOUT){
|
} else if(l.type == DROPOUT){
|
||||||
backward_dropout_layer(l, state);
|
backward_dropout_layer(l, state);
|
||||||
} else if(l.type == DETECTION){
|
} else if(l.type == DETECTION){
|
||||||
@ -273,6 +280,9 @@ int resize_network(network *net, int w, int h)
|
|||||||
resize_convolutional_layer(&l, w, h);
|
resize_convolutional_layer(&l, w, h);
|
||||||
}else if(l.type == MAXPOOL){
|
}else if(l.type == MAXPOOL){
|
||||||
resize_maxpool_layer(&l, w, h);
|
resize_maxpool_layer(&l, w, h);
|
||||||
|
}else if(l.type == AVGPOOL){
|
||||||
|
resize_avgpool_layer(&l, w, h);
|
||||||
|
break;
|
||||||
}else if(l.type == NORMALIZATION){
|
}else if(l.type == NORMALIZATION){
|
||||||
resize_normalization_layer(&l, w, h);
|
resize_normalization_layer(&l, w, h);
|
||||||
}else{
|
}else{
|
||||||
|
@ -15,6 +15,7 @@ extern "C" {
|
|||||||
#include "convolutional_layer.h"
|
#include "convolutional_layer.h"
|
||||||
#include "deconvolutional_layer.h"
|
#include "deconvolutional_layer.h"
|
||||||
#include "maxpool_layer.h"
|
#include "maxpool_layer.h"
|
||||||
|
#include "avgpool_layer.h"
|
||||||
#include "normalization_layer.h"
|
#include "normalization_layer.h"
|
||||||
#include "cost_layer.h"
|
#include "cost_layer.h"
|
||||||
#include "softmax_layer.h"
|
#include "softmax_layer.h"
|
||||||
@ -49,6 +50,8 @@ void forward_network_gpu(network net, network_state state)
|
|||||||
forward_normalization_layer_gpu(l, state);
|
forward_normalization_layer_gpu(l, state);
|
||||||
} else if(l.type == MAXPOOL){
|
} else if(l.type == MAXPOOL){
|
||||||
forward_maxpool_layer_gpu(l, state);
|
forward_maxpool_layer_gpu(l, state);
|
||||||
|
} else if(l.type == AVGPOOL){
|
||||||
|
forward_avgpool_layer_gpu(l, state);
|
||||||
} else if(l.type == DROPOUT){
|
} else if(l.type == DROPOUT){
|
||||||
forward_dropout_layer_gpu(l, state);
|
forward_dropout_layer_gpu(l, state);
|
||||||
} else if(l.type == ROUTE){
|
} else if(l.type == ROUTE){
|
||||||
@ -79,6 +82,8 @@ void backward_network_gpu(network net, network_state state)
|
|||||||
backward_deconvolutional_layer_gpu(l, state);
|
backward_deconvolutional_layer_gpu(l, state);
|
||||||
} else if(l.type == MAXPOOL){
|
} else if(l.type == MAXPOOL){
|
||||||
if(i != 0) backward_maxpool_layer_gpu(l, state);
|
if(i != 0) backward_maxpool_layer_gpu(l, state);
|
||||||
|
} else if(l.type == AVGPOOL){
|
||||||
|
if(i != 0) backward_avgpool_layer_gpu(l, state);
|
||||||
} else if(l.type == DROPOUT){
|
} else if(l.type == DROPOUT){
|
||||||
backward_dropout_layer_gpu(l, state);
|
backward_dropout_layer_gpu(l, state);
|
||||||
} else if(l.type == DETECTION){
|
} else if(l.type == DETECTION){
|
||||||
|
@ -40,10 +40,10 @@ void resize_normalization_layer(layer *layer, int w, int h)
|
|||||||
layer->out_w = w;
|
layer->out_w = w;
|
||||||
layer->inputs = w*h*c;
|
layer->inputs = w*h*c;
|
||||||
layer->outputs = layer->inputs;
|
layer->outputs = layer->inputs;
|
||||||
layer->output = realloc(layer->output, h * w * layer->c * layer->batch * sizeof(float));
|
layer->output = realloc(layer->output, h * w * c * batch * sizeof(float));
|
||||||
layer->delta = realloc(layer->delta, h * w * layer->c * layer->batch * sizeof(float));
|
layer->delta = realloc(layer->delta, h * w * c * batch * sizeof(float));
|
||||||
layer->squared = realloc(layer->squared, h * w * layer->c * layer->batch * sizeof(float));
|
layer->squared = realloc(layer->squared, h * w * c * batch * sizeof(float));
|
||||||
layer->norms = realloc(layer->norms, h * w * layer->c * layer->batch * sizeof(float));
|
layer->norms = realloc(layer->norms, h * w * c * batch * sizeof(float));
|
||||||
#ifdef GPU
|
#ifdef GPU
|
||||||
cuda_free(layer->output_gpu);
|
cuda_free(layer->output_gpu);
|
||||||
cuda_free(layer->delta_gpu);
|
cuda_free(layer->delta_gpu);
|
||||||
|
22
src/parser.c
22
src/parser.c
@ -14,6 +14,7 @@
|
|||||||
#include "softmax_layer.h"
|
#include "softmax_layer.h"
|
||||||
#include "dropout_layer.h"
|
#include "dropout_layer.h"
|
||||||
#include "detection_layer.h"
|
#include "detection_layer.h"
|
||||||
|
#include "avgpool_layer.h"
|
||||||
#include "route_layer.h"
|
#include "route_layer.h"
|
||||||
#include "list.h"
|
#include "list.h"
|
||||||
#include "option_list.h"
|
#include "option_list.h"
|
||||||
@ -29,6 +30,7 @@ int is_convolutional(section *s);
|
|||||||
int is_deconvolutional(section *s);
|
int is_deconvolutional(section *s);
|
||||||
int is_connected(section *s);
|
int is_connected(section *s);
|
||||||
int is_maxpool(section *s);
|
int is_maxpool(section *s);
|
||||||
|
int is_avgpool(section *s);
|
||||||
int is_dropout(section *s);
|
int is_dropout(section *s);
|
||||||
int is_softmax(section *s);
|
int is_softmax(section *s);
|
||||||
int is_normalization(section *s);
|
int is_normalization(section *s);
|
||||||
@ -214,6 +216,19 @@ maxpool_layer parse_maxpool(list *options, size_params params)
|
|||||||
return layer;
|
return layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
avgpool_layer parse_avgpool(list *options, size_params params)
|
||||||
|
{
|
||||||
|
int batch,w,h,c;
|
||||||
|
w = params.w;
|
||||||
|
h = params.h;
|
||||||
|
c = params.c;
|
||||||
|
batch=params.batch;
|
||||||
|
if(!(h && w && c)) error("Layer before avgpool layer must output image.");
|
||||||
|
|
||||||
|
avgpool_layer layer = make_avgpool_layer(batch,w,h,c);
|
||||||
|
return layer;
|
||||||
|
}
|
||||||
|
|
||||||
dropout_layer parse_dropout(list *options, size_params params)
|
dropout_layer parse_dropout(list *options, size_params params)
|
||||||
{
|
{
|
||||||
float probability = option_find_float(options, "probability", .5);
|
float probability = option_find_float(options, "probability", .5);
|
||||||
@ -333,6 +348,8 @@ network parse_network_cfg(char *filename)
|
|||||||
l = parse_normalization(options, params);
|
l = parse_normalization(options, params);
|
||||||
}else if(is_maxpool(s)){
|
}else if(is_maxpool(s)){
|
||||||
l = parse_maxpool(options, params);
|
l = parse_maxpool(options, params);
|
||||||
|
}else if(is_avgpool(s)){
|
||||||
|
l = parse_avgpool(options, params);
|
||||||
}else if(is_route(s)){
|
}else if(is_route(s)){
|
||||||
l = parse_route(options, params, net);
|
l = parse_route(options, params, net);
|
||||||
}else if(is_dropout(s)){
|
}else if(is_dropout(s)){
|
||||||
@ -402,6 +419,11 @@ int is_maxpool(section *s)
|
|||||||
return (strcmp(s->type, "[max]")==0
|
return (strcmp(s->type, "[max]")==0
|
||||||
|| strcmp(s->type, "[maxpool]")==0);
|
|| strcmp(s->type, "[maxpool]")==0);
|
||||||
}
|
}
|
||||||
|
int is_avgpool(section *s)
|
||||||
|
{
|
||||||
|
return (strcmp(s->type, "[avg]")==0
|
||||||
|
|| strcmp(s->type, "[avgpool]")==0);
|
||||||
|
}
|
||||||
int is_dropout(section *s)
|
int is_dropout(section *s)
|
||||||
{
|
{
|
||||||
return (strcmp(s->type, "[dropout]")==0);
|
return (strcmp(s->type, "[dropout]")==0);
|
||||||
|
Loading…
Reference in New Issue
Block a user