diff --git a/Makefile b/Makefile index 19bff326..71802896 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ CFLAGS+= -DGPU LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand endif -OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o +OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o ifeq ($(GPU), 1) OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o endif @@ -58,7 +58,6 @@ obj: results: mkdir -p results - .PHONY: clean clean: diff --git a/cfg/strided.cfg b/cfg/strided.cfg index f29f2667..31d8155b 100644 --- a/cfg/strided.cfg +++ b/cfg/strided.cfg @@ -13,9 +13,9 @@ seen=0 crop_height=224 crop_width=224 flip=1 -angle=15 -saturation=1.5 -exposure=1.5 +angle=0 +saturation=1 +exposure=1 [convolutional] filters=64 diff --git a/cfg/vgg-16.cfg b/cfg/vgg-16.cfg index e296eb94..72133d9e 100644 --- a/cfg/vgg-16.cfg +++ b/cfg/vgg-16.cfg @@ -13,9 +13,9 @@ decay=0.0005 crop_height=224 crop_width=224 flip=1 -exposure=2 -saturation=2 -angle=5 +exposure=1 +saturation=1 +angle=0 [convolutional] filters=64 diff --git a/cfg/vgg-conv.cfg b/cfg/vgg-conv.cfg new file mode 100644 index 00000000..ab0fb1e1 --- /dev/null +++ b/cfg/vgg-conv.cfg @@ -0,0 +1,122 @@ +[net] +batch=1 +subdivisions=1 +width=224 +height=224 +channels=3 +learning_rate=0.00001 +momentum=0.9 +seen=0 +decay=0.0005 + +[convolutional] +filters=64 +size=3 +stride=1 +pad=1 +activation=relu + +[convolutional] +filters=64 +size=3 +stride=1 +pad=1 +activation=relu + +[maxpool] +size=2 +stride=2 + +[convolutional] +filters=128 +size=3 +stride=1 +pad=1 +activation=relu + +[convolutional] +filters=128 +size=3 +stride=1 +pad=1 +activation=relu + +[maxpool] +size=2 +stride=2 + +[convolutional] +filters=256 +size=3 +stride=1 +pad=1 +activation=relu + +[convolutional] +filters=256 +size=3 +stride=1 +pad=1 +activation=relu + +[convolutional] +filters=256 +size=3 +stride=1 +pad=1 +activation=relu + +[maxpool] +size=2 +stride=2 + +[convolutional] +filters=512 +size=3 +stride=1 +pad=1 +activation=relu + +[convolutional] +filters=512 +size=3 +stride=1 +pad=1 +activation=relu + +[convolutional] +filters=512 +size=3 +stride=1 +pad=1 +activation=relu + +[maxpool] +size=2 +stride=2 + +[convolutional] +filters=512 +size=3 +stride=1 +pad=1 +activation=relu + +[convolutional] +filters=512 +size=3 +stride=1 +pad=1 +activation=relu + +[convolutional] +filters=512 +size=3 +stride=1 +pad=1 +activation=relu + +[maxpool] +size=2 +stride=2 + diff --git a/data/scream.jpg b/data/scream.jpg new file mode 100644 index 00000000..40aea94d Binary files /dev/null and b/data/scream.jpg differ diff --git a/src/activation_kernels.cu b/src/activation_kernels.cu index fb1126ed..16dd4d2a 100644 --- a/src/activation_kernels.cu +++ b/src/activation_kernels.cu @@ -8,6 +8,7 @@ __device__ float logistic_activate_kernel(float x){return 1./(1. + exp(-x));} __device__ float relu_activate_kernel(float x){return x*(x>0);} __device__ float relie_activate_kernel(float x){return x*(x>0);} __device__ float ramp_activate_kernel(float x){return x*(x>0)+.1*x;} +__device__ float leaky_activate_kernel(float x){return (x>0) ? x : .1*x;} __device__ float tanh_activate_kernel(float x){return (exp(2*x)-1)/(exp(2*x)+1);} __device__ float plse_activate_kernel(float x) { @@ -21,6 +22,7 @@ __device__ float logistic_gradient_kernel(float x){return (1-x)*x;} __device__ float relu_gradient_kernel(float x){return (x>0);} __device__ float relie_gradient_kernel(float x){return (x>0) ? 1 : .01;} __device__ float ramp_gradient_kernel(float x){return (x>0)+.1;} +__device__ float leaky_gradient_kernel(float x){return (x>0) ? 1 : .1;} __device__ float tanh_gradient_kernel(float x){return 1-x*x;} __device__ float plse_gradient_kernel(float x){return (x < 0 || x > 1) ? .01 : .125;} @@ -37,6 +39,8 @@ __device__ float activate_kernel(float x, ACTIVATION a) return relie_activate_kernel(x); case RAMP: return ramp_activate_kernel(x); + case LEAKY: + return leaky_activate_kernel(x); case TANH: return tanh_activate_kernel(x); case PLSE: @@ -58,6 +62,8 @@ __device__ float gradient_kernel(float x, ACTIVATION a) return relie_gradient_kernel(x); case RAMP: return ramp_gradient_kernel(x); + case LEAKY: + return leaky_gradient_kernel(x); case TANH: return tanh_gradient_kernel(x); case PLSE: diff --git a/src/activations.c b/src/activations.c index 8b607c21..d31b1e41 100644 --- a/src/activations.c +++ b/src/activations.c @@ -22,6 +22,8 @@ char *get_activation_string(ACTIVATION a) return "tanh"; case PLSE: return "plse"; + case LEAKY: + return "leaky"; default: break; } @@ -36,6 +38,7 @@ ACTIVATION get_activation(char *s) if (strcmp(s, "plse")==0) return PLSE; if (strcmp(s, "linear")==0) return LINEAR; if (strcmp(s, "ramp")==0) return RAMP; + if (strcmp(s, "leaky")==0) return LEAKY; if (strcmp(s, "tanh")==0) return TANH; fprintf(stderr, "Couldn't find activation function %s, going with ReLU\n", s); return RELU; @@ -54,6 +57,8 @@ float activate(float x, ACTIVATION a) return relie_activate(x); case RAMP: return ramp_activate(x); + case LEAKY: + return leaky_activate(x); case TANH: return tanh_activate(x); case PLSE: @@ -83,6 +88,8 @@ float gradient(float x, ACTIVATION a) return relie_gradient(x); case RAMP: return ramp_gradient(x); + case LEAKY: + return leaky_gradient(x); case TANH: return tanh_gradient(x); case PLSE: diff --git a/src/activations.h b/src/activations.h index 0aa61c2a..22a713af 100644 --- a/src/activations.h +++ b/src/activations.h @@ -4,7 +4,7 @@ #include "math.h" typedef enum{ - LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE + LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY }ACTIVATION; ACTIVATION get_activation(char *s); @@ -24,6 +24,7 @@ static inline float logistic_activate(float x){return 1./(1. + exp(-x));} static inline float relu_activate(float x){return x*(x>0);} static inline float relie_activate(float x){return x*(x>0);} static inline float ramp_activate(float x){return x*(x>0)+.1*x;} +static inline float leaky_activate(float x){return (x>0) ? x : .1*x;} static inline float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);} static inline float plse_activate(float x) { @@ -37,6 +38,7 @@ static inline float logistic_gradient(float x){return (1-x)*x;} static inline float relu_gradient(float x){return (x>0);} static inline float relie_gradient(float x){return (x>0) ? 1 : .01;} static inline float ramp_gradient(float x){return (x>0)+.1;} +static inline float leaky_gradient(float x){return (x>0) ? 1 : .1;} static inline float tanh_gradient(float x){return 1-x*x;} static inline float plse_gradient(float x){return (x < 0 || x > 1) ? .01 : .125;} diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 9c0dabe7..c2669348 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -97,12 +97,18 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int return l; } -void resize_convolutional_layer(convolutional_layer *l, int h, int w) +void resize_convolutional_layer(convolutional_layer *l, int w, int h) { - l->h = h; l->w = w; - int out_h = convolutional_out_height(*l); + l->h = h; int out_w = convolutional_out_width(*l); + int out_h = convolutional_out_height(*l); + + l->out_w = out_w; + l->out_h = out_h; + + l->outputs = l->out_h * l->out_w * l->out_c; + l->inputs = l->w * l->h * l->c; l->col_image = realloc(l->col_image, out_h*out_w*l->size*l->size*l->c*sizeof(float)); @@ -116,9 +122,9 @@ void resize_convolutional_layer(convolutional_layer *l, int h, int w) cuda_free(l->delta_gpu); cuda_free(l->output_gpu); - l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c); - l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n); - l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n); + l->col_image_gpu = cuda_make_array(0, out_h*out_w*l->size*l->size*l->c); + l->delta_gpu = cuda_make_array(0, l->batch*out_h*out_w*l->n); + l->output_gpu = cuda_make_array(0, l->batch*out_h*out_w*l->n); #endif } diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index 334759cd..3954f8a9 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -22,7 +22,7 @@ void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int #endif convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation); -void resize_convolutional_layer(convolutional_layer *layer, int h, int w); +void resize_convolutional_layer(convolutional_layer *layer, int w, int h); void forward_convolutional_layer(const convolutional_layer layer, network_state state); void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay); image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters); diff --git a/src/darknet.c b/src/darknet.c index 0a705da7..a34ccfba 100644 --- a/src/darknet.c +++ b/src/darknet.c @@ -13,41 +13,7 @@ extern void run_imagenet(int argc, char **argv); extern void run_detection(int argc, char **argv); extern void run_writing(int argc, char **argv); extern void run_captcha(int argc, char **argv); - -void del_arg(int argc, char **argv, int index) -{ - int i; - for(i = index; i < argc-1; ++i) argv[i] = argv[i+1]; - argv[i] = 0; -} - -int find_arg(int argc, char* argv[], char *arg) -{ - int i; - for(i = 0; i < argc; ++i) { - if(!argv[i]) continue; - if(0==strcmp(argv[i], arg)) { - del_arg(argc, argv, i); - return 1; - } - } - return 0; -} - -int find_int_arg(int argc, char **argv, char *arg, int def) -{ - int i; - for(i = 0; i < argc-1; ++i){ - if(!argv[i]) continue; - if(0==strcmp(argv[i], arg)){ - def = atoi(argv[i+1]); - del_arg(argc, argv, i); - del_arg(argc, argv, i); - break; - } - } - return def; -} +extern void run_nightmare(int argc, char **argv); void change_rate(char *filename, float scale, float add) { @@ -135,6 +101,8 @@ int main(int argc, char **argv) test_resize(argv[2]); } else if (0 == strcmp(argv[1], "captcha")){ run_captcha(argc, argv); + } else if (0 == strcmp(argv[1], "nightmare")){ + run_nightmare(argc, argv); } else if (0 == strcmp(argv[1], "change")){ change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0); } else if (0 == strcmp(argv[1], "rgbgr")){ diff --git a/src/image.c b/src/image.c index 657db68a..046de0c9 100644 --- a/src/image.c +++ b/src/image.c @@ -187,6 +187,7 @@ void show_image_cv(image p, char *name) { int x,y,k; image copy = copy_image(p); + constrain_image(copy); rgbgr_image(copy); //normalize_image(copy); @@ -207,7 +208,8 @@ void show_image_cv(image p, char *name) } } free_image(copy); - if(disp->height < 448 || disp->width < 448 || disp->height > 1000){ + if(0){ + //if(disp->height < 448 || disp->width < 448 || disp->height > 1000){ int w = 448; int h = w*p.h/p.w; if(h > 1000){ diff --git a/src/image.h b/src/image.h index 18065749..578d6900 100644 --- a/src/image.h +++ b/src/image.h @@ -37,6 +37,8 @@ void exposure_image(image im, float sat); void saturate_exposure_image(image im, float sat, float exposure); void hsv_to_rgb(image im); void rgbgr_image(image im); +void constrain_image(image im); +image grayscale_image(image im); image collapse_image_layers(image source, int border); image collapse_images_horz(image *ims, int n); diff --git a/src/imagenet.c b/src/imagenet.c index fe66a657..0e272fc1 100644 --- a/src/imagenet.c +++ b/src/imagenet.c @@ -48,7 +48,6 @@ void train_imagenet(char *cfgfile, char *weightfile) printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen); free_data(train); if((i % 30000) == 0) net.learning_rate *= .1; - //if(i%100 == 0 && net.learning_rate > .00001) net.learning_rate *= .97; if(i%1000==0){ char buff[256]; sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i); diff --git a/src/layer.h b/src/layer.h index 7255f499..a591f038 100644 --- a/src/layer.h +++ b/src/layer.h @@ -48,6 +48,8 @@ typedef struct { int does_cost; int joint; + int dontload; + float probability; float scale; int *indexes; diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c index 159bd0af..bc3aa68f 100644 --- a/src/maxpool_layer.c +++ b/src/maxpool_layer.c @@ -4,16 +4,16 @@ image get_maxpool_image(maxpool_layer l) { - int h = (l.h-1)/l.stride + 1; - int w = (l.w-1)/l.stride + 1; + int h = l.out_h; + int w = l.out_w; int c = l.c; return float_to_image(w,h,c,l.output); } image get_maxpool_delta(maxpool_layer l) { - int h = (l.h-1)/l.stride + 1; - int w = (l.w-1)/l.stride + 1; + int h = l.out_h; + int w = l.out_w; int c = l.c; return float_to_image(w,h,c,l.delta); } @@ -27,11 +27,11 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s l.h = h; l.w = w; l.c = c; - l.out_h = (h-1)/stride + 1; l.out_w = (w-1)/stride + 1; + l.out_h = (h-1)/stride + 1; l.out_c = c; l.outputs = l.out_h * l.out_w * l.out_c; - l.inputs = l.outputs; + l.inputs = h*w*c; l.size = size; l.stride = stride; int output_size = l.out_h * l.out_w * l.out_c * batch; @@ -46,11 +46,18 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s return l; } -void resize_maxpool_layer(maxpool_layer *l, int h, int w) +void resize_maxpool_layer(maxpool_layer *l, int w, int h) { + int stride = l->stride; l->h = h; l->w = w; - int output_size = ((h-1)/l->stride+1) * ((w-1)/l->stride+1) * l->c * l->batch; + + l->out_w = (w-1)/stride + 1; + l->out_h = (h-1)/stride + 1; + l->outputs = l->out_w * l->out_h * l->c; + int output_size = l->outputs * l->batch; + + l->indexes = realloc(l->indexes, output_size * sizeof(int)); l->output = realloc(l->output, output_size * sizeof(float)); l->delta = realloc(l->delta, output_size * sizeof(float)); @@ -59,8 +66,8 @@ void resize_maxpool_layer(maxpool_layer *l, int h, int w) cuda_free(l->output_gpu); cuda_free(l->delta_gpu); l->indexes_gpu = cuda_make_int_array(output_size); - l->output_gpu = cuda_make_array(l->output, output_size); - l->delta_gpu = cuda_make_array(l->delta, output_size); + l->output_gpu = cuda_make_array(0, output_size); + l->delta_gpu = cuda_make_array(0, output_size); #endif } diff --git a/src/maxpool_layer.h b/src/maxpool_layer.h index 4456863b..ab13874b 100644 --- a/src/maxpool_layer.h +++ b/src/maxpool_layer.h @@ -10,7 +10,7 @@ typedef layer maxpool_layer; image get_maxpool_image(maxpool_layer l); maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride); -void resize_maxpool_layer(maxpool_layer *l, int h, int w); +void resize_maxpool_layer(maxpool_layer *l, int w, int h); void forward_maxpool_layer(const maxpool_layer l, network_state state); void backward_maxpool_layer(const maxpool_layer l, network_state state); diff --git a/src/network.c b/src/network.c index 68790e58..c691600b 100644 --- a/src/network.c +++ b/src/network.c @@ -132,10 +132,11 @@ void backward_network(network net, network_state state) { int i; float *original_input = state.input; + float *original_delta = state.delta; for(i = net.n-1; i >= 0; --i){ if(i == 0){ state.input = original_input; - state.delta = 0; + state.delta = original_delta; }else{ layer prev = net.layers[i-1]; state.input = prev.output; @@ -171,6 +172,7 @@ float train_network_datum(network net, float *x, float *y) #endif network_state state; state.input = x; + state.delta = 0; state.truth = y; state.train = 1; forward_network(net, state); @@ -224,6 +226,7 @@ float train_network_batch(network net, data d, int n) int i,j; network_state state; state.train = 1; + state.delta = 0; float sum = 0; int batch = 2; for(i = 0; i < n; ++i){ @@ -249,43 +252,30 @@ void set_batch_network(network *net, int b) } } -/* -int resize_network(network net, int h, int w, int c) +int resize_network(network *net, int w, int h) { - fprintf(stderr, "Might be broken, careful!!"); int i; - for (i = 0; i < net.n; ++i){ - if(net.types[i] == CONVOLUTIONAL){ - convolutional_layer *layer = (convolutional_layer *)net.layers[i]; - resize_convolutional_layer(layer, h, w); - image output = get_convolutional_image(*layer); - h = output.h; - w = output.w; - c = output.c; - } else if(net.types[i] == DECONVOLUTIONAL){ - deconvolutional_layer *layer = (deconvolutional_layer *)net.layers[i]; - resize_deconvolutional_layer(layer, h, w); - image output = get_deconvolutional_image(*layer); - h = output.h; - w = output.w; - c = output.c; - }else if(net.types[i] == MAXPOOL){ - maxpool_layer *layer = (maxpool_layer *)net.layers[i]; - resize_maxpool_layer(layer, h, w); - image output = get_maxpool_image(*layer); - h = output.h; - w = output.w; - c = output.c; - }else if(net.types[i] == DROPOUT){ - dropout_layer *layer = (dropout_layer *)net.layers[i]; - resize_dropout_layer(layer, h*w*c); + //if(w == net->w && h == net->h) return 0; + net->w = w; + net->h = h; + //fprintf(stderr, "Resizing to %d x %d...", w, h); + //fflush(stderr); + for (i = 0; i < net->n; ++i){ + layer l = net->layers[i]; + if(l.type == CONVOLUTIONAL){ + resize_convolutional_layer(&l, w, h); + }else if(l.type == MAXPOOL){ + resize_maxpool_layer(&l, w, h); }else{ error("Cannot resize this type of layer"); } + net->layers[i] = l; + w = l.out_w; + h = l.out_h; } + //fprintf(stderr, " Done!\n"); return 0; } -*/ int get_network_output_size(network net) { diff --git a/src/network.h b/src/network.h index 9a8033c8..b684d33a 100644 --- a/src/network.h +++ b/src/network.h @@ -34,6 +34,8 @@ float *network_predict_gpu(network net, float *input); float * get_network_output_gpu_layer(network net, int i); float * get_network_delta_gpu_layer(network net, int i); float *get_network_output_gpu(network net); +void forward_network_gpu(network net, network_state state); +void backward_network_gpu(network net, network_state state); #endif void compare_networks(network n1, network n2, data d); @@ -65,7 +67,7 @@ image get_network_image_layer(network net, int i); int get_predicted_class_network(network net); void print_network(network net); void visualize_network(network net); -int resize_network(network net, int h, int w, int c); +int resize_network(network *net, int w, int h); void set_batch_network(network *net, int b); int get_network_input_size(network net); float get_network_cost(network net); diff --git a/src/network_kernels.cu b/src/network_kernels.cu index 5e353aee..36f55944 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -59,11 +59,12 @@ void backward_network_gpu(network net, network_state state) { int i; float * original_input = state.input; + float * original_delta = state.delta; for(i = net.n-1; i >= 0; --i){ layer l = net.layers[i]; if(i == 0){ state.input = original_input; - state.delta = 0; + state.delta = original_delta; }else{ layer prev = net.layers[i-1]; state.input = prev.output_gpu; @@ -120,6 +121,7 @@ float train_network_datum_gpu(network net, float *x, float *y) cuda_push_array(*net.truth_gpu, y, y_size); } state.input = *net.input_gpu; + state.delta = 0; state.truth = *net.truth_gpu; state.train = 1; forward_network_gpu(net, state); diff --git a/src/nightmare.c b/src/nightmare.c new file mode 100644 index 00000000..882c0ebe --- /dev/null +++ b/src/nightmare.c @@ -0,0 +1,189 @@ + +#include "network.h" +#include "parser.h" +#include "blas.h" +#include "utils.h" + +float abs_mean(float *x, int n) +{ + int i; + float sum = 0; + for (i = 0; i < n; ++i){ + sum += abs(x[i]); + } + return sum/n; +} + +void calculate_loss(float *output, float *delta, int n, float thresh) +{ + int i; + float mean = mean_array(output, n); + float var = variance_array(output, n); + for(i = 0; i < n; ++i){ + if(delta[i] > mean + thresh*sqrt(var)) delta[i] = output[i]; + else delta[i] = 0; + } +} + +void optimize_picture(network *net, image orig, int max_layer, float scale, float rate, float thresh) +{ + scale_image(orig, 2); + translate_image(orig, -1); + net->n = max_layer + 1; + + int dx = rand()%16 - 8; + int dy = rand()%16 - 8; + int flip = rand()%2; + + image crop = crop_image(orig, dx, dy, orig.w, orig.h); + image im = resize_image(crop, (int)(orig.w * scale), (int)(orig.h * scale)); + if(flip) flip_image(im); + + resize_network(net, im.w, im.h); + layer last = net->layers[net->n-1]; + //net->layers[net->n - 1].activation = LINEAR; + + image delta = make_image(im.w, im.h, im.c); + + network_state state = {0}; + +#ifdef GPU + state.input = cuda_make_array(im.data, im.w*im.h*im.c); + state.delta = cuda_make_array(0, im.w*im.h*im.c); + + forward_network_gpu(*net, state); + copy_ongpu(last.outputs, last.output_gpu, 1, last.delta_gpu, 1); + + cuda_pull_array(last.delta_gpu, last.delta, last.outputs); + calculate_loss(last.delta, last.delta, last.outputs, thresh); + cuda_push_array(last.delta_gpu, last.delta, last.outputs); + + backward_network_gpu(*net, state); + + cuda_pull_array(state.delta, delta.data, im.w*im.h*im.c); + cuda_free(state.input); + cuda_free(state.delta); +#else + state.input = im.data; + state.delta = delta.data; + forward_network(*net, state); + copy_cpu(last.outputs, last.output, 1, last.delta, 1); + calculate_loss(last.output, last.delta, last.outputs, thresh); + backward_network(*net, state); +#endif + + if(flip) flip_image(delta); + //normalize_array(delta.data, delta.w*delta.h*delta.c); + image resized = resize_image(delta, orig.w, orig.h); + image out = crop_image(resized, -dx, -dy, orig.w, orig.h); + + /* + image g = grayscale_image(out); + free_image(out); + out = g; + */ + + //rate = rate / abs_mean(out.data, out.w*out.h*out.c); + + normalize_array(out.data, out.w*out.h*out.c); + axpy_cpu(orig.w*orig.h*orig.c, rate, out.data, 1, orig.data, 1); + + /* + normalize_array(orig.data, orig.w*orig.h*orig.c); + scale_image(orig, sqrt(var)); + translate_image(orig, mean); + */ + + translate_image(orig, 1); + scale_image(orig, .5); + //normalize_image(orig); + + constrain_image(orig); + + free_image(crop); + free_image(im); + free_image(delta); + free_image(resized); + free_image(out); + +} + + +void run_nightmare(int argc, char **argv) +{ + srand(0); + if(argc < 4){ + fprintf(stderr, "usage: %s %s [cfg] [weights] [image] [layer] [options! (optional)]\n", argv[0], argv[1]); + return; + } + + char *cfg = argv[2]; + char *weights = argv[3]; + char *input = argv[4]; + int max_layer = atoi(argv[5]); + + int range = find_int_arg(argc, argv, "-range", 1); + int rounds = find_int_arg(argc, argv, "-rounds", 1); + int iters = find_int_arg(argc, argv, "-iters", 10); + int octaves = find_int_arg(argc, argv, "-octaves", 4); + float zoom = find_float_arg(argc, argv, "-zoom", 1.); + float rate = find_float_arg(argc, argv, "-rate", .04); + float thresh = find_float_arg(argc, argv, "-thresh", 1.); + float rotate = find_float_arg(argc, argv, "-rotate", 0); + + network net = parse_network_cfg(cfg); + load_weights(&net, weights); + char *cfgbase = basecfg(cfg); + char *imbase = basecfg(input); + + set_batch_network(&net, 1); + image im = load_image_color(input, 0, 0); + if(0){ + float scale = 1; + if(im.w > 512 || im.h > 512){ + if(im.w > im.h) scale = 512.0/im.w; + else scale = 512.0/im.h; + } + image resized = resize_image(im, scale*im.w, scale*im.h); + free_image(im); + im = resized; + } + + int e; + int n; + for(e = 0; e < rounds; ++e){ + fprintf(stderr, "Iteration: "); + fflush(stderr); + for(n = 0; n < iters; ++n){ + fprintf(stderr, "%d, ", n); + fflush(stderr); + int layer = max_layer + rand()%range - range/2; + int octave = rand()%octaves; + optimize_picture(&net, im, layer, 1/pow(1.33333333, octave), rate, thresh); + } + fprintf(stderr, "done\n"); + if(0){ + image g = grayscale_image(im); + free_image(im); + im = g; + } + char buff[256]; + sprintf(buff, "%s_%s_%d_%06d",imbase, cfgbase, max_layer, e); + printf("%d %s\n", e, buff); + save_image(im, buff); + //show_image(im, buff); + //cvWaitKey(0); + + if(rotate){ + image rot = rotate_image(im, rotate); + free_image(im); + im = rot; + } + image crop = crop_image(im, im.w * (1. - zoom)/2., im.h * (1.-zoom)/2., im.w*zoom, im.h*zoom); + image resized = resize_image(crop, im.w, im.h); + free_image(im); + free_image(crop); + im = resized; + } +} + diff --git a/src/parser.c b/src/parser.c index 2caf96e5..18c38603 100644 --- a/src/parser.c +++ b/src/parser.c @@ -343,6 +343,7 @@ network parse_network_cfg(char *filename) }else{ fprintf(stderr, "Type not recognized: %s\n", s->type); } + l.dontload = option_find_int_quiet(options, "dontload", 0); net.layers[count] = l; free_section(s); n = n->next; @@ -527,6 +528,7 @@ void load_weights_upto(network *net, char *filename, int cutoff) int i; for(i = 0; i < net->n && i < cutoff; ++i){ layer l = net->layers[i]; + if (l.dontload) continue; if(l.type == CONVOLUTIONAL){ int num = l.n*l.c*l.size*l.size; fread(l.biases, sizeof(float), l.n, fp); diff --git a/src/utils.c b/src/utils.c index 63664ec9..af22caab 100644 --- a/src/utils.c +++ b/src/utils.c @@ -8,6 +8,56 @@ #include "utils.h" +void del_arg(int argc, char **argv, int index) +{ + int i; + for(i = index; i < argc-1; ++i) argv[i] = argv[i+1]; + argv[i] = 0; +} + +int find_arg(int argc, char* argv[], char *arg) +{ + int i; + for(i = 0; i < argc; ++i) { + if(!argv[i]) continue; + if(0==strcmp(argv[i], arg)) { + del_arg(argc, argv, i); + return 1; + } + } + return 0; +} + +int find_int_arg(int argc, char **argv, char *arg, int def) +{ + int i; + for(i = 0; i < argc-1; ++i){ + if(!argv[i]) continue; + if(0==strcmp(argv[i], arg)){ + def = atoi(argv[i+1]); + del_arg(argc, argv, i); + del_arg(argc, argv, i); + break; + } + } + return def; +} + +float find_float_arg(int argc, char **argv, char *arg, float def) +{ + int i; + for(i = 0; i < argc-1; ++i){ + if(!argv[i]) continue; + if(0==strcmp(argv[i], arg)){ + def = atof(argv[i+1]); + del_arg(argc, argv, i); + del_arg(argc, argv, i); + break; + } + } + return def; +} + char *basecfg(char *cfgfile) { diff --git a/src/utils.h b/src/utils.h index 7148e547..674fc181 100644 --- a/src/utils.h +++ b/src/utils.h @@ -36,6 +36,9 @@ float variance_array(float *a, int n); float mag_array(float *a, int n); float **one_hot_encode(float *a, int n, int k); float sec(clock_t clocks); +int find_int_arg(int argc, char **argv, char *arg, int def); +float find_float_arg(int argc, char **argv, char *arg, float def); +int find_arg(int argc, char* argv[], char *arg); #endif