normalization layer

This commit is contained in:
Joseph Redmon 2015-07-09 15:22:14 -07:00
parent fc323c6310
commit 9c0b76ab8d
13 changed files with 294 additions and 17 deletions

View File

@ -34,7 +34,7 @@ CFLAGS+= -DGPU
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o normalization_layer.o
ifeq ($(GPU), 1)
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
endif

View File

@ -1,4 +1,23 @@
#include "blas.h"
#include "math.h"
void const_cpu(int N, float ALPHA, float *X, int INCX)
{
int i;
for(i = 0; i < N; ++i) X[i*INCX] = ALPHA;
}
void mul_cpu(int N, float *X, int INCX, float *Y, int INCY)
{
int i;
for(i = 0; i < N; ++i) Y[i*INCY] *= X[i*INCX];
}
void pow_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
{
int i;
for(i = 0; i < N; ++i) Y[i*INCY] = pow(X[i*INCX], ALPHA);
}
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
{

View File

@ -6,6 +6,10 @@ void time_random_matrix(int TA, int TB, int m, int k, int n);
void test_blas();
void const_cpu(int N, float ALPHA, float *X, int INCX);
void pow_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
void mul_cpu(int N, float *X, int INCX, float *Y, int INCY);
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
void copy_cpu(int N, float *X, int INCX, float *Y, int INCY);
void scal_cpu(int N, float ALPHA, float *X, int INCX);
@ -19,5 +23,9 @@ void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY);
void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY);
void scal_ongpu(int N, float ALPHA, float * X, int INCX);
void mask_ongpu(int N, float * X, float * mask);
void const_ongpu(int N, float ALPHA, float *X, int INCX);
void pow_ongpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
void mul_ongpu(int N, float *X, int INCX, float *Y, int INCY);
#endif
#endif

View File

@ -9,6 +9,18 @@ __global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX, f
if(i < N) Y[OFFY+i*INCY] += ALPHA*X[OFFX+i*INCX];
}
__global__ void pow_kernel(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(i < N) Y[i*INCY] = pow(X[i*INCX], ALPHA);
}
__global__ void const_kernel(int N, float ALPHA, float *X, int INCX)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(i < N) X[i*INCX] = ALPHA;
}
__global__ void scal_kernel(int N, float ALPHA, float *X, int INCX)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@ -27,11 +39,23 @@ __global__ void copy_kernel(int N, float *X, int OFFX, int INCX, float *Y, int
if(i < N) Y[i*INCY + OFFY] = X[i*INCX + OFFX];
}
__global__ void mul_kernel(int N, float *X, int INCX, float *Y, int INCY)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(i < N) Y[i*INCY] *= X[i*INCX];
}
extern "C" void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
{
axpy_ongpu_offset(N, ALPHA, X, 0, INCX, Y, 0, INCY);
}
extern "C" void pow_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
{
pow_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX, Y, INCY);
check_error(cudaPeekAtLastError());
}
extern "C" void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
{
axpy_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, OFFX, INCX, Y, OFFY, INCY);
@ -43,6 +67,12 @@ extern "C" void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY)
copy_ongpu_offset(N, X, 0, INCX, Y, 0, INCY);
}
extern "C" void mul_ongpu(int N, float * X, int INCX, float * Y, int INCY)
{
mul_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, INCX, Y, INCY);
check_error(cudaPeekAtLastError());
}
extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
{
copy_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, OFFX, INCX, Y, OFFY, INCY);
@ -55,6 +85,12 @@ extern "C" void mask_ongpu(int N, float * X, float * mask)
check_error(cudaPeekAtLastError());
}
extern "C" void const_ongpu(int N, float ALPHA, float * X, int INCX)
{
const_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
check_error(cudaPeekAtLastError());
}
extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX)
{
scal_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);

View File

