mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
NIGHTMARE!!!!
This commit is contained in:
parent
d1d56a2a72
commit
a08ef29e08
3
Makefile
3
Makefile
@ -34,7 +34,7 @@ CFLAGS+= -DGPU
|
|||||||
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
|
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
|
||||||
endif
|
endif
|
||||||
|
|
||||||
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o
|
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o
|
||||||
ifeq ($(GPU), 1)
|
ifeq ($(GPU), 1)
|
||||||
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
|
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
|
||||||
endif
|
endif
|
||||||
@ -58,7 +58,6 @@ obj:
|
|||||||
results:
|
results:
|
||||||
mkdir -p results
|
mkdir -p results
|
||||||
|
|
||||||
|
|
||||||
.PHONY: clean
|
.PHONY: clean
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
|
@ -13,9 +13,9 @@ seen=0
|
|||||||
crop_height=224
|
crop_height=224
|
||||||
crop_width=224
|
crop_width=224
|
||||||
flip=1
|
flip=1
|
||||||
angle=15
|
angle=0
|
||||||
saturation=1.5
|
saturation=1
|
||||||
exposure=1.5
|
exposure=1
|
||||||
|
|
||||||
[convolutional]
|
[convolutional]
|
||||||
filters=64
|
filters=64
|
||||||
|
@ -13,9 +13,9 @@ decay=0.0005
|
|||||||
crop_height=224
|
crop_height=224
|
||||||
crop_width=224
|
crop_width=224
|
||||||
flip=1
|
flip=1
|
||||||
exposure=2
|
exposure=1
|
||||||
saturation=2
|
saturation=1
|
||||||
angle=5
|
angle=0
|
||||||
|
|
||||||
[convolutional]
|
[convolutional]
|
||||||
filters=64
|
filters=64
|
||||||
|
122
cfg/vgg-conv.cfg
Normal file
122
cfg/vgg-conv.cfg
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
[net]
|
||||||
|
batch=1
|
||||||
|
subdivisions=1
|
||||||
|
width=224
|
||||||
|
height=224
|
||||||
|
channels=3
|
||||||
|
learning_rate=0.00001
|
||||||
|
momentum=0.9
|
||||||
|
seen=0
|
||||||
|
decay=0.0005
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=64
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=64
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[maxpool]
|
||||||
|
size=2
|
||||||
|
stride=2
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=128
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=128
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[maxpool]
|
||||||
|
size=2
|
||||||
|
stride=2
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=256
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=256
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=256
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[maxpool]
|
||||||
|
size=2
|
||||||
|
stride=2
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=512
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=512
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=512
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[maxpool]
|
||||||
|
size=2
|
||||||
|
stride=2
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=512
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=512
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[convolutional]
|
||||||
|
filters=512
|
||||||
|
size=3
|
||||||
|
stride=1
|
||||||
|
pad=1
|
||||||
|
activation=relu
|
||||||
|
|
||||||
|
[maxpool]
|
||||||
|
size=2
|
||||||
|
stride=2
|
||||||
|
|
BIN
data/scream.jpg
Normal file
BIN
data/scream.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 329 KiB |
@ -8,6 +8,7 @@ __device__ float logistic_activate_kernel(float x){return 1./(1. + exp(-x));}
|
|||||||
__device__ float relu_activate_kernel(float x){return x*(x>0);}
|
__device__ float relu_activate_kernel(float x){return x*(x>0);}
|
||||||
__device__ float relie_activate_kernel(float x){return x*(x>0);}
|
__device__ float relie_activate_kernel(float x){return x*(x>0);}
|
||||||
__device__ float ramp_activate_kernel(float x){return x*(x>0)+.1*x;}
|
__device__ float ramp_activate_kernel(float x){return x*(x>0)+.1*x;}
|
||||||
|
__device__ float leaky_activate_kernel(float x){return (x>0) ? x : .1*x;}
|
||||||
__device__ float tanh_activate_kernel(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
|
__device__ float tanh_activate_kernel(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
|
||||||
__device__ float plse_activate_kernel(float x)
|
__device__ float plse_activate_kernel(float x)
|
||||||
{
|
{
|
||||||
@ -21,6 +22,7 @@ __device__ float logistic_gradient_kernel(float x){return (1-x)*x;}
|
|||||||
__device__ float relu_gradient_kernel(float x){return (x>0);}
|
__device__ float relu_gradient_kernel(float x){return (x>0);}
|
||||||
__device__ float relie_gradient_kernel(float x){return (x>0) ? 1 : .01;}
|
__device__ float relie_gradient_kernel(float x){return (x>0) ? 1 : .01;}
|
||||||
__device__ float ramp_gradient_kernel(float x){return (x>0)+.1;}
|
__device__ float ramp_gradient_kernel(float x){return (x>0)+.1;}
|
||||||
|
__device__ float leaky_gradient_kernel(float x){return (x>0) ? 1 : .1;}
|
||||||
__device__ float tanh_gradient_kernel(float x){return 1-x*x;}
|
__device__ float tanh_gradient_kernel(float x){return 1-x*x;}
|
||||||
__device__ float plse_gradient_kernel(float x){return (x < 0 || x > 1) ? .01 : .125;}
|
__device__ float plse_gradient_kernel(float x){return (x < 0 || x > 1) ? .01 : .125;}
|
||||||
|
|
||||||
@ -37,6 +39,8 @@ __device__ float activate_kernel(float x, ACTIVATION a)
|
|||||||
return relie_activate_kernel(x);
|
return relie_activate_kernel(x);
|
||||||
case RAMP:
|
case RAMP:
|
||||||
return ramp_activate_kernel(x);
|
return ramp_activate_kernel(x);
|
||||||
|
case LEAKY:
|
||||||
|
return leaky_activate_kernel(x);
|
||||||
case TANH:
|
case TANH:
|
||||||
return tanh_activate_kernel(x);
|
return tanh_activate_kernel(x);
|
||||||
case PLSE:
|
case PLSE:
|
||||||
@ -58,6 +62,8 @@ __device__ float gradient_kernel(float x, ACTIVATION a)
|
|||||||
return relie_gradient_kernel(x);
|
return relie_gradient_kernel(x);
|
||||||
case RAMP:
|
case RAMP:
|
||||||
return ramp_gradient_kernel(x);
|
return ramp_gradient_kernel(x);
|
||||||
|
case LEAKY:
|
||||||
|
return leaky_gradient_kernel(x);
|
||||||
case TANH:
|
case TANH:
|
||||||
return tanh_gradient_kernel(x);
|
return tanh_gradient_kernel(x);
|
||||||
case PLSE:
|
case PLSE:
|
||||||
|
@ -22,6 +22,8 @@ char *get_activation_string(ACTIVATION a)
|
|||||||
return "tanh";
|
return "tanh";
|
||||||
case PLSE:
|
case PLSE:
|
||||||
return "plse";
|
return "plse";
|
||||||
|
case LEAKY:
|
||||||
|
return "leaky";
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -36,6 +38,7 @@ ACTIVATION get_activation(char *s)
|
|||||||
if (strcmp(s, "plse")==0) return PLSE;
|
if (strcmp(s, "plse")==0) return PLSE;
|
||||||
if (strcmp(s, "linear")==0) return LINEAR;
|
if (strcmp(s, "linear")==0) return LINEAR;
|
||||||
if (strcmp(s, "ramp")==0) return RAMP;
|
if (strcmp(s, "ramp")==0) return RAMP;
|
||||||
|
if (strcmp(s, "leaky")==0) return LEAKY;
|
||||||
if (strcmp(s, "tanh")==0) return TANH;
|
if (strcmp(s, "tanh")==0) return TANH;
|
||||||
fprintf(stderr, "Couldn't find activation function %s, going with ReLU\n", s);
|
fprintf(stderr, "Couldn't find activation function %s, going with ReLU\n", s);
|
||||||
return RELU;
|
return RELU;
|
||||||
@ -54,6 +57,8 @@ float activate(float x, ACTIVATION a)
|
|||||||
return relie_activate(x);
|
return relie_activate(x);
|
||||||
case RAMP:
|
case RAMP:
|
||||||
return ramp_activate(x);
|
return ramp_activate(x);
|
||||||
|
case LEAKY:
|
||||||
|
return leaky_activate(x);
|
||||||
case TANH:
|
case TANH:
|
||||||
return tanh_activate(x);
|
return tanh_activate(x);
|
||||||
case PLSE:
|
case PLSE:
|
||||||
@ -83,6 +88,8 @@ float gradient(float x, ACTIVATION a)
|
|||||||
return relie_gradient(x);
|
return relie_gradient(x);
|
||||||
case RAMP:
|
case RAMP:
|
||||||
return ramp_gradient(x);
|
return ramp_gradient(x);
|
||||||
|
case LEAKY:
|
||||||
|
return leaky_gradient(x);
|
||||||
case TANH:
|
case TANH:
|
||||||
return tanh_gradient(x);
|
return tanh_gradient(x);
|
||||||
case PLSE:
|
case PLSE:
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
#include "math.h"
|
#include "math.h"
|
||||||
|
|
||||||
typedef enum{
|
typedef enum{
|
||||||
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE
|
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY
|
||||||
}ACTIVATION;
|
}ACTIVATION;
|
||||||
|
|
||||||
ACTIVATION get_activation(char *s);
|
ACTIVATION get_activation(char *s);
|
||||||
@ -24,6 +24,7 @@ static inline float logistic_activate(float x){return 1./(1. + exp(-x));}
|
|||||||
static inline float relu_activate(float x){return x*(x>0);}
|
static inline float relu_activate(float x){return x*(x>0);}
|
||||||
static inline float relie_activate(float x){return x*(x>0);}
|
static inline float relie_activate(float x){return x*(x>0);}
|
||||||
static inline float ramp_activate(float x){return x*(x>0)+.1*x;}
|
static inline float ramp_activate(float x){return x*(x>0)+.1*x;}
|
||||||
|
static inline float leaky_activate(float x){return (x>0) ? x : .1*x;}
|
||||||
static inline float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
|
static inline float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
|
||||||
static inline float plse_activate(float x)
|
static inline float plse_activate(float x)
|
||||||
{
|
{
|
||||||
@ -37,6 +38,7 @@ static inline float logistic_gradient(float x){return (1-x)*x;}
|
|||||||
static inline float relu_gradient(float x){return (x>0);}
|
static inline float relu_gradient(float x){return (x>0);}
|
||||||
static inline float relie_gradient(float x){return (x>0) ? 1 : .01;}
|
static inline float relie_gradient(float x){return (x>0) ? 1 : .01;}
|
||||||
static inline float ramp_gradient(float x){return (x>0)+.1;}
|
static inline float ramp_gradient(float x){return (x>0)+.1;}
|
||||||
|
static inline float leaky_gradient(float x){return (x>0) ? 1 : .1;}
|
||||||
static inline float tanh_gradient(float x){return 1-x*x;}
|
static inline float tanh_gradient(float x){return 1-x*x;}
|
||||||
static inline float plse_gradient(float x){return (x < 0 || x > 1) ? .01 : .125;}
|
static inline float plse_gradient(float x){return (x < 0 || x > 1) ? .01 : .125;}
|
||||||
|
|
||||||
|
@ -97,12 +97,18 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
|||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
|
||||||
void resize_convolutional_layer(convolutional_layer *l, int h, int w)
|
void resize_convolutional_layer(convolutional_layer *l, int w, int h)
|
||||||
{
|
{
|
||||||
l->h = h;
|
|
||||||
l->w = w;
|
l->w = w;
|
||||||
int out_h = convolutional_out_height(*l);
|
l->h = h;
|
||||||
int out_w = convolutional_out_width(*l);
|
int out_w = convolutional_out_width(*l);
|
||||||
|
int out_h = convolutional_out_height(*l);
|
||||||
|
|
||||||
|
l->out_w = out_w;
|
||||||
|
l->out_h = out_h;
|
||||||
|
|
||||||
|
l->outputs = l->out_h * l->out_w * l->out_c;
|
||||||
|
l->inputs = l->w * l->h * l->c;
|
||||||
|
|
||||||
l->col_image = realloc(l->col_image,
|
l->col_image = realloc(l->col_image,
|
||||||
out_h*out_w*l->size*l->size*l->c*sizeof(float));
|
out_h*out_w*l->size*l->size*l->c*sizeof(float));
|
||||||
@ -116,9 +122,9 @@ void resize_convolutional_layer(convolutional_layer *l, int h, int w)
|
|||||||
cuda_free(l->delta_gpu);
|
cuda_free(l->delta_gpu);
|
||||||
cuda_free(l->output_gpu);
|
cuda_free(l->output_gpu);
|
||||||
|
|
||||||
l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c);
|
l->col_image_gpu = cuda_make_array(0, out_h*out_w*l->size*l->size*l->c);
|
||||||
l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
|
l->delta_gpu = cuda_make_array(0, l->batch*out_h*out_w*l->n);
|
||||||
l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
|
l->output_gpu = cuda_make_array(0, l->batch*out_h*out_w*l->n);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation);
|
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation);
|
||||||
void resize_convolutional_layer(convolutional_layer *layer, int h, int w);
|
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
|
||||||
void forward_convolutional_layer(const convolutional_layer layer, network_state state);
|
void forward_convolutional_layer(const convolutional_layer layer, network_state state);
|
||||||
void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay);
|
void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay);
|
||||||
image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters);
|
image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters);
|
||||||
|
@ -13,41 +13,7 @@ extern void run_imagenet(int argc, char **argv);
|
|||||||
extern void run_detection(int argc, char **argv);
|
extern void run_detection(int argc, char **argv);
|
||||||
extern void run_writing(int argc, char **argv);
|
extern void run_writing(int argc, char **argv);
|
||||||
extern void run_captcha(int argc, char **argv);
|
extern void run_captcha(int argc, char **argv);
|
||||||
|
extern void run_nightmare(int argc, char **argv);
|
||||||
void del_arg(int argc, char **argv, int index)
|
|
||||||
{
|
|
||||||
int i;
|
|
||||||
for(i = index; i < argc-1; ++i) argv[i] = argv[i+1];
|
|
||||||
argv[i] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int find_arg(int argc, char* argv[], char *arg)
|
|
||||||
{
|
|
||||||
int i;
|
|
||||||
for(i = 0; i < argc; ++i) {
|
|
||||||
if(!argv[i]) continue;
|
|
||||||
if(0==strcmp(argv[i], arg)) {
|
|
||||||
del_arg(argc, argv, i);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int find_int_arg(int argc, char **argv, char *arg, int def)
|
|
||||||
{
|
|
||||||
int i;
|
|
||||||
for(i = 0; i < argc-1; ++i){
|
|
||||||
if(!argv[i]) continue;
|
|
||||||
if(0==strcmp(argv[i], arg)){
|
|
||||||
def = atoi(argv[i+1]);
|
|
||||||
del_arg(argc, argv, i);
|
|
||||||
del_arg(argc, argv, i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return def;
|
|
||||||
}
|
|
||||||
|
|
||||||
void change_rate(char *filename, float scale, float add)
|
void change_rate(char *filename, float scale, float add)
|
||||||
{
|
{
|
||||||
@ -135,6 +101,8 @@ int main(int argc, char **argv)
|
|||||||
test_resize(argv[2]);
|
test_resize(argv[2]);
|
||||||
} else if (0 == strcmp(argv[1], "captcha")){
|
} else if (0 == strcmp(argv[1], "captcha")){
|
||||||
run_captcha(argc, argv);
|
run_captcha(argc, argv);
|
||||||
|
} else if (0 == strcmp(argv[1], "nightmare")){
|
||||||
|
run_nightmare(argc, argv);
|
||||||
} else if (0 == strcmp(argv[1], "change")){
|
} else if (0 == strcmp(argv[1], "change")){
|
||||||
change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0);
|
change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0);
|
||||||
} else if (0 == strcmp(argv[1], "rgbgr")){
|
} else if (0 == strcmp(argv[1], "rgbgr")){
|
||||||
|
@ -187,6 +187,7 @@ void show_image_cv(image p, char *name)
|
|||||||
{
|
{
|
||||||
int x,y,k;
|
int x,y,k;
|
||||||
image copy = copy_image(p);
|
image copy = copy_image(p);
|
||||||
|
constrain_image(copy);
|
||||||
rgbgr_image(copy);
|
rgbgr_image(copy);
|
||||||
//normalize_image(copy);
|
//normalize_image(copy);
|
||||||
|
|
||||||
@ -207,7 +208,8 @@ void show_image_cv(image p, char *name)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
free_image(copy);
|
free_image(copy);
|
||||||
if(disp->height < 448 || disp->width < 448 || disp->height > 1000){
|
if(0){
|
||||||
|
//if(disp->height < 448 || disp->width < 448 || disp->height > 1000){
|
||||||
int w = 448;
|
int w = 448;
|
||||||
int h = w*p.h/p.w;
|
int h = w*p.h/p.w;
|
||||||
if(h > 1000){
|
if(h > 1000){
|
||||||
|
@ -37,6 +37,8 @@ void exposure_image(image im, float sat);
|
|||||||
void saturate_exposure_image(image im, float sat, float exposure);
|
void saturate_exposure_image(image im, float sat, float exposure);
|
||||||
void hsv_to_rgb(image im);
|
void hsv_to_rgb(image im);
|
||||||
void rgbgr_image(image im);
|
void rgbgr_image(image im);
|
||||||
|
void constrain_image(image im);
|
||||||
|
image grayscale_image(image im);
|
||||||
|
|
||||||
image collapse_image_layers(image source, int border);
|
image collapse_image_layers(image source, int border);
|
||||||
image collapse_images_horz(image *ims, int n);
|
image collapse_images_horz(image *ims, int n);
|
||||||
|
@ -48,7 +48,6 @@ void train_imagenet(char *cfgfile, char *weightfile)
|
|||||||
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
|
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
|
||||||
free_data(train);
|
free_data(train);
|
||||||
if((i % 30000) == 0) net.learning_rate *= .1;
|
if((i % 30000) == 0) net.learning_rate *= .1;
|
||||||
//if(i%100 == 0 && net.learning_rate > .00001) net.learning_rate *= .97;
|
|
||||||
if(i%1000==0){
|
if(i%1000==0){
|
||||||
char buff[256];
|
char buff[256];
|
||||||
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
|
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
|
||||||
|
@ -48,6 +48,8 @@ typedef struct {
|
|||||||
int does_cost;
|
int does_cost;
|
||||||
int joint;
|
int joint;
|
||||||
|
|
||||||
|
int dontload;
|
||||||
|
|
||||||
float probability;
|
float probability;
|
||||||
float scale;
|
float scale;
|
||||||
int *indexes;
|
int *indexes;
|
||||||
|
@ -4,16 +4,16 @@
|
|||||||
|
|
||||||
image get_maxpool_image(maxpool_layer l)
|
image get_maxpool_image(maxpool_layer l)
|
||||||
{
|
{
|
||||||
int h = (l.h-1)/l.stride + 1;
|
int h = l.out_h;
|
||||||
int w = (l.w-1)/l.stride + 1;
|
int w = l.out_w;
|
||||||
int c = l.c;
|
int c = l.c;
|
||||||
return float_to_image(w,h,c,l.output);
|
return float_to_image(w,h,c,l.output);
|
||||||
}
|
}
|
||||||
|
|
||||||
image get_maxpool_delta(maxpool_layer l)
|
image get_maxpool_delta(maxpool_layer l)
|
||||||
{
|
{
|
||||||
int h = (l.h-1)/l.stride + 1;
|
int h = l.out_h;
|
||||||
int w = (l.w-1)/l.stride + 1;
|
int w = l.out_w;
|
||||||
int c = l.c;
|
int c = l.c;
|
||||||
return float_to_image(w,h,c,l.delta);
|
return float_to_image(w,h,c,l.delta);
|
||||||
}
|
}
|
||||||
@ -27,11 +27,11 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
|
|||||||
l.h = h;
|
l.h = h;
|
||||||
l.w = w;
|
l.w = w;
|
||||||
l.c = c;
|
l.c = c;
|
||||||
l.out_h = (h-1)/stride + 1;
|
|
||||||
l.out_w = (w-1)/stride + 1;
|
l.out_w = (w-1)/stride + 1;
|
||||||
|
l.out_h = (h-1)/stride + 1;
|
||||||
l.out_c = c;
|
l.out_c = c;
|
||||||
l.outputs = l.out_h * l.out_w * l.out_c;
|
l.outputs = l.out_h * l.out_w * l.out_c;
|
||||||
l.inputs = l.outputs;
|
l.inputs = h*w*c;
|
||||||
l.size = size;
|
l.size = size;
|
||||||
l.stride = stride;
|
l.stride = stride;
|
||||||
int output_size = l.out_h * l.out_w * l.out_c * batch;
|
int output_size = l.out_h * l.out_w * l.out_c * batch;
|
||||||
@ -46,11 +46,18 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
|
|||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
|
||||||
void resize_maxpool_layer(maxpool_layer *l, int h, int w)
|
void resize_maxpool_layer(maxpool_layer *l, int w, int h)
|
||||||
{
|
{
|
||||||
|
int stride = l->stride;
|
||||||
l->h = h;
|
l->h = h;
|
||||||
l->w = w;
|
l->w = w;
|
||||||
int output_size = ((h-1)/l->stride+1) * ((w-1)/l->stride+1) * l->c * l->batch;
|
|
||||||
|
l->out_w = (w-1)/stride + 1;
|
||||||
|
l->out_h = (h-1)/stride + 1;
|
||||||
|
l->outputs = l->out_w * l->out_h * l->c;
|
||||||
|
int output_size = l->outputs * l->batch;
|
||||||
|
|
||||||
|
l->indexes = realloc(l->indexes, output_size * sizeof(int));
|
||||||
l->output = realloc(l->output, output_size * sizeof(float));
|
l->output = realloc(l->output, output_size * sizeof(float));
|
||||||
l->delta = realloc(l->delta, output_size * sizeof(float));
|
l->delta = realloc(l->delta, output_size * sizeof(float));
|
||||||
|
|
||||||
@ -59,8 +66,8 @@ void resize_maxpool_layer(maxpool_layer *l, int h, int w)
|
|||||||
cuda_free(l->output_gpu);
|
cuda_free(l->output_gpu);
|
||||||
cuda_free(l->delta_gpu);
|
cuda_free(l->delta_gpu);
|
||||||
l->indexes_gpu = cuda_make_int_array(output_size);
|
l->indexes_gpu = cuda_make_int_array(output_size);
|
||||||
l->output_gpu = cuda_make_array(l->output, output_size);
|
l->output_gpu = cuda_make_array(0, output_size);
|
||||||
l->delta_gpu = cuda_make_array(l->delta, output_size);
|
l->delta_gpu = cuda_make_array(0, output_size);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ typedef layer maxpool_layer;
|
|||||||
|
|
||||||
image get_maxpool_image(maxpool_layer l);
|
image get_maxpool_image(maxpool_layer l);
|
||||||
maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride);
|
maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride);
|
||||||
void resize_maxpool_layer(maxpool_layer *l, int h, int w);
|
void resize_maxpool_layer(maxpool_layer *l, int w, int h);
|
||||||
void forward_maxpool_layer(const maxpool_layer l, network_state state);
|
void forward_maxpool_layer(const maxpool_layer l, network_state state);
|
||||||
void backward_maxpool_layer(const maxpool_layer l, network_state state);
|
void backward_maxpool_layer(const maxpool_layer l, network_state state);
|
||||||
|
|
||||||
|
@ -132,10 +132,11 @@ void backward_network(network net, network_state state)
|
|||||||
{
|
{
|
||||||
int i;
|
int i;
|
||||||
float *original_input = state.input;
|
float *original_input = state.input;
|
||||||
|
float *original_delta = state.delta;
|
||||||
for(i = net.n-1; i >= 0; --i){
|
for(i = net.n-1; i >= 0; --i){
|
||||||
if(i == 0){
|
if(i == 0){
|
||||||
state.input = original_input;
|
state.input = original_input;
|
||||||
state.delta = 0;
|
state.delta = original_delta;
|
||||||
}else{
|
}else{
|
||||||
layer prev = net.layers[i-1];
|
layer prev = net.layers[i-1];
|
||||||
state.input = prev.output;
|
state.input = prev.output;
|
||||||
@ -171,6 +172,7 @@ float train_network_datum(network net, float *x, float *y)
|
|||||||
#endif
|
#endif
|
||||||
network_state state;
|
network_state state;
|
||||||
state.input = x;
|
state.input = x;
|
||||||
|
state.delta = 0;
|
||||||
state.truth = y;
|
state.truth = y;
|
||||||
state.train = 1;
|
state.train = 1;
|
||||||
forward_network(net, state);
|
forward_network(net, state);
|
||||||
@ -224,6 +226,7 @@ float train_network_batch(network net, data d, int n)
|
|||||||
int i,j;
|
int i,j;
|
||||||
network_state state;
|
network_state state;
|
||||||
state.train = 1;
|
state.train = 1;
|
||||||
|
state.delta = 0;
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
int batch = 2;
|
int batch = 2;
|
||||||
for(i = 0; i < n; ++i){
|
for(i = 0; i < n; ++i){
|
||||||
@ -249,43 +252,30 @@ void set_batch_network(network *net, int b)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
int resize_network(network *net, int w, int h)
|
||||||
int resize_network(network net, int h, int w, int c)
|
|
||||||
{
|
{
|
||||||
fprintf(stderr, "Might be broken, careful!!");
|
|
||||||
int i;
|
int i;
|
||||||
for (i = 0; i < net.n; ++i){
|
//if(w == net->w && h == net->h) return 0;
|
||||||
if(net.types[i] == CONVOLUTIONAL){
|
net->w = w;
|
||||||
convolutional_layer *layer = (convolutional_layer *)net.layers[i];
|
net->h = h;
|
||||||
resize_convolutional_layer(layer, h, w);
|
//fprintf(stderr, "Resizing to %d x %d...", w, h);
|
||||||
image output = get_convolutional_image(*layer);
|
//fflush(stderr);
|
||||||
h = output.h;
|
for (i = 0; i < net->n; ++i){
|
||||||
w = output.w;
|
layer l = net->layers[i];
|
||||||
c = output.c;
|
if(l.type == CONVOLUTIONAL){
|
||||||
} else if(net.types[i] == DECONVOLUTIONAL){
|
resize_convolutional_layer(&l, w, h);
|
||||||
deconvolutional_layer *layer = (deconvolutional_layer *)net.layers[i];
|
}else if(l.type == MAXPOOL){
|
||||||
resize_deconvolutional_layer(layer, h, w);
|
resize_maxpool_layer(&l, w, h);
|
||||||
image output = get_deconvolutional_image(*layer);
|
|
||||||
h = output.h;
|
|
||||||
w = output.w;
|
|
||||||
c = output.c;
|
|
||||||
}else if(net.types[i] == MAXPOOL){
|
|
||||||
maxpool_layer *layer = (maxpool_layer *)net.layers[i];
|
|
||||||
resize_maxpool_layer(layer, h, w);
|
|
||||||
image output = get_maxpool_image(*layer);
|
|
||||||
h = output.h;
|
|
||||||
w = output.w;
|
|
||||||
c = output.c;
|
|
||||||
}else if(net.types[i] == DROPOUT){
|
|
||||||
dropout_layer *layer = (dropout_layer *)net.layers[i];
|
|
||||||
resize_dropout_layer(layer, h*w*c);
|
|
||||||
}else{
|
}else{
|
||||||
error("Cannot resize this type of layer");
|
error("Cannot resize this type of layer");
|
||||||
}
|
}
|
||||||
|
net->layers[i] = l;
|
||||||
|
w = l.out_w;
|
||||||
|
h = l.out_h;
|
||||||
}
|
}
|
||||||
|
//fprintf(stderr, " Done!\n");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
int get_network_output_size(network net)
|
int get_network_output_size(network net)
|
||||||
{
|
{
|
||||||
|
@ -34,6 +34,8 @@ float *network_predict_gpu(network net, float *input);
|
|||||||
float * get_network_output_gpu_layer(network net, int i);
|
float * get_network_output_gpu_layer(network net, int i);
|
||||||
float * get_network_delta_gpu_layer(network net, int i);
|
float * get_network_delta_gpu_layer(network net, int i);
|
||||||
float *get_network_output_gpu(network net);
|
float *get_network_output_gpu(network net);
|
||||||
|
void forward_network_gpu(network net, network_state state);
|
||||||
|
void backward_network_gpu(network net, network_state state);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void compare_networks(network n1, network n2, data d);
|
void compare_networks(network n1, network n2, data d);
|
||||||
@ -65,7 +67,7 @@ image get_network_image_layer(network net, int i);
|
|||||||
int get_predicted_class_network(network net);
|
int get_predicted_class_network(network net);
|
||||||
void print_network(network net);
|
void print_network(network net);
|
||||||
void visualize_network(network net);
|
void visualize_network(network net);
|
||||||
int resize_network(network net, int h, int w, int c);
|
int resize_network(network *net, int w, int h);
|
||||||
void set_batch_network(network *net, int b);
|
void set_batch_network(network *net, int b);
|
||||||
int get_network_input_size(network net);
|
int get_network_input_size(network net);
|
||||||
float get_network_cost(network net);
|
float get_network_cost(network net);
|
||||||
|
@ -59,11 +59,12 @@ void backward_network_gpu(network net, network_state state)
|
|||||||
{
|
{
|
||||||
int i;
|
int i;
|
||||||
float * original_input = state.input;
|
float * original_input = state.input;
|
||||||
|
float * original_delta = state.delta;
|
||||||
for(i = net.n-1; i >= 0; --i){
|
for(i = net.n-1; i >= 0; --i){
|
||||||
layer l = net.layers[i];
|
layer l = net.layers[i];
|
||||||
if(i == 0){
|
if(i == 0){
|
||||||
state.input = original_input;
|
state.input = original_input;
|
||||||
state.delta = 0;
|
state.delta = original_delta;
|
||||||
}else{
|
}else{
|
||||||
layer prev = net.layers[i-1];
|
layer prev = net.layers[i-1];
|
||||||
state.input = prev.output_gpu;
|
state.input = prev.output_gpu;
|
||||||
@ -120,6 +121,7 @@ float train_network_datum_gpu(network net, float *x, float *y)
|
|||||||
cuda_push_array(*net.truth_gpu, y, y_size);
|
cuda_push_array(*net.truth_gpu, y, y_size);
|
||||||
}
|
}
|
||||||
state.input = *net.input_gpu;
|
state.input = *net.input_gpu;
|
||||||
|
state.delta = 0;
|
||||||
state.truth = *net.truth_gpu;
|
state.truth = *net.truth_gpu;
|
||||||
state.train = 1;
|
state.train = 1;
|
||||||
forward_network_gpu(net, state);
|
forward_network_gpu(net, state);
|
||||||
|
189
src/nightmare.c
Normal file
189
src/nightmare.c
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
|
||||||
|
#include "network.h"
|
||||||
|
#include "parser.h"
|
||||||
|
#include "blas.h"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
float abs_mean(float *x, int n)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
float sum = 0;
|
||||||
|
for (i = 0; i < n; ++i){
|
||||||
|
sum += abs(x[i]);
|
||||||
|
}
|
||||||
|
return sum/n;
|
||||||
|
}
|
||||||
|
|
||||||
|
void calculate_loss(float *output, float *delta, int n, float thresh)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
float mean = mean_array(output, n);
|
||||||
|
float var = variance_array(output, n);
|
||||||
|
for(i = 0; i < n; ++i){
|
||||||
|
if(delta[i] > mean + thresh*sqrt(var)) delta[i] = output[i];
|
||||||
|
else delta[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void optimize_picture(network *net, image orig, int max_layer, float scale, float rate, float thresh)
|
||||||
|
{
|
||||||
|
scale_image(orig, 2);
|
||||||
|
translate_image(orig, -1);
|
||||||
|
net->n = max_layer + 1;
|
||||||
|
|
||||||
|
int dx = rand()%16 - 8;
|
||||||
|
int dy = rand()%16 - 8;
|
||||||
|
int flip = rand()%2;
|
||||||
|
|
||||||
|
image crop = crop_image(orig, dx, dy, orig.w, orig.h);
|
||||||
|
image im = resize_image(crop, (int)(orig.w * scale), (int)(orig.h * scale));
|
||||||
|
if(flip) flip_image(im);
|
||||||
|
|
||||||
|
resize_network(net, im.w, im.h);
|
||||||
|
layer last = net->layers[net->n-1];
|
||||||
|
//net->layers[net->n - 1].activation = LINEAR;
|
||||||
|
|
||||||
|
image delta = make_image(im.w, im.h, im.c);
|
||||||
|
|
||||||
|
network_state state = {0};
|
||||||
|
|
||||||
|
#ifdef GPU
|
||||||
|
state.input = cuda_make_array(im.data, im.w*im.h*im.c);
|
||||||
|
state.delta = cuda_make_array(0, im.w*im.h*im.c);
|
||||||
|
|
||||||
|
forward_network_gpu(*net, state);
|
||||||
|
copy_ongpu(last.outputs, last.output_gpu, 1, last.delta_gpu, 1);
|
||||||
|
|
||||||
|
cuda_pull_array(last.delta_gpu, last.delta, last.outputs);
|
||||||
|
calculate_loss(last.delta, last.delta, last.outputs, thresh);
|
||||||
|
cuda_push_array(last.delta_gpu, last.delta, last.outputs);
|
||||||
|
|
||||||
|
backward_network_gpu(*net, state);
|
||||||
|
|
||||||
|
cuda_pull_array(state.delta, delta.data, im.w*im.h*im.c);
|
||||||
|
cuda_free(state.input);
|
||||||
|
cuda_free(state.delta);
|
||||||
|
#else
|
||||||
|
state.input = im.data;
|
||||||
|
state.delta = delta.data;
|
||||||
|
forward_network(*net, state);
|
||||||
|
copy_cpu(last.outputs, last.output, 1, last.delta, 1);
|
||||||
|
calculate_loss(last.output, last.delta, last.outputs, thresh);
|
||||||
|
backward_network(*net, state);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if(flip) flip_image(delta);
|
||||||
|
//normalize_array(delta.data, delta.w*delta.h*delta.c);
|
||||||
|
image resized = resize_image(delta, orig.w, orig.h);
|
||||||
|
image out = crop_image(resized, -dx, -dy, orig.w, orig.h);
|
||||||
|
|
||||||
|
/*
|
||||||
|
image g = grayscale_image(out);
|
||||||
|
free_image(out);
|
||||||
|
out = g;
|
||||||
|
*/
|
||||||
|
|
||||||
|
//rate = rate / abs_mean(out.data, out.w*out.h*out.c);
|
||||||
|
|
||||||
|
normalize_array(out.data, out.w*out.h*out.c);
|
||||||
|
axpy_cpu(orig.w*orig.h*orig.c, rate, out.data, 1, orig.data, 1);
|
||||||
|
|
||||||
|
/*
|
||||||
|
normalize_array(orig.data, orig.w*orig.h*orig.c);
|
||||||
|
scale_image(orig, sqrt(var));
|
||||||
|
translate_image(orig, mean);
|
||||||
|
*/
|
||||||
|
|
||||||
|
translate_image(orig, 1);
|
||||||
|
scale_image(orig, .5);
|
||||||
|
//normalize_image(orig);
|
||||||
|
|
||||||
|
constrain_image(orig);
|
||||||
|
|
||||||
|
free_image(crop);
|
||||||
|
free_image(im);
|
||||||
|
free_image(delta);
|
||||||
|
free_image(resized);
|
||||||
|
free_image(out);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void run_nightmare(int argc, char **argv)
|
||||||
|
{
|
||||||
|
srand(0);
|
||||||
|
if(argc < 4){
|
||||||
|
fprintf(stderr, "usage: %s %s [cfg] [weights] [image] [layer] [options! (optional)]\n", argv[0], argv[1]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
char *cfg = argv[2];
|
||||||
|
char *weights = argv[3];
|
||||||
|
char *input = argv[4];
|
||||||
|
int max_layer = atoi(argv[5]);
|
||||||
|
|
||||||
|
int range = find_int_arg(argc, argv, "-range", 1);
|
||||||
|
int rounds = find_int_arg(argc, argv, "-rounds", 1);
|
||||||
|
int iters = find_int_arg(argc, argv, "-iters", 10);
|
||||||
|
int octaves = find_int_arg(argc, argv, "-octaves", 4);
|
||||||
|
float zoom = find_float_arg(argc, argv, "-zoom", 1.);
|
||||||
|
float rate = find_float_arg(argc, argv, "-rate", .04);
|
||||||
|
float thresh = find_float_arg(argc, argv, "-thresh", 1.);
|
||||||
|
float rotate = find_float_arg(argc, argv, "-rotate", 0);
|
||||||
|
|
||||||
|
network net = parse_network_cfg(cfg);
|
||||||
|
load_weights(&net, weights);
|
||||||
|
char *cfgbase = basecfg(cfg);
|
||||||
|
char *imbase = basecfg(input);
|
||||||
|
|
||||||
|
set_batch_network(&net, 1);
|
||||||
|
image im = load_image_color(input, 0, 0);
|
||||||
|
if(0){
|
||||||
|
float scale = 1;
|
||||||
|
if(im.w > 512 || im.h > 512){
|
||||||
|
if(im.w > im.h) scale = 512.0/im.w;
|
||||||
|
else scale = 512.0/im.h;
|
||||||
|
}
|
||||||
|
image resized = resize_image(im, scale*im.w, scale*im.h);
|
||||||
|
free_image(im);
|
||||||
|
im = resized;
|
||||||
|
}
|
||||||
|
|
||||||
|
int e;
|
||||||
|
int n;
|
||||||
|
for(e = 0; e < rounds; ++e){
|
||||||
|
fprintf(stderr, "Iteration: ");
|
||||||
|
fflush(stderr);
|
||||||
|
for(n = 0; n < iters; ++n){
|
||||||
|
fprintf(stderr, "%d, ", n);
|
||||||
|
fflush(stderr);
|
||||||
|
int layer = max_layer + rand()%range - range/2;
|
||||||
|
int octave = rand()%octaves;
|
||||||
|
optimize_picture(&net, im, layer, 1/pow(1.33333333, octave), rate, thresh);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "done\n");
|
||||||
|
if(0){
|
||||||
|
image g = grayscale_image(im);
|
||||||
|
free_image(im);
|
||||||
|
im = g;
|
||||||
|
}
|
||||||
|
char buff[256];
|
||||||
|
sprintf(buff, "%s_%s_%d_%06d",imbase, cfgbase, max_layer, e);
|
||||||
|
printf("%d %s\n", e, buff);
|
||||||
|
save_image(im, buff);
|
||||||
|
//show_image(im, buff);
|
||||||
|
//cvWaitKey(0);
|
||||||
|
|
||||||
|
if(rotate){
|
||||||
|
image rot = rotate_image(im, rotate);
|
||||||
|
free_image(im);
|
||||||
|
im = rot;
|
||||||
|
}
|
||||||
|
image crop = crop_image(im, im.w * (1. - zoom)/2., im.h * (1.-zoom)/2., im.w*zoom, im.h*zoom);
|
||||||
|
image resized = resize_image(crop, im.w, im.h);
|
||||||
|
free_image(im);
|
||||||
|
free_image(crop);
|
||||||
|
im = resized;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -343,6 +343,7 @@ network parse_network_cfg(char *filename)
|
|||||||
}else{
|
}else{
|
||||||
fprintf(stderr, "Type not recognized: %s\n", s->type);
|
fprintf(stderr, "Type not recognized: %s\n", s->type);
|
||||||
}
|
}
|
||||||
|
l.dontload = option_find_int_quiet(options, "dontload", 0);
|
||||||
net.layers[count] = l;
|
net.layers[count] = l;
|
||||||
free_section(s);
|
free_section(s);
|
||||||
n = n->next;
|
n = n->next;
|
||||||
@ -527,6 +528,7 @@ void load_weights_upto(network *net, char *filename, int cutoff)
|
|||||||
int i;
|
int i;
|
||||||
for(i = 0; i < net->n && i < cutoff; ++i){
|
for(i = 0; i < net->n && i < cutoff; ++i){
|
||||||
layer l = net->layers[i];
|
layer l = net->layers[i];
|
||||||
|
if (l.dontload) continue;
|
||||||
if(l.type == CONVOLUTIONAL){
|
if(l.type == CONVOLUTIONAL){
|
||||||
int num = l.n*l.c*l.size*l.size;
|
int num = l.n*l.c*l.size*l.size;
|
||||||
fread(l.biases, sizeof(float), l.n, fp);
|
fread(l.biases, sizeof(float), l.n, fp);
|
||||||
|
50
src/utils.c
50
src/utils.c
@ -8,6 +8,56 @@
|
|||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
|
void del_arg(int argc, char **argv, int index)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
for(i = index; i < argc-1; ++i) argv[i] = argv[i+1];
|
||||||
|
argv[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int find_arg(int argc, char* argv[], char *arg)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
for(i = 0; i < argc; ++i) {
|
||||||
|
if(!argv[i]) continue;
|
||||||
|
if(0==strcmp(argv[i], arg)) {
|
||||||
|
del_arg(argc, argv, i);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int find_int_arg(int argc, char **argv, char *arg, int def)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
for(i = 0; i < argc-1; ++i){
|
||||||
|
if(!argv[i]) continue;
|
||||||
|
if(0==strcmp(argv[i], arg)){
|
||||||
|
def = atoi(argv[i+1]);
|
||||||
|
del_arg(argc, argv, i);
|
||||||
|
del_arg(argc, argv, i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return def;
|
||||||
|
}
|
||||||
|
|
||||||
|
float find_float_arg(int argc, char **argv, char *arg, float def)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
for(i = 0; i < argc-1; ++i){
|
||||||
|
if(!argv[i]) continue;
|
||||||
|
if(0==strcmp(argv[i], arg)){
|
||||||
|
def = atof(argv[i+1]);
|
||||||
|
del_arg(argc, argv, i);
|
||||||
|
del_arg(argc, argv, i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return def;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
char *basecfg(char *cfgfile)
|
char *basecfg(char *cfgfile)
|
||||||
{
|
{
|
||||||
|
@ -36,6 +36,9 @@ float variance_array(float *a, int n);
|
|||||||
float mag_array(float *a, int n);
|
float mag_array(float *a, int n);
|
||||||
float **one_hot_encode(float *a, int n, int k);
|
float **one_hot_encode(float *a, int n, int k);
|
||||||
float sec(clock_t clocks);
|
float sec(clock_t clocks);
|
||||||
|
int find_int_arg(int argc, char **argv, char *arg, int def);
|
||||||
|
float find_float_arg(int argc, char **argv, char *arg, float def);
|
||||||
|
int find_arg(int argc, char* argv[], char *arg);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user