diff --git a/cfg/gru.cfg b/cfg/gru.cfg index cc331dcf..a68d3fc5 100644 --- a/cfg/gru.cfg +++ b/cfg/gru.cfg @@ -10,7 +10,7 @@ adam=1 policy=constant power=4 -max_batches=400000 +max_batches=1000000 [gru] output = 1024 diff --git a/examples/classifier.c b/examples/classifier.c index 0d28a58b..ee3f4008 100644 --- a/examples/classifier.c +++ b/examples/classifier.c @@ -33,6 +33,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus, cuda_set_device(gpus[i]); #endif nets[i] = load_network(cfgfile, weightfile, clear); + nets[i].learning_rate *= ngpus; } srand(time(0)); network net = nets[0]; diff --git a/examples/darknet.c b/examples/darknet.c index 92f42bca..a7ce1482 100644 --- a/examples/darknet.c +++ b/examples/darknet.c @@ -337,7 +337,7 @@ void denormalize_net(char *cfgfile, char *weightfile, char *outfile) int i; for (i = 0; i < net.n; ++i) { layer l = net.layers[i]; - if (l.type == CONVOLUTIONAL && l.batch_normalize) { + if ((l.type == DECONVOLUTIONAL || l.type == CONVOLUTIONAL) && l.batch_normalize) { denormalize_convolutional_layer(l); net.layers[i].batch_normalize=0; } diff --git a/examples/detector.c b/examples/detector.c index 4bec0507..8debcc61 100644 --- a/examples/detector.c +++ b/examples/detector.c @@ -23,6 +23,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i cuda_set_device(gpus[i]); #endif nets[i] = load_network(cfgfile, weightfile, clear); + nets[i].learning_rate *= ngpus; } srand(time(0)); network net = nets[0]; diff --git a/examples/rnn.c b/examples/rnn.c index 8d1fa242..af0f8985 100644 --- a/examples/rnn.c +++ b/examples/rnn.c @@ -43,6 +43,7 @@ char **read_tokens(char *filename, size_t *read) size = size*2; d = realloc(d, size*sizeof(char *)); } + if(0==strcmp(line, "")) line = "\n"; d[count-1] = line; } fclose(fp); @@ -190,7 +191,7 @@ void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, for(j = 0; j < streams; ++j){ //printf("%d\n", j); - if(rand()%10 == 0){ + if(rand()%64 == 0){ //fprintf(stderr, "Reset\n"); offsets[j] = rand_size_t()%size; reset_rnn_state(net, j); diff --git a/examples/segmenter.c b/examples/segmenter.c index e3804d37..2c1979d4 100644 --- a/examples/segmenter.c +++ b/examples/segmenter.c @@ -2,7 +2,7 @@ #include #include -void train_segmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear) +void train_segmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int display) { int i; @@ -95,9 +95,9 @@ void train_segmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus, #else loss = train_network(net, train); #endif - if(1){ - image tr = float_to_image(net.w/div, net.h/div, 80, train.y.vals[net.batch]); - image im = float_to_image(net.w, net.h, net.c, train.X.vals[net.batch]); + if(display){ + image tr = float_to_image(net.w/div, net.h/div, 80, train.y.vals[net.batch*(net.subdivisions-1)]); + image im = float_to_image(net.w, net.h, net.c, train.X.vals[net.batch*(net.subdivisions-1)]); image mask = mask_to_rgb(tr); image prmask = mask_to_rgb(pred); show_image(im, "input"); @@ -163,10 +163,10 @@ void predict_segmenter(char *datafile, char *cfgfile, char *weightfile, char *fi float *X = sized.data; time=clock(); float *predictions = network_predict(net, X); - image m = float_to_image(sized.w, sized.h, 81, predictions); - image rgb = mask_to_rgb(m); + image pred = get_network_image(net); + image prmask = mask_to_rgb(pred); show_image(sized, "orig"); - show_image(rgb, "pred"); + show_image(prmask, "pred"); #ifdef OPENCV cvWaitKey(0); #endif @@ -174,7 +174,7 @@ void predict_segmenter(char *datafile, char *cfgfile, char *weightfile, char *fi printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); free_image(im); free_image(sized); - free_image(rgb); + free_image(prmask); if (filename) break; } } @@ -183,7 +183,7 @@ void predict_segmenter(char *datafile, char *cfgfile, char *weightfile, char *fi void demo_segmenter(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename) { #ifdef OPENCV - printf("Regressor Demo\n"); + printf("Classifier Demo\n"); network net = parse_network_cfg(cfgfile); if(weightfile){ load_weights(&net, weightfile); @@ -200,8 +200,8 @@ void demo_segmenter(char *datacfg, char *cfgfile, char *weightfile, int cam_inde } if(!cap) error("Couldn't connect to webcam.\n"); - cvNamedWindow("Regressor", CV_WINDOW_NORMAL); - cvResizeWindow("Regressor", 512, 512); + cvNamedWindow("Segmenter", CV_WINDOW_NORMAL); + cvResizeWindow("Segmenter", 512, 512); float fps = 0; while(1){ @@ -210,7 +210,6 @@ void demo_segmenter(char *datacfg, char *cfgfile, char *weightfile, int cam_inde image in = get_image_from_stream(cap); image in_s = letterbox_image(in, net.w, net.h); - show_image(in, "Regressor"); float *predictions = network_predict(net, in_s.data); @@ -218,10 +217,13 @@ void demo_segmenter(char *datacfg, char *cfgfile, char *weightfile, int cam_inde printf("\033[1;1H"); printf("\nFPS:%.0f\n",fps); - printf("People: %f\n", predictions[0]); - + image pred = get_network_image(net); + image prmask = mask_to_rgb(pred); + show_image(prmask, "Segmenter"); + free_image(in_s); free_image(in); + free_image(prmask); cvWaitKey(10); @@ -266,12 +268,13 @@ void run_segmenter(int argc, char **argv) int cam_index = find_int_arg(argc, argv, "-c", 0); int clear = find_arg(argc, argv, "-clear"); + int display = find_arg(argc, argv, "-display"); char *data = argv[3]; char *cfg = argv[4]; char *weights = (argc > 5) ? argv[5] : 0; char *filename = (argc > 6) ? argv[6]: 0; if(0==strcmp(argv[2], "test")) predict_segmenter(data, cfg, weights, filename); - else if(0==strcmp(argv[2], "train")) train_segmenter(data, cfg, weights, gpus, ngpus, clear); + else if(0==strcmp(argv[2], "train")) train_segmenter(data, cfg, weights, gpus, ngpus, clear, display); else if(0==strcmp(argv[2], "demo")) demo_segmenter(data, cfg, weights, cam_index, filename); } diff --git a/src/data.c b/src/data.c index 80ca1313..0f566060 100644 --- a/src/data.c +++ b/src/data.c @@ -623,6 +623,54 @@ data load_data_seg(int n, char **paths, int m, int w, int h, int classes, int mi d.X.cols = h*w*3; + d.y.rows = n; + d.y.cols = h*w*classes/div/div; + d.y.vals = calloc(d.X.rows, sizeof(float*)); + + for(i = 0; i < n; ++i){ + image orig = load_image_color(random_paths[i], 0, 0); + augment_args a = random_augment_args(orig, angle, aspect, min, max, w, h); + image sized = rotate_crop_image(orig, a.rad, a.scale, a.w, a.h, a.dx, a.dy, a.aspect); + + int flip = rand()%2; + if(flip) flip_image(sized); + random_distort_image(sized, hue, saturation, exposure); + d.X.vals[i] = sized.data; + + image mask = get_segmentation_image(random_paths[i], orig.w, orig.h, classes); + //image mask = make_image(orig.w, orig.h, classes+1); + image sized_m = rotate_crop_image(mask, a.rad, a.scale/div, a.w/div, a.h/div, a.dx/div, a.dy/div, a.aspect); + + if(flip) flip_image(sized_m); + d.y.vals[i] = sized_m.data; + + free_image(orig); + free_image(mask); + + /* + image rgb = mask_to_rgb(sized_m, classes); + show_image(rgb, "part"); + show_image(sized, "orig"); + cvWaitKey(0); + free_image(rgb); + */ + } + free(random_paths); + return d; +} + +data load_data_iseg(int n, char **paths, int m, int w, int h, int classes, int min, int max, float angle, float aspect, float hue, float saturation, float exposure, int div) +{ + char **random_paths = get_random_paths(paths, n, m); + int i; + data d = {0}; + d.shallow = 0; + + d.X.rows = n; + d.X.vals = calloc(d.X.rows, sizeof(float*)); + d.X.cols = h*w*3; + + d.y.rows = n; d.y.cols = h*w*classes/div/div; d.y.vals = calloc(d.X.rows, sizeof(float*)); diff --git a/src/deconvolutional_layer.c b/src/deconvolutional_layer.c index 0959d738..674ce6b3 100644 --- a/src/deconvolutional_layer.c +++ b/src/deconvolutional_layer.c @@ -112,20 +112,20 @@ layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size l.output_gpu = cuda_make_array(l.output, l.batch*l.out_h*l.out_w*n); if(batch_normalize){ - l.mean_gpu = cuda_make_array(l.mean, n); - l.variance_gpu = cuda_make_array(l.variance, n); + l.mean_gpu = cuda_make_array(0, n); + l.variance_gpu = cuda_make_array(0, n); - l.rolling_mean_gpu = cuda_make_array(l.mean, n); - l.rolling_variance_gpu = cuda_make_array(l.variance, n); + l.rolling_mean_gpu = cuda_make_array(0, n); + l.rolling_variance_gpu = cuda_make_array(0, n); - l.mean_delta_gpu = cuda_make_array(l.mean, n); - l.variance_delta_gpu = cuda_make_array(l.variance, n); + l.mean_delta_gpu = cuda_make_array(0, n); + l.variance_delta_gpu = cuda_make_array(0, n); - l.scales_gpu = cuda_make_array(l.scales, n); - l.scale_updates_gpu = cuda_make_array(l.scale_updates, n); + l.scales_gpu = cuda_make_array(0, n); + l.scale_updates_gpu = cuda_make_array(0, n); - l.x_gpu = cuda_make_array(l.output, l.batch*l.out_h*l.out_w*n); - l.x_norm_gpu = cuda_make_array(l.output, l.batch*l.out_h*l.out_w*n); + l.x_gpu = cuda_make_array(0, l.batch*l.out_h*l.out_w*n); + l.x_norm_gpu = cuda_make_array(0, l.batch*l.out_h*l.out_w*n); } } #ifdef CUDNN @@ -144,6 +144,21 @@ layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size return l; } +void denormalize_deconvolutional_layer(layer l) +{ + int i, j; + for(i = 0; i < l.n; ++i){ + float scale = l.scales[i]/sqrt(l.rolling_variance[i] + .00001); + for(j = 0; j < l.c*l.size*l.size; ++j){ + l.weights[i*l.c*l.size*l.size + j] *= scale; + } + l.biases[i] -= l.rolling_mean[i] * scale; + l.scales[i] = 1; + l.rolling_mean[i] = 0; + l.rolling_variance[i] = 1; + } +} + void resize_deconvolutional_layer(layer *l, int h, int w) { l->h = h; diff --git a/src/image.c b/src/image.c index 146ec355..c28fd5c0 100644 --- a/src/image.c +++ b/src/image.c @@ -199,7 +199,7 @@ void draw_detections(image im, int num, float thresh, box *boxes, float **probs, float prob = probs[i][class]; if(prob > thresh){ - int width = im.h * .012; + int width = im.h * .006; if(0){ width = pow(prob, 1./2.)*10+1; diff --git a/src/network.c b/src/network.c index 0d2773e8..424057b9 100644 --- a/src/network.c +++ b/src/network.c @@ -312,6 +312,11 @@ void set_batch_network(network *net, int b) if(net->layers[i].type == CONVOLUTIONAL){ cudnn_convolutional_setup(net->layers + i); } + if(net->layers[i].type == DECONVOLUTIONAL){ + layer *l = net->layers + i; + cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, l->out_h, l->out_w); + cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1); + } #endif } } diff --git a/src/network_kernels.cu b/src/network_kernels.cu index 5af37608..df859d70 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -383,7 +383,6 @@ float train_networks(network *nets, int n, data d, int interval) float sum = 0; for(i = 0; i < n; ++i){ - nets[i].learning_rate *= n; data p = get_data_part(d, i, n); threads[i] = train_network_in_thread(nets[i], p, errors + i); } diff --git a/src/region_layer.c b/src/region_layer.c index c090075e..86ec4d47 100644 --- a/src/region_layer.c +++ b/src/region_layer.c @@ -152,13 +152,13 @@ void forward_region_layer(const layer l, network net) for(n = 0; n < l.n; ++n){ int index = entry_index(l, b, n*l.w*l.h, 0); activate_array(l.output + index, 2*l.w*l.h, LOGISTIC); - index = entry_index(l, b, n*l.w*l.h, 4); + index = entry_index(l, b, n*l.w*l.h, l.coords); if(!l.background) activate_array(l.output + index, l.w*l.h, LOGISTIC); } } if (l.softmax_tree){ int i; - int count = 5; + int count = l.coords + 1; for (i = 0; i < l.softmax_tree->groups; ++i) { int group_size = l.softmax_tree->group_size[i]; softmax_cpu(net.input + count, group_size, l.batch, l.inputs, l.n*l.w*l.h, 1, l.n*l.w*l.h, l.temperature, l.output + count); @@ -186,13 +186,13 @@ void forward_region_layer(const layer l, network net) for(t = 0; t < 30; ++t){ box truth = float_to_box(net.truth + t*(l.coords + 1) + b*l.truths, 1); if(!truth.x) break; - int class = net.truth[t*(l.coords + 1) + b*l.truths + 4]; + int class = net.truth[t*(l.coords + 1) + b*l.truths + l.coords]; float maxp = 0; int maxi = 0; if(truth.x > 100000 && truth.y > 100000){ for(n = 0; n < l.n*l.w*l.h; ++n){ - int class_index = entry_index(l, b, n, 5); - int obj_index = entry_index(l, b, n, 4); + int class_index = entry_index(l, b, n, l.coords + 1); + int obj_index = entry_index(l, b, n, l.coords); float scale = l.output[obj_index]; l.delta[obj_index] = l.noobject_scale * (0 - l.output[obj_index]); float p = scale*get_hierarchy_probability(l.output + class_index, l.softmax_tree, class, l.w*l.h); @@ -201,8 +201,8 @@ void forward_region_layer(const layer l, network net) maxi = n; } } - int class_index = entry_index(l, b, maxi, 5); - int obj_index = entry_index(l, b, maxi, 4); + int class_index = entry_index(l, b, maxi, l.coords + 1); + int obj_index = entry_index(l, b, maxi, l.coords); delta_region_class(l.output, l.delta, class_index, class, l.classes, l.softmax_tree, l.class_scale, l.w*l.h, &avg_cat); if(l.output[obj_index] < .3) l.delta[obj_index] = l.object_scale * (.3 - l.output[obj_index]); else l.delta[obj_index] = 0; @@ -220,14 +220,14 @@ void forward_region_layer(const layer l, network net) box pred = get_region_box(l.output, l.biases, n, box_index, i, j, l.w, l.h, l.w*l.h); float best_iou = 0; for(t = 0; t < 30; ++t){ - box truth = float_to_box(net.truth + t*5 + b*l.truths, 1); + box truth = float_to_box(net.truth + t*(l.coords + 1) + b*l.truths, 1); if(!truth.x) break; float iou = box_iou(pred, truth); if (iou > best_iou) { best_iou = iou; } } - int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4); + int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, l.coords); avg_anyobj += l.output[obj_index]; l.delta[obj_index] = l.noobject_scale * (0 - l.output[obj_index]); if(l.background) l.delta[obj_index] = l.noobject_scale * (1 - l.output[obj_index]); @@ -247,7 +247,7 @@ void forward_region_layer(const layer l, network net) } } for(t = 0; t < 30; ++t){ - box truth = float_to_box(net.truth + t*5 + b*l.truths, 1); + box truth = float_to_box(net.truth + t*(l.coords + 1) + b*l.truths, 1); if(!truth.x) break; float best_iou = 0; @@ -356,7 +356,7 @@ void get_region_boxes(layer l, int w, int h, int netw, int neth, float thresh, f for (j = 0; j < l.h; ++j) { for (i = 0; i < l.w/2; ++i) { for (n = 0; n < l.n; ++n) { - for(z = 0; z < l.classes + 5; ++z){ + for(z = 0; z < l.classes + l.coords + 1; ++z){ int i1 = z*l.w*l.h*l.n + n*l.w*l.h + j*l.w + i; int i2 = z*l.w*l.h*l.n + n*l.w*l.h + j*l.w + (l.w - i - 1); float swap = flip[i1]; @@ -382,7 +382,7 @@ void get_region_boxes(layer l, int w, int h, int netw, int neth, float thresh, f for(j = 0; j < l.classes; ++j){ probs[index][j] = 0; } - int obj_index = entry_index(l, 0, n*l.w*l.h + i, 4); + int obj_index = entry_index(l, 0, n*l.w*l.h + i, l.coords); int box_index = entry_index(l, 0, n*l.w*l.h + i, 0); float scale = l.background ? 1 : predictions[obj_index]; boxes[index] = get_region_box(predictions, l.biases, n, box_index, col, row, l.w, l.h, l.w*l.h); @@ -393,7 +393,7 @@ void get_region_boxes(layer l, int w, int h, int netw, int neth, float thresh, f hierarchy_predictions(predictions + class_index, l.classes, l.softmax_tree, 0, l.w*l.h); if(map){ for(j = 0; j < 200; ++j){ - int class_index = entry_index(l, 0, n*l.w*l.h + i, 5 + map[j]); + int class_index = entry_index(l, 0, n*l.w*l.h + i, l.coords + 1 + map[j]); float prob = scale*predictions[class_index]; probs[index][j] = (prob > thresh) ? prob : 0; } @@ -405,7 +405,7 @@ void get_region_boxes(layer l, int w, int h, int netw, int neth, float thresh, f } else { float max = 0; for(j = 0; j < l.classes; ++j){ - int class_index = entry_index(l, 0, n*l.w*l.h + i, 5 + j); + int class_index = entry_index(l, 0, n*l.w*l.h + i, l.coords + 1 + j); float prob = scale*predictions[class_index]; probs[index][j] = (prob > thresh) ? prob : 0; if(prob > max) max = prob; @@ -454,7 +454,7 @@ void forward_region_layer_gpu(const layer l, network net) if (group_size > mmax) mmax = group_size; } printf("%d %d %d \n", l.softmax_tree->groups, mmin, mmax); - int index = entry_index(l, 0, 0, 5); + int index = entry_index(l, 0, 0, l.coords + 1); softmax_tree(net.input_gpu + index, l.w*l.h, l.batch*l.n, l.inputs/l.n, 1, l.output_gpu + index, *l.softmax_tree); /* // TIMING CODE @@ -559,7 +559,7 @@ void zero_objectness(layer l) int i, n; for (i = 0; i < l.w*l.h; ++i){ for(n = 0; n < l.n; ++n){ - int obj_index = entry_index(l, 0, n*l.w*l.h + i, 4); + int obj_index = entry_index(l, 0, n*l.w*l.h + i, l.coords); l.output[obj_index] = 0; } }