From addcc4ef9692f874123bcb861c70d4b07bb3b960 Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Wed, 12 Apr 2017 14:22:53 -0700 Subject: [PATCH] working on TED demo --- cfg/darknet9000.cfg | 205 +++++++++++++++++++++++++++++++ cfg/imagenet9k.hierarchy.dataset | 9 ++ cfg/yolo9000.cfg | 19 ++- src/classifier.c | 24 +--- src/convolutional_kernels.cu | 10 +- src/convolutional_layer.c | 2 +- src/convolutional_layer.h | 2 +- src/data.c | 19 +-- src/data.h | 5 +- src/deconvolutional_kernels.cu | 6 +- src/demo.c | 16 ++- src/demo.h | 4 +- src/detector.c | 4 +- src/image.c | 9 ++ src/image.h | 1 + src/lsd.c | 70 +++++------ src/network.c | 3 +- src/network.h | 1 + src/nightmare.c | 36 +++--- src/parser.c | 9 +- src/region_layer.c | 23 +++- src/region_layer.h | 1 + src/utils.c | 23 ++++ src/utils.h | 1 + 24 files changed, 392 insertions(+), 110 deletions(-) create mode 100644 cfg/darknet9000.cfg create mode 100644 cfg/imagenet9k.hierarchy.dataset diff --git a/cfg/darknet9000.cfg b/cfg/darknet9000.cfg new file mode 100644 index 00000000..9dd2dfbb --- /dev/null +++ b/cfg/darknet9000.cfg @@ -0,0 +1,205 @@ +[net] +# Training +# batch=128 +# subdivisions=4 +# Testing +batch = 1 +subdivisions = 1 +height=448 +width=448 +max_crop=512 +channels=3 +momentum=0.9 +decay=0.0005 + +learning_rate=0.001 +policy=poly +power=4 +max_batches=100000 + +angle=7 +hue=.1 +saturation=.75 +exposure=.75 +aspect=.75 + +[convolutional] +batch_normalize=1 +filters=32 +size=3 +stride=1 +pad=1 +activation=leaky + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=64 +size=3 +stride=1 +pad=1 +activation=leaky + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=128 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=64 +size=1 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=128 +size=3 +stride=1 +pad=1 +activation=leaky + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=256 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=128 +size=1 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=256 +size=3 +stride=1 +pad=1 +activation=leaky + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=512 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=256 +size=1 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=512 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=256 +size=1 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=512 +size=3 +stride=1 +pad=1 +activation=leaky + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=1024 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=512 +size=1 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=1024 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=512 +size=1 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=1024 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +filters=9418 +size=1 +stride=1 +pad=1 +activation=linear + +[avgpool] + +[softmax] +groups=1 +tree=data/9k.tree + +[cost] +type=masked + diff --git a/cfg/imagenet9k.hierarchy.dataset b/cfg/imagenet9k.hierarchy.dataset new file mode 100644 index 00000000..41fb71b0 --- /dev/null +++ b/cfg/imagenet9k.hierarchy.dataset @@ -0,0 +1,9 @@ +classes=9418 +train = data/9k.train.list +valid = /data/imagenet/imagenet1k.valid.list +leaves = data/imagenet1k.labels +backup = /home/pjreddie/backup/ +labels = data/9k.labels +names = data/9k.names +top=5 + diff --git a/cfg/yolo9000.cfg b/cfg/yolo9000.cfg index 981491d8..9a8dd62f 100644 --- a/cfg/yolo9000.cfg +++ b/cfg/yolo9000.cfg @@ -1,17 +1,24 @@ [net] +# Testing +# batch=1 +# subdivisions=1 +# Training +batch=64 +subdivisions=8 batch=1 subdivisions=1 -height=416 -width=416 +height=544 +width=544 channels=3 momentum=0.9 decay=0.0005 -learning_rate=0.00001 -max_batches = 242200 +learning_rate=0.001 +burn_in=1000 +max_batches = 500200 policy=steps -steps=500,200000,240000 -scales=10,.1,.1 +steps=400000,450000 +scales=.1,.1 hue=.1 saturation=.75 diff --git a/src/classifier.c b/src/classifier.c index 1b7ff38a..32fd288a 100644 --- a/src/classifier.c +++ b/src/classifier.c @@ -1113,27 +1113,9 @@ void run_classifier(int argc, char **argv) } char *gpu_list = find_char_arg(argc, argv, "-gpus", 0); - int *gpus = 0; - int gpu = 0; - int ngpus = 0; - if(gpu_list){ - printf("%s\n", gpu_list); - int len = strlen(gpu_list); - ngpus = 1; - int i; - for(i = 0; i < len; ++i){ - if (gpu_list[i] == ',') ++ngpus; - } - gpus = calloc(ngpus, sizeof(int)); - for(i = 0; i < ngpus; ++i){ - gpus[i] = atoi(gpu_list); - gpu_list = strchr(gpu_list, ',')+1; - } - } else { - gpu = gpu_index; - gpus = &gpu; - ngpus = 1; - } + int ngpus; + int *gpus = read_intlist(gpu_list, &ngpus, gpu_index); + int cam_index = find_int_arg(argc, argv, "-c", 0); int top = find_int_arg(argc, argv, "-t", 0); diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index 41dec50d..b53dd16a 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -286,7 +286,7 @@ void push_convolutional_layer(convolutional_layer layer) } } -void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, float eps, float decay, float rate, int n, int batch) +void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, float eps, float decay, float rate, int n, int batch, int t) { scal_ongpu(n, B1, m, 1); scal_ongpu(n, B2, v, 1); @@ -296,7 +296,7 @@ void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, mul_ongpu(n, d, 1, d, 1); axpy_ongpu(n, (1-B2), d, 1, v, 1); - adam_gpu(n, w, m, v, B1, B2, rate/batch, eps, 1000); + adam_gpu(n, w, m, v, B1, B2, rate/batch, eps, t); fill_ongpu(n, 0, d, 1); } @@ -305,10 +305,10 @@ void update_convolutional_layer_gpu(layer l, int batch, float learning_rate, flo int size = l.size*l.size*l.c*l.n; if(l.adam){ - adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, size, batch); - adam_update_gpu(l.biases_gpu, l.bias_updates_gpu, l.bias_m_gpu, l.bias_v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, l.n, batch); + adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, size, batch, l.t); + adam_update_gpu(l.biases_gpu, l.bias_updates_gpu, l.bias_m_gpu, l.bias_v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, l.n, batch, l.t); if(l.scales_gpu){ - adam_update_gpu(l.scales_gpu, l.scale_updates_gpu, l.scale_m_gpu, l.scale_v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, l.n, batch); + adam_update_gpu(l.scales_gpu, l.scale_updates_gpu, l.scale_m_gpu, l.scale_v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, l.n, batch, l.t); } }else{ axpy_ongpu(size, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1); diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 182a113d..e5b5bb6f 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -188,7 +188,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int // float scale = 1./sqrt(size*size*c); float scale = sqrt(2./(size*size*c)); - scale = .02; + //scale = .02; //for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_uniform(-1, 1); for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_normal(); int out_w = convolutional_out_width(l); diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index e00e6788..d25ef649 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -19,7 +19,7 @@ void pull_convolutional_layer(convolutional_layer layer); void add_bias_gpu(float *output, float *biases, int batch, int n, int size); void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size); -void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, float eps, float decay, float rate, int n, int batch); +void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, float eps, float decay, float rate, int n, int batch, int t); #ifdef CUDNN void cudnn_convolutional_setup(layer *l); #endif diff --git a/src/data.c b/src/data.c index 533ae8e0..e78b17ed 100644 --- a/src/data.c +++ b/src/data.c @@ -102,7 +102,7 @@ matrix load_image_paths(char **paths, int n, int w, int h) return X; } -matrix load_image_augment_paths(char **paths, int n, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure) +matrix load_image_augment_paths(char **paths, int n, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure, int center) { int i; matrix X; @@ -112,7 +112,12 @@ matrix load_image_augment_paths(char **paths, int n, int min, int max, int size, for(i = 0; i < n; ++i){ image im = load_image_color(paths[i], 0, 0); - image crop = random_augment_image(im, angle, aspect, min, max, size); + image crop; + if(center){ + crop = center_crop_image(im, size, size); + } else { + crop = random_augment_image(im, angle, aspect, min, max, size); + } int flip = rand()%2; if (flip) flip_image(crop); random_distort_image(crop, hue, saturation, exposure); @@ -742,7 +747,7 @@ void *load_thread(void *ptr) } else if (a.type == REGRESSION_DATA){ *a.d = load_data_regression(a.paths, a.n, a.m, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure); } else if (a.type == CLASSIFICATION_DATA){ - *a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.hierarchy, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure); + *a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.hierarchy, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.center); } else if (a.type == SUPER_DATA){ *a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale); } else if (a.type == WRITING_DATA){ @@ -890,18 +895,18 @@ data load_data_regression(char **paths, int n, int m, int min, int max, int size if(m) paths = get_random_paths(paths, n, m); data d = {0}; d.shallow = 0; - d.X = load_image_augment_paths(paths, n, min, max, size, angle, aspect, hue, saturation, exposure); + d.X = load_image_augment_paths(paths, n, min, max, size, angle, aspect, hue, saturation, exposure, 0); d.y = load_regression_labels_paths(paths, n); if(m) free(paths); return d; } -data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure) +data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure, int center) { if(m) paths = get_random_paths(paths, n, m); data d = {0}; d.shallow = 0; - d.X = load_image_augment_paths(paths, n, min, max, size, angle, aspect, hue, saturation, exposure); + d.X = load_image_augment_paths(paths, n, min, max, size, angle, aspect, hue, saturation, exposure, center); d.y = load_labels_paths(paths, n, labels, k, hierarchy); if(m) free(paths); return d; @@ -914,7 +919,7 @@ data load_data_tag(char **paths, int n, int m, int k, int min, int max, int size d.w = size; d.h = size; d.shallow = 0; - d.X = load_image_augment_paths(paths, n, min, max, size, angle, aspect, hue, saturation, exposure); + d.X = load_image_augment_paths(paths, n, min, max, size, angle, aspect, hue, saturation, exposure, 0); d.y = load_tags_paths(paths, n, k); if(m) free(paths); return d; diff --git a/src/data.h b/src/data.h index 30e025c7..16b334dc 100644 --- a/src/data.h +++ b/src/data.h @@ -49,6 +49,7 @@ typedef struct load_args{ int classes; int background; int scale; + int center; float jitter; float angle; float aspect; @@ -80,9 +81,9 @@ data load_data_captcha_encode(char **paths, int n, int m, int w, int h); data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h); data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter, float hue, float saturation, float exposure); data load_data_tag(char **paths, int n, int m, int k, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure); -matrix load_image_augment_paths(char **paths, int n, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure); +matrix load_image_augment_paths(char **paths, int n, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure, int center); data load_data_super(char **paths, int n, int m, int w, int h, int scale); -data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure); +data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure, int center); data load_data_regression(char **paths, int n, int m, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure); data load_go(char *filename); diff --git a/src/deconvolutional_kernels.cu b/src/deconvolutional_kernels.cu index 55aa162c..16694634 100644 --- a/src/deconvolutional_kernels.cu +++ b/src/deconvolutional_kernels.cu @@ -114,10 +114,10 @@ void update_deconvolutional_layer_gpu(layer l, int batch, float learning_rate, f int size = l.size*l.size*l.c*l.n; if(l.adam){ - adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, size, batch); - adam_update_gpu(l.biases_gpu, l.bias_updates_gpu, l.bias_m_gpu, l.bias_v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, l.n, batch); + adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, size, batch, l.t); + adam_update_gpu(l.biases_gpu, l.bias_updates_gpu, l.bias_m_gpu, l.bias_v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, l.n, batch, l.t); if(l.scales_gpu){ - adam_update_gpu(l.scales_gpu, l.scale_updates_gpu, l.scale_m_gpu, l.scale_v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, l.n, batch); + adam_update_gpu(l.scales_gpu, l.scale_updates_gpu, l.scale_m_gpu, l.scale_v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, l.n, batch, l.t); } }else{ axpy_ongpu(size, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1); diff --git a/src/demo.c b/src/demo.c index 24c02e0e..27fcb995 100644 --- a/src/demo.c +++ b/src/demo.c @@ -10,6 +10,7 @@ #include #define FRAMES 3 +#define DEMO 1 #ifdef OPENCV @@ -37,7 +38,13 @@ static float *avg; void *fetch_in_thread(void *ptr) { - in = get_image_from_stream(cap); + image raw = get_image_from_stream(cap); + if(DEMO){ + in = center_crop_image(raw, 1440, 1080); + free_image(raw); + }else{ + in = raw; + } if(!in.data){ error("Stream closed."); } @@ -65,7 +72,7 @@ void *detect_in_thread(void *ptr) } else { error("Last layer must produce detections\n"); } - if (nms > 0) do_nms(boxes, probs, l.w*l.h*l.n, l.classes, nms); + if (nms > 0) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms); printf("\033[2J"); printf("\033[1;1H"); printf("\nFPS:%.1f\n",fps); @@ -113,6 +120,11 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch cap = cvCaptureFromFile(filename); }else{ cap = cvCaptureFromCAM(cam_index); + if(DEMO){ + cvSetCaptureProperty(cap, CV_CAP_PROP_FRAME_WIDTH, 1920); + cvSetCaptureProperty(cap, CV_CAP_PROP_FRAME_HEIGHT, 1080); + cvSetCaptureProperty(cap, CV_CAP_PROP_FPS, 60); + } } if(!cap) error("Couldn't connect to webcam.\n"); diff --git a/src/demo.h b/src/demo.h index c3d6a61a..15f21f82 100644 --- a/src/demo.h +++ b/src/demo.h @@ -1,5 +1,5 @@ -#ifndef DEMO -#define DEMO +#ifndef DEMO_H +#define DEMO_H #include "image.h" void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, float hier_thresh); diff --git a/src/detector.c b/src/detector.c index 8a429ed8..3884a0b3 100644 --- a/src/detector.c +++ b/src/detector.c @@ -625,8 +625,8 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam network_predict(net, X); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); get_region_boxes(l, 1, 1, thresh, probs, boxes, 0, 0, hier_thresh, 0); - if (l.softmax_tree && nms) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms); - else if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, l.classes, nms); + if (nms) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms); + //else if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, l.classes, nms); draw_detections(sized, l.w*l.h*l.n, thresh, boxes, probs, names, alphabet, l.classes); if(outfile){ save_image(sized, outfile); diff --git a/src/image.c b/src/image.c index 8b2fae90..6ddf7985 100644 --- a/src/image.c +++ b/src/image.c @@ -187,6 +187,7 @@ void draw_detections(image im, int num, float thresh, box *boxes, float **probs, alphabet = 0; } + //printf("%d %s: %.0f%%\n", i, names[class], prob*100); printf("%s: %.0f%%\n", names[class], prob*100); int offset = class*123457 % classes; float red = get_color(2,offset,classes); @@ -641,6 +642,14 @@ void place_image(image im, int w, int h, int dx, int dy, image canvas) } } +image center_crop_image(image im, int w, int h) +{ + int m = (im.w < im.h) ? im.w : im.h; + image c = crop_image(im, (im.w - m) / 2, (im.h - m)/2, m, m); + image r = resize_image(c, w, h); + free_image(c); + return r; +} image rotate_crop_image(image im, float rad, float s, int w, int h, float dx, float dy, float aspect) { diff --git a/src/image.h b/src/image.h index 3109094e..fd4ca414 100644 --- a/src/image.h +++ b/src/image.h @@ -44,6 +44,7 @@ void draw_detections(image im, int num, float thresh, box *boxes, float **probs, image image_distance(image a, image b); void scale_image(image m, float s); image crop_image(image im, int dx, int dy, int w, int h); +image center_crop_image(image im, int w, int h); image random_crop_image(image im, int w, int h); image random_augment_image(image im, float angle, float aspect, int low, int high, int size); void random_distort_image(image im, float hue, float saturation, float exposure); diff --git a/src/lsd.c b/src/lsd.c index 8801c70f..312679ae 100644 --- a/src/lsd.c +++ b/src/lsd.c @@ -599,14 +599,14 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, aloss_avg = aloss_avg*.9 + aloss*.1; printf("%d: adv: %f | adv_avg: %f, %f rate, %lf seconds, %d images\n", i, aloss, aloss_avg, get_current_rate(gnet), sec(clock()-time), i*imgs); - if(i%1000==0){ + if(i%10000==0){ char buff[256]; sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); save_weights(gnet, buff); sprintf(buff, "%s/%s_%d.weights", backup_directory, abase, i); save_weights(anet, buff); } - if(i%100==0){ + if(i%1000==0){ char buff[256]; sprintf(buff, "%s/%s.backup", backup_directory, base); save_weights(gnet, buff); @@ -620,8 +620,7 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, #endif } -/* -void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int clear) +void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int clear, int display) { #ifdef GPU //char *train_images = "/home/pjreddie/data/coco/train1.txt"; @@ -668,31 +667,19 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle pthread_t load_thread = load_data_in_thread(args); clock_t time; - network_state gstate = {0}; - gstate.index = 0; - gstate.net = net; - int x_size = get_network_input_size(net)*net.batch; + int x_size = net.inputs*net.batch; int y_size = x_size; - gstate.input = cuda_make_array(0, x_size); - gstate.truth = cuda_make_array(0, y_size); - gstate.delta = 0; - gstate.train = 1; + net.delta = 0; + net.train = 1; float *pixs = calloc(x_size, sizeof(float)); float *graypixs = calloc(x_size, sizeof(float)); float *y = calloc(y_size, sizeof(float)); - network_state astate = {0}; - astate.index = 0; - astate.net = anet; - int ay_size = get_network_output_size(anet)*anet.batch; - astate.input = 0; - astate.truth = 0; - astate.delta = 0; - astate.train = 1; + int ay_size = anet.outputs*anet.batch; + anet.delta = 0; + anet.train = 1; float *imerror = cuda_make_array(0, imlayer.outputs*imlayer.batch); - float *ones_gpu = cuda_make_array(0, ay_size); - fill_ongpu(ay_size, 1, ones_gpu, 1); float aloss_avg = -1; float gloss_avg = -1; @@ -712,8 +699,8 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle for(j = 0; j < imgs; ++j){ image gim = float_to_image(net.w, net.h, net.c, gray.X.vals[j]); grayscale_image_3c(gim); - train.y.vals[j][0] = 1; - gray.y.vals[j][0] = 0; + train.y.vals[j][0] = .95; + gray.y.vals[j][0] = .05; } time=clock(); float gloss = 0; @@ -721,9 +708,8 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle for(j = 0; j < net.subdivisions; ++j){ get_next_batch(train, net.batch, j*net.batch, pixs, 0); get_next_batch(gray, net.batch, j*net.batch, graypixs, 0); - cuda_push_array(gstate.input, graypixs, x_size); - cuda_push_array(gstate.truth, pixs, y_size); - */ + cuda_push_array(net.input_gpu, graypixs, net.inputs*net.batch); + cuda_push_array(net.truth_gpu, pixs, net.truths*net.batch); /* image origi = float_to_image(net.w, net.h, 3, pixs); image grayi = float_to_image(net.w, net.h, 3, graypixs); @@ -731,16 +717,15 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle show_image(origi, "orig"); cvWaitKey(0); */ - /* *net.seen += net.batch; - forward_network_gpu(net, gstate); + forward_network_gpu(net); fill_ongpu(imlayer.outputs*imlayer.batch, 0, imerror, 1); - astate.input = imlayer.output_gpu; - astate.delta = imerror; - astate.truth = ones_gpu; - forward_network_gpu(anet, astate); - backward_network_gpu(anet, astate); + copy_ongpu(anet.inputs*anet.batch, imlayer.output_gpu, 1, anet.input_gpu, 1); + fill_ongpu(anet.inputs*anet.batch, .95, anet.truth_gpu, 1); + anet.delta_gpu = imerror; + forward_network_gpu(anet); + backward_network_gpu(anet); scal_ongpu(imlayer.outputs*imlayer.batch, 1./100., net.layers[net.n-1].delta_gpu, 1); @@ -751,12 +736,11 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle axpy_ongpu(imlayer.outputs*imlayer.batch, 1, imerror, 1, net.layers[net.n-1].delta_gpu, 1); - backward_network_gpu(net, gstate); + backward_network_gpu(net); - gloss += get_network_cost(net) /(net.subdivisions*net.batch); + gloss += *net.cost /(net.subdivisions*net.batch); - cuda_pull_array(imlayer.output_gpu, imlayer.output, imlayer.outputs*imlayer.batch); for(k = 0; k < net.batch; ++k){ int index = j*net.batch + k; copy_cpu(imlayer.outputs, imlayer.output + k*imlayer.outputs, 1, gray.X.vals[index], 1); @@ -769,6 +753,16 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle float aloss = train_network(anet, merge); update_network_gpu(net); + + #ifdef OPENCV + if(display){ + image im = float_to_image(anet.w, anet.h, anet.c, gray.X.vals[0]); + image im2 = float_to_image(anet.w, anet.h, anet.c, train.X.vals[0]); + show_image(im, "gen"); + show_image(im2, "train"); + cvWaitKey(50); + } + #endif free_data(merge); free_data(train); free_data(gray); @@ -797,7 +791,6 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle save_weights(net, buff); #endif } -*/ /* void train_lsd2(char *cfgfile, char *weightfile, char *acfgfile, char *aweightfile, int clear) @@ -1136,6 +1129,7 @@ void run_lsd(int argc, char **argv) //else if(0==strcmp(argv[2], "traincolor")) train_colorizer(cfg, weights, acfg, aweights, clear); //else if(0==strcmp(argv[2], "train3")) train_lsd3(argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], clear); if(0==strcmp(argv[2], "traingan")) train_dcgan(cfg, weights, acfg, aweights, clear, display, file); + else if(0==strcmp(argv[2], "traincolor")) train_colorizer(cfg, weights, acfg, aweights, clear, display); else if(0==strcmp(argv[2], "gan")) test_dcgan(cfg, weights); else if(0==strcmp(argv[2], "test")) test_lsd(cfg, weights, filename, 0); else if(0==strcmp(argv[2], "color")) test_lsd(cfg, weights, filename, 1); diff --git a/src/network.c b/src/network.c index abf1b8aa..2b21338d 100644 --- a/src/network.c +++ b/src/network.c @@ -42,6 +42,7 @@ load_args get_base_args(network net) args.angle = net.angle; args.aspect = net.aspect; args.exposure = net.exposure; + args.center = net.center; args.saturation = net.saturation; args.hue = net.hue; return args; @@ -385,7 +386,7 @@ image get_network_image_layer(network net, int i) { layer l = net.layers[i]; #ifdef GPU - cuda_pull_array(l.output_gpu, l.output, l.outputs); + //cuda_pull_array(l.output_gpu, l.output, l.outputs); #endif if (l.out_w && l.out_h && l.out_c){ return float_to_image(l.out_w, l.out_h, l.out_c, l.output); diff --git a/src/network.h b/src/network.h index d27119ed..64d0b61c 100644 --- a/src/network.h +++ b/src/network.h @@ -47,6 +47,7 @@ typedef struct network{ int h, w, c; int max_crop; int min_crop; + int center; float angle; float aspect; float exposure; diff --git a/src/nightmare.c b/src/nightmare.c index 4bcf1877..3ab735ce 100644 --- a/src/nightmare.c +++ b/src/nightmare.c @@ -2,6 +2,7 @@ #include "parser.h" #include "blas.h" #include "utils.h" +#include "region_layer.h" // ./darknet nightmare cfg/extractor.recon.cfg ~/trained/yolo-coco.conv frame6.png -reconstruct -iters 500 -i 3 -lambda .1 -rate .01 -smooth 2 @@ -137,11 +138,11 @@ void reconstruct_picture(network net, float *features, image recon, image update #ifdef GPU cuda_push_array(net.input_gpu, recon.data, recon.w*recon.h*recon.c); - cuda_push_array(net.truth_gpu, features, net.truths); + //cuda_push_array(net.truth_gpu, features, net.truths); net.delta_gpu = cuda_make_array(delta.data, delta.w*delta.h*delta.c); forward_network_gpu(net); - copy_ongpu(l.outputs, net.truth_gpu, 1, l.delta_gpu, 1); + cuda_push_array(l.delta_gpu, features, l.outputs); axpy_ongpu(l.outputs, -1, l.output_gpu, 1, l.delta_gpu, 1); backward_network_gpu(net); @@ -157,13 +158,15 @@ void reconstruct_picture(network net, float *features, image recon, image update backward_network(net); #endif + //normalize_array(delta.data, delta.w*delta.h*delta.c); axpy_cpu(recon.w*recon.h*recon.c, 1, delta.data, 1, update.data, 1); - smooth(recon, update, lambda, smooth_size); + //smooth(recon, update, lambda, smooth_size); axpy_cpu(recon.w*recon.h*recon.c, rate, update.data, 1, recon.data, 1); scal_cpu(recon.w*recon.h*recon.c, momentum, update.data, 1); - //float mag = mag_array(recon.data, recon.w*recon.h*recon.c); + float mag = mag_array(delta.data, recon.w*recon.h*recon.c); + printf("mag: %f\n", mag); //scal_cpu(recon.w*recon.h*recon.c, 600/mag, recon.data, 1); constrain_image(recon); @@ -330,28 +333,33 @@ void run_nightmare(int argc, char **argv) image update; if (reconstruct){ net.n = max_layer; - resize_network(&net, im.w, im.h); + im = letterbox_image(im, net.w, net.h); + //resize_network(&net, im.w, im.h); - int zz = 0; network_predict(net, im.data); - image out_im = get_network_image(net); - image crop = crop_image(out_im, zz, zz, out_im.w-2*zz, out_im.h-2*zz); + if(net.layers[net.n-1].type == REGION){ + printf("region!\n"); + zero_objectness(net.layers[net.n-1]); + } + image out_im = copy_image(get_network_image(net)); + /* + image crop = crop_image(out_im, zz, zz, out_im.w-2*zz, out_im.h-2*zz); //flip_image(crop); image f_im = resize_image(crop, out_im.w, out_im.h); free_image(crop); + */ printf("%d features\n", out_im.w*out_im.h*out_im.c); + features = out_im.data; - im = resize_image(im, im.w, im.h); - f_im = resize_image(f_im, f_im.w, f_im.h); - features = f_im.data; - + /* int i; - for(i = 0; i < 14*14*512; ++i){ - //features[i] += rand_uniform(-.19, .19); + for(i = 0; i < 14*14*512; ++i){ + //features[i] += rand_uniform(-.19, .19); } free_image(im); im = make_random_image(im.w, im.h, im.c); + */ update = make_image(im.w, im.h, im.c); } diff --git a/src/parser.c b/src/parser.c index 2c03c7f5..4add3573 100644 --- a/src/parser.c +++ b/src/parser.c @@ -558,6 +558,7 @@ void parse_net_options(list *options, network *net) net->inputs = option_find_int_quiet(options, "inputs", net->h * net->w * net->c); net->max_crop = option_find_int_quiet(options, "max_crop",net->w*2); net->min_crop = option_find_int_quiet(options, "min_crop",net->w); + net->center = option_find_int_quiet(options, "center",0); net->angle = option_find_float_quiet(options, "angle", 0); net->aspect = option_find_float_quiet(options, "aspect", 1); @@ -835,8 +836,8 @@ void save_convolutional_weights(layer l, FILE *fp) } fwrite(l.weights, sizeof(float), num, fp); if(l.adam){ - fwrite(l.m, sizeof(float), num, fp); - fwrite(l.v, sizeof(float), num, fp); + //fwrite(l.m, sizeof(float), num, fp); + //fwrite(l.v, sizeof(float), num, fp); } } @@ -1039,8 +1040,8 @@ void load_convolutional_weights(layer l, FILE *fp) } fread(l.weights, sizeof(float), num, fp); if(l.adam){ - fread(l.m, sizeof(float), num, fp); - fread(l.v, sizeof(float), num, fp); + //fread(l.m, sizeof(float), num, fp); + //fread(l.v, sizeof(float), num, fp); } //if(l.c == 3) scal_cpu(num, 1./256, l.weights, 1); if (l.flipped) { diff --git a/src/region_layer.c b/src/region_layer.c index 3931b988..275f2323 100644 --- a/src/region_layer.c +++ b/src/region_layer.c @@ -384,14 +384,24 @@ void get_region_boxes(layer l, int w, int h, float thresh, float **probs, box *b probs[index][l.classes] = scale; } } else { + float max = 0; for(j = 0; j < l.classes; ++j){ int class_index = entry_index(l, 0, n*l.w*l.h + i, 5 + j); float prob = scale*predictions[class_index]; probs[index][j] = (prob > thresh) ? prob : 0; + if(prob > max) max = prob; // TODO REMOVE // if (j != 15 && j != 16) probs[index][j] = 0; - // if (j != 0) probs[index][j] = 0; + /* + if (j != 0) probs[index][j] = 0; + int blacklist[] = {121, 497, 482, 504, 122, 518,481, 418, 542, 491, 914, 478, 120, 510,500}; + int bb; + for (bb = 0; bb < sizeof(blacklist)/sizeof(int); ++bb){ + if(index == blacklist[bb]) probs[index][j] = 0; + } + */ } + probs[index][l.classes] = max; } if(only_objectness){ probs[index][0] = scale; @@ -461,3 +471,14 @@ void backward_region_layer_gpu(const layer l, network net) } #endif +void zero_objectness(layer l) +{ + int i, n; + for (i = 0; i < l.w*l.h; ++i){ + for(n = 0; n < l.n; ++n){ + int obj_index = entry_index(l, 0, n*l.w*l.h + i, 4); + l.output[obj_index] = 0; + } + } +} + diff --git a/src/region_layer.h b/src/region_layer.h index 6375445f..5177d3be 100644 --- a/src/region_layer.h +++ b/src/region_layer.h @@ -9,6 +9,7 @@ void forward_region_layer(const layer l, network net); void backward_region_layer(const layer l, network net); void get_region_boxes(layer l, int w, int h, float thresh, float **probs, box *boxes, int only_objectness, int *map, float tree_thresh, int nomult); void resize_region_layer(layer *l, int w, int h); +void zero_objectness(layer l); #ifdef GPU void forward_region_layer_gpu(const layer l, network net); diff --git a/src/utils.c b/src/utils.c index b5181d78..9fa2d6bf 100644 --- a/src/utils.c +++ b/src/utils.c @@ -9,6 +9,29 @@ #include "utils.h" +int *read_intlist(char *gpu_list, int *ngpus, int d) +{ + int *gpus = 0; + if(gpu_list){ + int len = strlen(gpu_list); + *ngpus = 1; + int i; + for(i = 0; i < len; ++i){ + if (gpu_list[i] == ',') ++*ngpus; + } + gpus = calloc(*ngpus, sizeof(int)); + for(i = 0; i < *ngpus; ++i){ + gpus[i] = atoi(gpu_list); + gpu_list = strchr(gpu_list, ',')+1; + } + } else { + gpus = calloc(1, sizeof(float)); + *gpus = d; + *ngpus = 1; + } + return gpus; +} + int *read_map(char *filename) { int n = 0; diff --git a/src/utils.h b/src/utils.h index bbc67654..ab4c6959 100644 --- a/src/utils.h +++ b/src/utils.h @@ -7,6 +7,7 @@ #define SECRET_NUM -1234 #define TWO_PI 6.2831853071795864769252866 +int *read_intlist(char *s, int *n, int d); int *read_map(char *filename); void shuffle(void *arr, size_t n, size_t size); void sorta_shuffle(void *arr, size_t n, size_t size, size_t sections);