mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
normalization layer
This commit is contained in:
parent
fc323c6310
commit
9c0b76ab8d
2
Makefile
2
Makefile
@ -34,7 +34,7 @@ CFLAGS+= -DGPU
|
||||
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
|
||||
endif
|
||||
|
||||
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o
|
||||
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o normalization_layer.o
|
||||
ifeq ($(GPU), 1)
|
||||
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
|
||||
endif
|
||||
|
19
src/blas.c
19
src/blas.c
@ -1,4 +1,23 @@
|
||||
#include "blas.h"
|
||||
#include "math.h"
|
||||
|
||||
void const_cpu(int N, float ALPHA, float *X, int INCX)
|
||||
{
|
||||
int i;
|
||||
for(i = 0; i < N; ++i) X[i*INCX] = ALPHA;
|
||||
}
|
||||
|
||||
void mul_cpu(int N, float *X, int INCX, float *Y, int INCY)
|
||||
{
|
||||
int i;
|
||||
for(i = 0; i < N; ++i) Y[i*INCY] *= X[i*INCX];
|
||||
}
|
||||
|
||||
void pow_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
|
||||
{
|
||||
int i;
|
||||
for(i = 0; i < N; ++i) Y[i*INCY] = pow(X[i*INCX], ALPHA);
|
||||
}
|
||||
|
||||
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
|
||||
{
|
||||
|
@ -6,6 +6,10 @@ void time_random_matrix(int TA, int TB, int m, int k, int n);
|
||||
|
||||
void test_blas();
|
||||
|
||||
void const_cpu(int N, float ALPHA, float *X, int INCX);
|
||||
void pow_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
|
||||
void mul_cpu(int N, float *X, int INCX, float *Y, int INCY);
|
||||
|
||||
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
|
||||
void copy_cpu(int N, float *X, int INCX, float *Y, int INCY);
|
||||
void scal_cpu(int N, float ALPHA, float *X, int INCX);
|
||||
@ -19,5 +23,9 @@ void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY);
|
||||
void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY);
|
||||
void scal_ongpu(int N, float ALPHA, float * X, int INCX);
|
||||
void mask_ongpu(int N, float * X, float * mask);
|
||||
void const_ongpu(int N, float ALPHA, float *X, int INCX);
|
||||
void pow_ongpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
|
||||
void mul_ongpu(int N, float *X, int INCX, float *Y, int INCY);
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
@ -9,6 +9,18 @@ __global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX, f
|
||||
if(i < N) Y[OFFY+i*INCY] += ALPHA*X[OFFX+i*INCX];
|
||||
}
|
||||
|
||||
__global__ void pow_kernel(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if(i < N) Y[i*INCY] = pow(X[i*INCX], ALPHA);
|
||||
}
|
||||
|
||||
__global__ void const_kernel(int N, float ALPHA, float *X, int INCX)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if(i < N) X[i*INCX] = ALPHA;
|
||||
}
|
||||
|
||||
__global__ void scal_kernel(int N, float ALPHA, float *X, int INCX)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
@ -27,11 +39,23 @@ __global__ void copy_kernel(int N, float *X, int OFFX, int INCX, float *Y, int
|
||||
if(i < N) Y[i*INCY + OFFY] = X[i*INCX + OFFX];
|
||||
}
|
||||
|
||||
__global__ void mul_kernel(int N, float *X, int INCX, float *Y, int INCY)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if(i < N) Y[i*INCY] *= X[i*INCX];
|
||||
}
|
||||
|
||||
extern "C" void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
|
||||
{
|
||||
axpy_ongpu_offset(N, ALPHA, X, 0, INCX, Y, 0, INCY);
|
||||
}
|
||||
|
||||
extern "C" void pow_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
|
||||
{
|
||||
pow_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX, Y, INCY);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
|
||||
{
|
||||
axpy_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, OFFX, INCX, Y, OFFY, INCY);
|
||||
@ -43,6 +67,12 @@ extern "C" void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY)
|
||||
copy_ongpu_offset(N, X, 0, INCX, Y, 0, INCY);
|
||||
}
|
||||
|
||||
extern "C" void mul_ongpu(int N, float * X, int INCX, float * Y, int INCY)
|
||||
{
|
||||
mul_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, INCX, Y, INCY);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
|
||||
{
|
||||
copy_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, OFFX, INCX, Y, OFFY, INCY);
|
||||
@ -55,6 +85,12 @@ extern "C" void mask_ongpu(int N, float * X, float * mask)
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void const_ongpu(int N, float ALPHA, float * X, int INCX)
|
||||
{
|
||||
const_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX)
|
||||
{
|
||||
scal_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
|
||||
|
11
src/layer.h
11
src/layer.h
@ -13,7 +13,8 @@ typedef enum {
|
||||
DROPOUT,
|
||||
CROP,
|
||||
ROUTE,
|
||||
COST
|
||||
COST,
|
||||
NORMALIZATION
|
||||
} LAYER_TYPE;
|
||||
|
||||
typedef enum{
|
||||
@ -48,6 +49,10 @@ typedef struct {
|
||||
int does_cost;
|
||||
int joint;
|
||||
|
||||
float alpha;
|
||||
float beta;
|
||||
float kappa;
|
||||
|
||||
int dontload;
|
||||
|
||||
float probability;
|
||||
@ -69,6 +74,8 @@ typedef struct {
|
||||
int * input_sizes;
|
||||
float * delta;
|
||||
float * output;
|
||||
float * squared;
|
||||
float * norms;
|
||||
|
||||
#ifdef GPU
|
||||
int *indexes_gpu;
|
||||
@ -86,6 +93,8 @@ typedef struct {
|
||||
float * output_gpu;
|
||||
float * delta_gpu;
|
||||
float * rand_gpu;
|
||||
float * squared_gpu;
|
||||
float * norms_gpu;
|
||||
#endif
|
||||
} layer;
|
||||
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "convolutional_layer.h"
|
||||
#include "deconvolutional_layer.h"
|
||||
#include "detection_layer.h"
|
||||
#include "normalization_layer.h"
|
||||
#include "maxpool_layer.h"
|
||||
#include "cost_layer.h"
|
||||
#include "softmax_layer.h"
|
||||
@ -39,6 +40,8 @@ char *get_layer_string(LAYER_TYPE a)
|
||||
return "cost";
|
||||
case ROUTE:
|
||||
return "route";
|
||||
case NORMALIZATION:
|
||||
return "normalization";
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -66,6 +69,8 @@ void forward_network(network net, network_state state)
|
||||
forward_convolutional_layer(l, state);
|
||||
} else if(l.type == DECONVOLUTIONAL){
|
||||
forward_deconvolutional_layer(l, state);
|
||||
} else if(l.type == NORMALIZATION){
|
||||
forward_normalization_layer(l, state);
|
||||
} else if(l.type == DETECTION){
|
||||
forward_detection_layer(l, state);
|
||||
} else if(l.type == CONNECTED){
|
||||
@ -147,6 +152,8 @@ void backward_network(network net, network_state state)
|
||||
backward_convolutional_layer(l, state);
|
||||
} else if(l.type == DECONVOLUTIONAL){
|
||||
backward_deconvolutional_layer(l, state);
|
||||
} else if(l.type == NORMALIZATION){
|
||||
backward_normalization_layer(l, state);
|
||||
} else if(l.type == MAXPOOL){
|
||||
if(i != 0) backward_maxpool_layer(l, state);
|
||||
} else if(l.type == DROPOUT){
|
||||
@ -266,6 +273,8 @@ int resize_network(network *net, int w, int h)
|
||||
resize_convolutional_layer(&l, w, h);
|
||||
}else if(l.type == MAXPOOL){
|
||||
resize_maxpool_layer(&l, w, h);
|
||||
}else if(l.type == NORMALIZATION){
|
||||
resize_normalization_layer(&l, w, h);
|
||||
}else{
|
||||
error("Cannot resize this type of layer");
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ extern "C" {
|
||||
#include "convolutional_layer.h"
|
||||
#include "deconvolutional_layer.h"
|
||||
#include "maxpool_layer.h"
|
||||
#include "normalization_layer.h"
|
||||
#include "cost_layer.h"
|
||||
#include "softmax_layer.h"
|
||||
#include "dropout_layer.h"
|
||||
@ -44,6 +45,8 @@ void forward_network_gpu(network net, network_state state)
|
||||
forward_cost_layer_gpu(l, state);
|
||||
} else if(l.type == SOFTMAX){
|
||||
forward_softmax_layer_gpu(l, state);
|
||||
} else if(l.type == NORMALIZATION){
|
||||
forward_normalization_layer_gpu(l, state);
|
||||
} else if(l.type == MAXPOOL){
|
||||
forward_maxpool_layer_gpu(l, state);
|
||||
} else if(l.type == DROPOUT){
|
||||
@ -80,6 +83,8 @@ void backward_network_gpu(network net, network_state state)
|
||||
backward_dropout_layer_gpu(l, state);
|
||||
} else if(l.type == DETECTION){
|
||||
backward_detection_layer_gpu(l, state);
|
||||
} else if(l.type == NORMALIZATION){
|
||||
backward_normalization_layer_gpu(l, state);
|
||||
} else if(l.type == SOFTMAX){
|
||||
if(i != 0) backward_softmax_layer_gpu(l, state);
|
||||
} else if(l.type == CONNECTED){
|
||||
@ -136,20 +141,7 @@ float *get_network_output_layer_gpu(network net, int i)
|
||||
{
|
||||
layer l = net.layers[i];
|
||||
cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
|
||||
if(l.type == CONVOLUTIONAL){
|
||||
return l.output;
|
||||
} else if(l.type == DECONVOLUTIONAL){
|
||||
return l.output;
|
||||
} else if(l.type == CONNECTED){
|
||||
return l.output;
|
||||
} else if(l.type == DETECTION){
|
||||
return l.output;
|
||||
} else if(l.type == MAXPOOL){
|
||||
return l.output;
|
||||
} else if(l.type == SOFTMAX){
|
||||
return l.output;
|
||||
}
|
||||
return 0;
|
||||
return l.output;
|
||||
}
|
||||
|
||||
float *get_network_output_gpu(network net)
|
||||
|
@ -130,6 +130,7 @@ void run_nightmare(int argc, char **argv)
|
||||
float rate = find_float_arg(argc, argv, "-rate", .04);
|
||||
float thresh = find_float_arg(argc, argv, "-thresh", 1.);
|
||||
float rotate = find_float_arg(argc, argv, "-rotate", 0);
|
||||
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
|
||||
|
||||
network net = parse_network_cfg(cfg);
|
||||
load_weights(&net, weights);
|
||||
@ -168,7 +169,11 @@ void run_nightmare(int argc, char **argv)
|
||||
im = g;
|
||||
}
|
||||
char buff[256];
|
||||
sprintf(buff, "%s_%s_%d_%06d",imbase, cfgbase, max_layer, e);
|
||||
if (prefix){
|
||||
sprintf(buff, "%s/%s_%s_%d_%06d",prefix, imbase, cfgbase, max_layer, e);
|
||||
}else{
|
||||
sprintf(buff, "%s_%s_%d_%06d",imbase, cfgbase, max_layer, e);
|
||||
}
|
||||
printf("%d %s\n", e, buff);
|
||||
save_image(im, buff);
|
||||
//show_image(im, buff);
|
||||
|
143
src/normalization_layer.c
Normal file
143
src/normalization_layer.c
Normal file
@ -0,0 +1,143 @@
|
||||
#include "normalization_layer.h"
|
||||
#include "blas.h"
|
||||
#include <stdio.h>
|
||||
|
||||
layer make_normalization_layer(int batch, int w, int h, int c, int size, float alpha, float beta, float kappa)
|
||||
{
|
||||
fprintf(stderr, "Local Response Normalization Layer: %d x %d x %d image, %d size\n", w,h,c,size);
|
||||
layer layer = {0};
|
||||
layer.type = NORMALIZATION;
|
||||
layer.batch = batch;
|
||||
layer.h = layer.out_h = h;
|
||||
layer.w = layer.out_w = w;
|
||||
layer.c = layer.out_c = c;
|
||||
layer.kappa = kappa;
|
||||
layer.size = size;
|
||||
layer.alpha = alpha;
|
||||
layer.beta = beta;
|
||||
layer.output = calloc(h * w * c * batch, sizeof(float));
|
||||
layer.delta = calloc(h * w * c * batch, sizeof(float));
|
||||
layer.squared = calloc(h * w * c * batch, sizeof(float));
|
||||
layer.norms = calloc(h * w * c * batch, sizeof(float));
|
||||
layer.inputs = w*h*c;
|
||||
layer.outputs = layer.inputs;
|
||||
#ifdef GPU
|
||||
layer.output_gpu = cuda_make_array(0, h * w * c * batch);
|
||||
layer.delta_gpu = cuda_make_array(0, h * w * c * batch);
|
||||
layer.squared_gpu = cuda_make_array(0, h * w * c * batch);
|
||||
layer.norms_gpu = cuda_make_array(0, h * w * c * batch);
|
||||
#endif
|
||||
return layer;
|
||||
}
|
||||
|
||||
void resize_normalization_layer(layer *layer, int w, int h)
|
||||
{
|
||||
int c = layer->c;
|
||||
int batch = layer->batch;
|
||||
layer->h = h;
|
||||
layer->w = w;
|
||||
layer->out_h = h;
|
||||
layer->out_w = w;
|
||||
layer->inputs = w*h*c;
|
||||
layer->outputs = layer->inputs;
|
||||
layer->output = realloc(layer->output, h * w * layer->c * layer->batch * sizeof(float));
|
||||
layer->delta = realloc(layer->delta, h * w * layer->c * layer->batch * sizeof(float));
|
||||
layer->squared = realloc(layer->squared, h * w * layer->c * layer->batch * sizeof(float));
|
||||
layer->norms = realloc(layer->norms, h * w * layer->c * layer->batch * sizeof(float));
|
||||
#ifdef GPU
|
||||
cuda_free(layer->output_gpu);
|
||||
cuda_free(layer->delta_gpu);
|
||||
cuda_free(layer->squared_gpu);
|
||||
cuda_free(layer->norms_gpu);
|
||||
layer->output_gpu = cuda_make_array(0, h * w * c * batch);
|
||||
layer->delta_gpu = cuda_make_array(0, h * w * c * batch);
|
||||
layer->squared_gpu = cuda_make_array(0, h * w * c * batch);
|
||||
layer->norms_gpu = cuda_make_array(0, h * w * c * batch);
|
||||
#endif
|
||||
}
|
||||
|
||||
void forward_normalization_layer(const layer layer, network_state state)
|
||||
{
|
||||
int k,b;
|
||||
int w = layer.w;
|
||||
int h = layer.h;
|
||||
int c = layer.c;
|
||||
scal_cpu(w*h*c*layer.batch, 0, layer.squared, 1);
|
||||
|
||||
for(b = 0; b < layer.batch; ++b){
|
||||
float *squared = layer.squared + w*h*c*b;
|
||||
float *norms = layer.norms + w*h*c*b;
|
||||
float *input = state.input + w*h*c*b;
|
||||
pow_cpu(w*h*c, 2, input, 1, squared, 1);
|
||||
|
||||
const_cpu(w*h, layer.kappa, norms, 1);
|
||||
for(k = 0; k < layer.size/2; ++k){
|
||||
axpy_cpu(w*h, layer.alpha, squared + w*h*k, 1, norms, 1);
|
||||
}
|
||||
|
||||
for(k = 1; k < layer.c; ++k){
|
||||
copy_cpu(w*h, norms + w*h*(k-1), 1, norms + w*h*k, 1);
|
||||
int prev = k - ((layer.size-1)/2) - 1;
|
||||
int next = k + (layer.size/2);
|
||||
if(prev >= 0) axpy_cpu(w*h, -layer.alpha, squared + w*h*prev, 1, norms + w*h*k, 1);
|
||||
if(next < layer.c) axpy_cpu(w*h, layer.alpha, squared + w*h*next, 1, norms + w*h*k, 1);
|
||||
}
|
||||
}
|
||||
pow_cpu(w*h*c*layer.batch, -layer.beta, layer.norms, 1, layer.output, 1);
|
||||
mul_cpu(w*h*c*layer.batch, state.input, 1, layer.output, 1);
|
||||
}
|
||||
|
||||
void backward_normalization_layer(const layer layer, network_state state)
|
||||
{
|
||||
// TODO This is approximate ;-)
|
||||
|
||||
int w = layer.w;
|
||||
int h = layer.h;
|
||||
int c = layer.c;
|
||||
pow_cpu(w*h*c*layer.batch, -layer.beta, layer.norms, 1, state.delta, 1);
|
||||
mul_cpu(w*h*c*layer.batch, layer.delta, 1, state.delta, 1);
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
void forward_normalization_layer_gpu(const layer layer, network_state state)
|
||||
{
|
||||
int k,b;
|
||||
int w = layer.w;
|
||||
int h = layer.h;
|
||||
int c = layer.c;
|
||||
scal_ongpu(w*h*c*layer.batch, 0, layer.squared_gpu, 1);
|
||||
|
||||
for(b = 0; b < layer.batch; ++b){
|
||||
float *squared = layer.squared_gpu + w*h*c*b;
|
||||
float *norms = layer.norms_gpu + w*h*c*b;
|
||||
float *input = state.input + w*h*c*b;
|
||||
pow_ongpu(w*h*c, 2, input, 1, squared, 1);
|
||||
|
||||
const_ongpu(w*h, layer.kappa, norms, 1);
|
||||
for(k = 0; k < layer.size/2; ++k){
|
||||
axpy_ongpu(w*h, layer.alpha, squared + w*h*k, 1, norms, 1);
|
||||
}
|
||||
|
||||
for(k = 1; k < layer.c; ++k){
|
||||
copy_ongpu(w*h, norms + w*h*(k-1), 1, norms + w*h*k, 1);
|
||||
int prev = k - ((layer.size-1)/2) - 1;
|
||||
int next = k + (layer.size/2);
|
||||
if(prev >= 0) axpy_ongpu(w*h, -layer.alpha, squared + w*h*prev, 1, norms + w*h*k, 1);
|
||||
if(next < layer.c) axpy_ongpu(w*h, layer.alpha, squared + w*h*next, 1, norms + w*h*k, 1);
|
||||
}
|
||||
}
|
||||
pow_ongpu(w*h*c*layer.batch, -layer.beta, layer.norms_gpu, 1, layer.output_gpu, 1);
|
||||
mul_ongpu(w*h*c*layer.batch, state.input, 1, layer.output_gpu, 1);
|
||||
}
|
||||
|
||||
void backward_normalization_layer_gpu(const layer layer, network_state state)
|
||||
{
|
||||
// TODO This is approximate ;-)
|
||||
|
||||
int w = layer.w;
|
||||
int h = layer.h;
|
||||
int c = layer.c;
|
||||
pow_ongpu(w*h*c*layer.batch, -layer.beta, layer.norms_gpu, 1, state.delta, 1);
|
||||
mul_ongpu(w*h*c*layer.batch, layer.delta_gpu, 1, state.delta, 1);
|
||||
}
|
||||
#endif
|
19
src/normalization_layer.h
Normal file
19
src/normalization_layer.h
Normal file
@ -0,0 +1,19 @@
|
||||
#ifndef NORMALIZATION_LAYER_H
|
||||
#define NORMALIZATION_LAYER_H
|
||||
|
||||
#include "image.h"
|
||||
#include "layer.h"
|
||||
#include "params.h"
|
||||
|
||||
layer make_normalization_layer(int batch, int w, int h, int c, int size, float alpha, float beta, float kappa);
|
||||
void resize_normalization_layer(layer *layer, int h, int w);
|
||||
void forward_normalization_layer(const layer layer, network_state state);
|
||||
void backward_normalization_layer(const layer layer, network_state state);
|
||||
void visualize_normalization_layer(layer layer, char *window);
|
||||
|
||||
#ifdef GPU
|
||||
void forward_normalization_layer_gpu(const layer layer, network_state state);
|
||||
void backward_normalization_layer_gpu(const layer layer, network_state state);
|
||||
#endif
|
||||
|
||||
#endif
|
21
src/parser.c
21
src/parser.c
@ -7,6 +7,7 @@
|
||||
#include "crop_layer.h"
|
||||
#include "cost_layer.h"
|
||||
#include "convolutional_layer.h"
|
||||
#include "normalization_layer.h"
|
||||
#include "deconvolutional_layer.h"
|
||||
#include "connected_layer.h"
|
||||
#include "maxpool_layer.h"
|
||||
@ -30,6 +31,7 @@ int is_connected(section *s);
|
||||
int is_maxpool(section *s);
|
||||
int is_dropout(section *s);
|
||||
int is_softmax(section *s);
|
||||
int is_normalization(section *s);
|
||||
int is_crop(section *s);
|
||||
int is_cost(section *s);
|
||||
int is_detection(section *s);
|
||||
@ -228,6 +230,17 @@ dropout_layer parse_dropout(list *options, size_params params)
|
||||
return layer;
|
||||
}
|
||||
|
||||
layer parse_normalization(list *options, size_params params)
|
||||
{
|
||||
float alpha = option_find_float(options, "alpha", .0001);
|
||||
float beta = option_find_float(options, "beta" , .75);
|
||||
float kappa = option_find_float(options, "kappa", 1);
|
||||
int size = option_find_int(options, "size", 5);
|
||||
layer l = make_normalization_layer(params.batch, params.w, params.h, params.c, size, alpha, beta, kappa);
|
||||
option_unused(options);
|
||||
return l;
|
||||
}
|
||||
|
||||
route_layer parse_route(list *options, size_params params, network net)
|
||||
{
|
||||
char *l = option_find(options, "layers");
|
||||
@ -328,6 +341,8 @@ network parse_network_cfg(char *filename)
|
||||
l = parse_detection(options, params);
|
||||
}else if(is_softmax(s)){
|
||||
l = parse_softmax(options, params);
|
||||
}else if(is_normalization(s)){
|
||||
l = parse_normalization(options, params);
|
||||
}else if(is_maxpool(s)){
|
||||
l = parse_maxpool(options, params);
|
||||
}else if(is_route(s)){
|
||||
@ -403,6 +418,12 @@ int is_dropout(section *s)
|
||||
return (strcmp(s->type, "[dropout]")==0);
|
||||
}
|
||||
|
||||
int is_normalization(section *s)
|
||||
{
|
||||
return (strcmp(s->type, "[lrn]")==0
|
||||
|| strcmp(s->type, "[normalization]")==0);
|
||||
}
|
||||
|
||||
int is_softmax(section *s)
|
||||
{
|
||||
return (strcmp(s->type, "[soft]")==0
|
||||
|
15
src/utils.c
15
src/utils.c
@ -58,6 +58,21 @@ float find_float_arg(int argc, char **argv, char *arg, float def)
|
||||
return def;
|
||||
}
|
||||
|
||||
char *find_char_arg(int argc, char **argv, char *arg, char *def)
|
||||
{
|
||||
int i;
|
||||
for(i = 0; i < argc-1; ++i){
|
||||
if(!argv[i]) continue;
|
||||
if(0==strcmp(argv[i], arg)){
|
||||
def = argv[i+1];
|
||||
del_arg(argc, argv, i);
|
||||
del_arg(argc, argv, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return def;
|
||||
}
|
||||
|
||||
|
||||
char *basecfg(char *cfgfile)
|
||||
{
|
||||
|
@ -39,6 +39,7 @@ float sec(clock_t clocks);
|
||||
int find_int_arg(int argc, char **argv, char *arg, int def);
|
||||
float find_float_arg(int argc, char **argv, char *arg, float def);
|
||||
int find_arg(int argc, char* argv[], char *arg);
|
||||
char *find_char_arg(int argc, char **argv, char *arg, char *def);
|
||||
|
||||
#endif
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user