@ -13,7 +13,8 @@ typedef enum {
DROPOUT,
CROP,
ROUTE,
COST
COST,
NORMALIZATION
} LAYER_TYPE;
typedef enum{
@ -48,6 +49,10 @@ typedef struct {
int does_cost;
int joint;
float alpha;
float beta;
float kappa;
int dontload;
float probability;
@ -69,6 +74,8 @@ typedef struct {
int * input_sizes;
float * delta;
float * output;
float * squared;
float * norms;
#ifdef GPU
int *indexes_gpu;
@ -86,6 +93,8 @@ typedef struct {
float * output_gpu;
float * delta_gpu;
float * rand_gpu;
float * squared_gpu;
float * norms_gpu;
#endif
} layer;

View File

@ -10,6 +10,7 @@
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "detection_layer.h"
#include "normalization_layer.h"
#include "maxpool_layer.h"
#include "cost_layer.h"
#include "softmax_layer.h"
@ -39,6 +40,8 @@ char *get_layer_string(LAYER_TYPE a)
return "cost";
case ROUTE:
return "route";
case NORMALIZATION:
return "normalization";
default:
break;
}
@ -66,6 +69,8 @@ void forward_network(network net, network_state state)
forward_convolutional_layer(l, state);
} else if(l.type == DECONVOLUTIONAL){
forward_deconvolutional_layer(l, state);
} else if(l.type == NORMALIZATION){
forward_normalization_layer(l, state);
} else if(l.type == DETECTION){
forward_detection_layer(l, state);
} else if(l.type == CONNECTED){
@ -147,6 +152,8 @@ void backward_network(network net, network_state state)
backward_convolutional_layer(l, state);
} else if(l.type == DECONVOLUTIONAL){
backward_deconvolutional_layer(l, state);
} else if(l.type == NORMALIZATION){
backward_normalization_layer(l, state);
} else if(l.type == MAXPOOL){
if(i != 0) backward_maxpool_layer(l, state);
} else if(l.type == DROPOUT){
@ -266,6 +273,8 @@ int resize_network(network *net, int w, int h)
resize_convolutional_layer(&l, w, h);
}else if(l.type == MAXPOOL){
resize_maxpool_layer(&l, w, h);
}else if(l.type == NORMALIZATION){
resize_normalization_layer(&l, w, h);
}else{
error("Cannot resize this type of layer");
}

View File

@ -15,6 +15,7 @@ extern "C" {
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "maxpool_layer.h"
#include "normalization_layer.h"
#include "cost_layer.h"
#include "softmax_layer.h"
#include "dropout_layer.h"
@ -44,6 +45,8 @@ void forward_network_gpu(network net, network_state state)
forward_cost_layer_gpu(l, state);
} else if(l.type == SOFTMAX){
forward_softmax_layer_gpu(l, state);
} else if(l.type == NORMALIZATION){
forward_normalization_layer_gpu(l, state);
} else if(l.type == MAXPOOL){
forward_maxpool_layer_gpu(l, state);
} else if(l.type == DROPOUT){
@ -80,6 +83,8 @@ void backward_network_gpu(network net, network_state state)
backward_dropout_layer_gpu(l, state);
} else if(l.type == DETECTION){
backward_detection_layer_gpu(l, state);
} else if(l.type == NORMALIZATION){
backward_normalization_layer_gpu(l, state);
} else if(l.type == SOFTMAX){
if(i != 0) backward_softmax_layer_gpu(l, state);
} else if(l.type == CONNECTED){
@ -136,20 +141,7 @@ float *get_network_output_layer_gpu(network net, int i)
{
layer l = net.layers[i];
cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
if(l.type == CONVOLUTIONAL){
return l.output;
} else if(l.type == DECONVOLUTIONAL){
return l.output;
} else if(l.type == CONNECTED){
return l.output;
} else if(l.type == DETECTION){
return l.output;
} else if(l.type == MAXPOOL){
return l.output;
} else if(l.type == SOFTMAX){
return l.output;
}
return 0;
return l.output;
}
float *get_network_output_gpu(network net)

View File

@ -130,6 +130,7 @@ void run_nightmare(int argc, char **argv)
float rate = find_float_arg(argc, argv, "-rate", .04);
float thresh = find_float_arg(argc, argv, "-thresh", 1.);
float rotate = find_float_arg(argc, argv, "-rotate", 0);
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
network net = parse_network_cfg(cfg);
load_weights(&net, weights);
@ -168,7 +169,11 @@ void run_nightmare(int argc, char **argv)
im = g;
}
char buff[256];
sprintf(buff, "%s_%s_%d_%06d",imbase, cfgbase, max_layer, e);
if (prefix){
sprintf(buff, "%s/%s_%s_%d_%06d",prefix, imbase, cfgbase, max_layer, e);
}else{
sprintf(buff, "%s_%s_%d_%06d",imbase, cfgbase, max_layer, e);
}
printf("%d %s\n", e, buff);
save_image(im, buff);
//show_image(im, buff);

143
src/normalization_layer.c Normal file
View File

@ -0,0 +1,143 @@
#include "normalization_layer.h"
#include "blas.h"
#include <stdio.h>
layer make_normalization_layer(int batch, int w, int h, int c, int size, float alpha, float beta, float kappa)
{
fprintf(stderr, "Local Response Normalization Layer: %d x %d x %d image, %d size\n", w,h,c,size);
layer layer = {0};
layer.type = NORMALIZATION;
layer.batch = batch;
layer.h = layer.out_h = h;
layer.w = layer.out_w = w;
layer.c = layer.out_c = c;
layer.kappa = kappa;
layer.size = size;
layer.alpha = alpha;
layer.beta = beta;
layer.output = calloc(h * w * c * batch, sizeof(float));
layer.delta = calloc(h * w * c * batch, sizeof(float));
layer.squared = calloc(h * w * c * batch, sizeof(float));
layer.norms = calloc(h * w * c * batch, sizeof(float));
layer.inputs = w*h*c;
layer.outputs = layer.inputs;
#ifdef GPU
layer.output_gpu = cuda_make_array(0, h * w * c * batch);
layer.delta_gpu = cuda_make_array(0, h * w * c * batch);
layer.squared_gpu = cuda_make_array(0, h * w * c * batch);
layer.norms_gpu = cuda_make_array(0, h * w * c * batch);
#endif
return layer;
}
void resize_normalization_layer(layer *layer, int w, int h)
{
int c = layer->c;
int batch = layer->batch;
layer->h = h;
layer->w = w;
layer->out_h = h;
layer->out_w = w;
layer->inputs = w*h*c;
layer->outputs = layer->inputs;
layer->output = realloc(layer->output, h * w * layer->c * layer->batch * sizeof(float));
layer->delta = realloc(layer->delta, h * w * layer->c * layer->batch * sizeof(float));
layer->squared = realloc(layer->squared, h * w * layer->c * layer->batch * sizeof(float));
layer->norms = realloc(layer->norms, h * w * layer->c * layer->batch * sizeof(float));
#ifdef GPU
cuda_free(layer->output_gpu);
cuda_free(layer->delta_gpu);
cuda_free(layer->squared_gpu);
cuda_free(layer->norms_gpu);
layer->output_gpu = cuda_make_array(0, h * w * c * batch);
layer->delta_gpu = cuda_make_array(0, h * w * c * batch);
layer->squared_gpu = cuda_make_array(0, h * w * c * batch);
layer->norms_gpu = cuda_make_array(0, h * w * c * batch);
#endif
}
void forward_normalization_layer(const layer layer, network_state state)
{
int k,b;
int w = layer.w;
int h = layer.h;
int c = layer.c;
scal_cpu(w*h*c*layer.batch, 0, layer.squared, 1);
for(b = 0; b < layer.batch; ++b){
float *squared = layer.squared + w*h*c*b;
float *norms = layer.norms + w*h*c*b;
float *input = state.input + w*h*c*b;
pow_cpu(w*h*c, 2, input, 1, squared, 1);
const_cpu(w*h, layer.kappa, norms, 1);
for(k = 0; k < layer.size/2; ++k){
axpy_cpu(w*h, layer.alpha, squared + w*h*k, 1, norms, 1);
}
for(k = 1; k < layer.c; ++k){
copy_cpu(w*h, norms + w*h*(k-1), 1, norms + w*h*k, 1);
int prev = k - ((layer.size-1)/2) - 1;
int next = k + (layer.size/2);
if(prev >= 0) axpy_cpu(w*h, -layer.alpha, squared + w*h*prev, 1, norms + w*h*k, 1);
if(next < layer.c) axpy_cpu(w*h, layer.alpha, squared + w*h*next, 1, norms + w*h*k, 1);
}
}
pow_cpu(w*h*c*layer.batch, -layer.beta, layer.norms, 1, layer.output, 1);
mul_cpu(w*h*c*layer.batch, state.input, 1, layer.output, 1);
}
void backward_normalization_layer(const layer layer, network_state state)
{
// TODO This is approximate ;-)
int w = layer.w;
int h = layer.h;
int c = layer.c;
pow_cpu(w*h*c*layer.batch, -layer.beta, layer.norms, 1, state.delta, 1);
mul_cpu(w*h*c*layer.batch, layer.delta, 1, state.delta, 1);
}
#ifdef GPU
void forward_normalization_layer_gpu(const layer layer, network_state state)
{
int k,b;
int w = layer.w;
int h = layer.h;
int c = layer.c;
scal_ongpu(w*h*c*layer.batch, 0, layer.squared_gpu, 1);
for(b = 0; b < layer.batch; ++b){
float *squared = layer.squared_gpu + w*h*c*b;
float *norms = layer.norms_gpu + w*h*c*b;
float *input = state.input + w*h*c*b;
pow_ongpu(w*h*c, 2, input, 1, squared, 1);
const_ongpu(w*h, layer.kappa, norms, 1);
for(k = 0; k < layer.size/2; ++k){
axpy_ongpu(w*h, layer.alpha, squared + w*h*k, 1, norms, 1);
}
for(k = 1; k < layer.c; ++k){
copy_ongpu(w*h, norms + w*h*(k-1), 1, norms + w*h*k, 1);
int prev = k - ((layer.size-1)/2) - 1;
int next = k + (layer.size/2);
if(prev >= 0) axpy_ongpu(w*h, -layer.alpha, squared + w*h*prev, 1, norms + w*h*k, 1);
if(next < layer.c) axpy_ongpu(w*h, layer.alpha, squared + w*h*next, 1, norms + w*h*k, 1);
}
}
pow_ongpu(w*h*c*layer.batch, -layer.beta, layer.norms_gpu, 1, layer.output_gpu, 1);
mul_ongpu(w*h*c*layer.batch, state.input, 1, layer.output_gpu, 1);
}
void backward_normalization_layer_gpu(const layer layer, network_state state)
{
// TODO This is approximate ;-)
int w = layer.w;
int h = layer.h;
int c = layer.c;
pow_ongpu(w*h*c*layer.batch, -layer.beta, layer.norms_gpu, 1, state.delta, 1);
mul_ongpu(w*h*c*layer.batch, layer.delta_gpu, 1, state.delta, 1);
}
#endif

19
src/normalization_layer.h Normal file
View File

@ -0,0 +1,19 @@
#ifndef NORMALIZATION_LAYER_H
#define NORMALIZATION_LAYER_H
#include "image.h"
#include "layer.h"
#include "params.h"
layer make_normalization_layer(int batch, int w, int h, int c, int size, float alpha, float beta, float kappa);
void resize_normalization_layer(layer *layer, int h, int w);
void forward_normalization_layer(const layer layer, network_state state);
void backward_normalization_layer(const layer layer, network_state state);
void visualize_normalization_layer(layer layer, char *window);
#ifdef GPU
void forward_normalization_layer_gpu(const layer layer, network_state state);
void backward_normalization_layer_gpu(const layer layer, network_state state);
#endif
#endif

View File

@ -7,6 +7,7 @@
#include "crop_layer.h"
#include "cost_layer.h"
#include "convolutional_layer.h"
#include "normalization_layer.h"
#include "deconvolutional_layer.h"
#include "connected_layer.h"
#include "maxpool_layer.h"
@ -30,6 +31,7 @@ int is_connected(section *s);
int is_maxpool(section *s);
int is_dropout(section *s);
int is_softmax(section *s);
int is_normalization(section *s);
int is_crop(section *s);
int is_cost(section *s);
int is_detection(section *s);
@ -228,6 +230,17 @@ dropout_layer parse_dropout(list *options, size_params params)
return layer;
}
layer parse_normalization(list *options, size_params params)
{
float alpha = option_find_float(options, "alpha", .0001);
float beta = option_find_float(options, "beta" , .75);
float kappa = option_find_float(options, "kappa", 1);
int size = option_find_int(options, "size", 5);
layer l = make_normalization_layer(params.batch, params.w, params.h, params.c, size, alpha, beta, kappa);
option_unused(options);
return l;
}
route_layer parse_route(list *options, size_params params, network net)
{
char *l = option_find(options, "layers");
@ -328,6 +341,8 @@ network parse_network_cfg(char *filename)
l = parse_detection(options, params);
}else if(is_softmax(s)){
l = parse_softmax(options, params);
}else if(is_normalization(s)){
l = parse_normalization(options, params);
}else if(is_maxpool(s)){
l = parse_maxpool(options, params);
}else if(is_route(s)){
@ -403,6 +418,12 @@ int is_dropout(section *s)
return (strcmp(s->type, "[dropout]")==0);
}
int is_normalization(section *s)
{
return (strcmp(s->type, "[lrn]")==0
|| strcmp(s->type, "[normalization]")==0);
}
int is_softmax(section *s)
{
return (strcmp(s->type, "[soft]")==0

View File

@ -58,6 +58,21 @@ float find_float_arg(int argc, char **argv, char *arg, float def)
return def;
}
char *find_char_arg(int argc, char **argv, char *arg, char *def)
{
int i;
for(i = 0; i < argc-1; ++i){
if(!argv[i]) continue;
if(0==strcmp(argv[i], arg)){
def = argv[i+1];
del_arg(argc, argv, i);
del_arg(argc, argv, i);
break;
}
}
return def;
}
char *basecfg(char *cfgfile)
{

View File

@ -39,6 +39,7 @@ float sec(clock_t clocks);
int find_int_arg(int argc, char **argv, char *arg, int def);
float find_float_arg(int argc, char **argv, char *arg, float def);
int find_arg(int argc, char* argv[], char *arg);
char *find_char_arg(int argc, char **argv, char *arg, char *def);
#endif