diff --git a/Makefile b/Makefile index b1f93c4c..a1e4604f 100644 --- a/Makefile +++ b/Makefile @@ -57,8 +57,8 @@ CFLAGS+= -DCUDNN LDFLAGS+= -lcudnn endif -OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o detection_layer.o route_layer.o upsample_layer.o box.o normalization_layer.o avgpool_layer.o layer.o local_layer.o shortcut_layer.o logistic_layer.o activation_layer.o rnn_layer.o gru_layer.o crnn_layer.o demo.o batchnorm_layer.o region_layer.o reorg_layer.o tree.o lstm_layer.o l2norm_layer.o yolo_layer.o -EXECOBJA=captcha.o lsd.o super.o art.o tag.o cifar.o go.o rnn.o segmenter.o regressor.o classifier.o coco.o yolo.o detector.o nightmare.o darknet.o +OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o detection_layer.o route_layer.o upsample_layer.o box.o normalization_layer.o avgpool_layer.o layer.o local_layer.o shortcut_layer.o logistic_layer.o activation_layer.o rnn_layer.o gru_layer.o crnn_layer.o demo.o batchnorm_layer.o region_layer.o reorg_layer.o tree.o lstm_layer.o l2norm_layer.o yolo_layer.o iseg_layer.o +EXECOBJA=captcha.o lsd.o super.o art.o tag.o cifar.o go.o rnn.o segmenter.o regressor.o classifier.o coco.o yolo.o detector.o nightmare.o instance-segmenter.o darknet.o ifeq ($(GPU), 1) LDFLAGS+= -lstdc++ OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o avgpool_layer_kernels.o diff --git a/examples/art.c b/examples/art.c index 7d58c5d9..ce885540 100644 --- a/examples/art.c +++ b/examples/art.c @@ -24,7 +24,6 @@ void demo_art(char *cfgfile, char *weightfile, int cam_index) while(1){ image in = get_image_from_stream(cap); image in_s = resize_image(in, net->w, net->h); - show_image(in, window); float *p = network_predict(net, in_s.data); @@ -45,10 +44,9 @@ void demo_art(char *cfgfile, char *weightfile, int cam_index) } printf("]\n"); + show_image(in, window, 1); free_image(in_s); free_image(in); - - cvWaitKey(1); } #endif } diff --git a/examples/classifier.c b/examples/classifier.c index d118ea58..2a89da62 100644 --- a/examples/classifier.c +++ b/examples/classifier.c @@ -645,6 +645,45 @@ void label_classifier(char *datacfg, char *filename, char *weightfile) } } +void csv_classifier(char *datacfg, char *cfgfile, char *weightfile) +{ + int i,j; + network *net = load_network(cfgfile, weightfile, 0); + srand(time(0)); + + list *options = read_data_cfg(datacfg); + + char *test_list = option_find_str(options, "test", "data/test.list"); + int top = option_find_int(options, "top", 1); + + list *plist = get_paths(test_list); + + char **paths = (char **)list_to_array(plist); + int m = plist->size; + free_list(plist); + int *indexes = calloc(top, sizeof(int)); + + for(i = 0; i < m; ++i){ + double time = what_time_is_it_now(); + char *path = paths[i]; + image im = load_image_color(path, 0, 0); + image r = letterbox_image(im, net->w, net->h); + float *predictions = network_predict(net, r.data); + if(net->hierarchy) hierarchy_predictions(predictions, net->outputs, net->hierarchy, 1, 1); + top_k(predictions, net->outputs, top, indexes); + + printf("%s", path); + for(j = 0; j < top; ++j){ + printf("\t%d", indexes[j]); + } + printf("\n"); + + free_image(im); + free_image(r); + + fprintf(stderr, "%lf seconds, %d images, %d total\n", what_time_is_it_now() - time, i+1, m); + } +} void test_classifier(char *datacfg, char *cfgfile, char *weightfile, int target_layer) { @@ -869,8 +908,7 @@ void threat_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_i } if(1){ - show_image(out, "Threat"); - cvWaitKey(10); + show_image(out, "Threat", 10); } free_image(in_s); free_image(in); @@ -922,7 +960,6 @@ void gun_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_inde image in = get_image_from_stream(cap); image in_s = resize_image(in, net->w, net->h); - show_image(in, "Threat Detection"); float *predictions = network_predict(net, in_s.data); top_predictions(net, top, indexes); @@ -947,11 +984,10 @@ void gun_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_inde } } + show_image(in, "Threat Detection", 10); free_image(in_s); free_image(in); - cvWaitKey(10); - gettimeofday(&tval_after, NULL); timersub(&tval_after, &tval_before, &tval_result); float curr = 1000000.f/((long int)tval_result.tv_usec); @@ -1036,12 +1072,10 @@ void demo_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_ind free_image(label); } - show_image(in, base); + show_image(in, base, 10); free_image(in_s); free_image(in); - cvWaitKey(10); - gettimeofday(&tval_after, NULL); timersub(&tval_after, &tval_before, &tval_result); float curr = 1000000.f/((long int)tval_result.tv_usec); @@ -1080,6 +1114,7 @@ void run_classifier(int argc, char **argv) else if(0==strcmp(argv[2], "gun")) gun_classifier(data, cfg, weights, cam_index, filename); else if(0==strcmp(argv[2], "threat")) threat_classifier(data, cfg, weights, cam_index, filename); else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights, layer); + else if(0==strcmp(argv[2], "csv")) csv_classifier(data, cfg, weights); else if(0==strcmp(argv[2], "label")) label_classifier(data, cfg, weights); else if(0==strcmp(argv[2], "valid")) validate_classifier_single(data, cfg, weights); else if(0==strcmp(argv[2], "validmulti")) validate_classifier_multi(data, cfg, weights); diff --git a/examples/coco.c b/examples/coco.c index 6b01dcd6..6a50b89a 100644 --- a/examples/coco.c +++ b/examples/coco.c @@ -325,14 +325,10 @@ void test_coco(char *cfgfile, char *weightfile, char *filename, float thresh) draw_detections(im, dets, l.side*l.side*l.n, thresh, coco_classes, alphabet, 80); save_image(im, "prediction"); - show_image(im, "predictions"); + show_image(im, "predictions", 0); free_detections(dets, nboxes); free_image(im); free_image(sized); -#ifdef OPENCV - cvWaitKey(0); - cvDestroyAllWindows(); -#endif if (filename) break; } } diff --git a/examples/darknet.c b/examples/darknet.c index 51e04c13..99b0d64f 100644 --- a/examples/darknet.c +++ b/examples/darknet.c @@ -14,6 +14,7 @@ extern void run_nightmare(int argc, char **argv); extern void run_classifier(int argc, char **argv); extern void run_regressor(int argc, char **argv); extern void run_segmenter(int argc, char **argv); +extern void run_isegmenter(int argc, char **argv); extern void run_char_rnn(int argc, char **argv); extern void run_tag(int argc, char **argv); extern void run_cifar(int argc, char **argv); @@ -452,6 +453,8 @@ int main(int argc, char **argv) run_classifier(argc, argv); } else if (0 == strcmp(argv[1], "regressor")){ run_regressor(argc, argv); + } else if (0 == strcmp(argv[1], "isegmenter")){ + run_isegmenter(argc, argv); } else if (0 == strcmp(argv[1], "segmenter")){ run_segmenter(argc, argv); } else if (0 == strcmp(argv[1], "art")){ diff --git a/examples/detector.c b/examples/detector.c index 326e07f1..b503fc60 100644 --- a/examples/detector.c +++ b/examples/detector.c @@ -613,9 +613,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam if(fullscreen){ cvSetWindowProperty("predictions", CV_WND_PROP_FULLSCREEN, CV_WINDOW_FULLSCREEN); } - show_image(im, "predictions"); - cvWaitKey(0); - cvDestroyAllWindows(); + show_image(im, "predictions", 0); #endif } diff --git a/examples/instance-segmenter.c b/examples/instance-segmenter.c new file mode 100644 index 00000000..96d8545f --- /dev/null +++ b/examples/instance-segmenter.c @@ -0,0 +1,265 @@ +#include "darknet.h" +#include +#include + +void train_isegmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int display) +{ + int i; + + float avg_loss = -1; + char *base = basecfg(cfgfile); + printf("%s\n", base); + printf("%d\n", ngpus); + network **nets = calloc(ngpus, sizeof(network*)); + + srand(time(0)); + int seed = rand(); + for(i = 0; i < ngpus; ++i){ + srand(seed); +#ifdef GPU + cuda_set_device(gpus[i]); +#endif + nets[i] = load_network(cfgfile, weightfile, clear); + nets[i]->learning_rate *= ngpus; + } + srand(time(0)); + network *net = nets[0]; + image pred = get_network_image(net); + + int div = net->w/pred.w; + assert(pred.w * div == net->w); + assert(pred.h * div == net->h); + + int imgs = net->batch * net->subdivisions * ngpus; + + printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay); + list *options = read_data_cfg(datacfg); + + char *backup_directory = option_find_str(options, "backup", "/backup/"); + char *train_list = option_find_str(options, "train", "data/train.list"); + + list *plist = get_paths(train_list); + char **paths = (char **)list_to_array(plist); + printf("%d\n", plist->size); + int N = plist->size; + + load_args args = {0}; + args.w = net->w; + args.h = net->h; + args.threads = 32; + args.scale = div; + args.num_boxes = 90; + + args.min = net->min_crop; + args.max = net->max_crop; + args.angle = net->angle; + args.aspect = net->aspect; + args.exposure = net->exposure; + args.saturation = net->saturation; + args.hue = net->hue; + args.size = net->w; + args.classes = 80; + + args.paths = paths; + args.n = imgs; + args.m = N; + args.type = ISEG_DATA; + + data train; + data buffer; + pthread_t load_thread; + args.d = &buffer; + load_thread = load_data(args); + + int epoch = (*net->seen)/N; + while(get_current_batch(net) < net->max_batches || net->max_batches == 0){ + double time = what_time_is_it_now(); + + pthread_join(load_thread, 0); + train = buffer; + load_thread = load_data(args); + + printf("Loaded: %lf seconds\n", what_time_is_it_now()-time); + time = what_time_is_it_now(); + + float loss = 0; +#ifdef GPU + if(ngpus == 1){ + loss = train_network(net, train); + } else { + loss = train_networks(nets, ngpus, train, 4); + } +#else + loss = train_network(net, train); +#endif + if(display){ + image tr = float_to_image(net->w/div, net->h/div, 80, train.y.vals[net->batch*(net->subdivisions-1)]); + image im = float_to_image(net->w, net->h, net->c, train.X.vals[net->batch*(net->subdivisions-1)]); + pred.c = 80; + image mask = mask_to_rgb(tr); + image prmask = mask_to_rgb(pred); + show_image(im, "input", 1); + show_image(prmask, "pred", 1); + show_image(mask, "truth", 100); + free_image(mask); + free_image(prmask); + } + if(avg_loss == -1) avg_loss = loss; + avg_loss = avg_loss*.9 + loss*.1; + printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, *net->seen); + free_data(train); + if(*net->seen/N > epoch){ + epoch = *net->seen/N; + char buff[256]; + sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch); + save_weights(net, buff); + } + if(get_current_batch(net)%100 == 0){ + char buff[256]; + sprintf(buff, "%s/%s.backup",backup_directory,base); + save_weights(net, buff); + } + } + char buff[256]; + sprintf(buff, "%s/%s.weights", backup_directory, base); + save_weights(net, buff); + + free_network(net); + free_ptrs((void**)paths, plist->size); + free_list(plist); + free(base); +} + +void predict_isegmenter(char *datafile, char *cfg, char *weights, char *filename) +{ + network *net = load_network(cfg, weights, 0); + set_batch_network(net, 1); + srand(2222222); + + clock_t time; + char buff[256]; + char *input = buff; + while(1){ + if(filename){ + strncpy(input, filename, 256); + }else{ + printf("Enter Image Path: "); + fflush(stdout); + input = fgets(input, 256, stdin); + if(!input) return; + strtok(input, "\n"); + } + image im = load_image_color(input, 0, 0); + image sized = letterbox_image(im, net->w, net->h); + + float *X = sized.data; + time=clock(); + float *predictions = network_predict(net, X); + image pred = get_network_image(net); + image prmask = mask_to_rgb(pred); + printf("Predicted: %f\n", predictions[0]); + printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); + show_image(sized, "orig", 1); + show_image(prmask, "pred", 0); + free_image(im); + free_image(sized); + free_image(prmask); + if (filename) break; + } +} + + +void demo_isegmenter(char *datacfg, char *cfg, char *weights, int cam_index, const char *filename) +{ +#ifdef OPENCV + printf("Classifier Demo\n"); + network *net = load_network(cfg, weights, 0); + set_batch_network(net, 1); + + srand(2222222); + CvCapture * cap; + + if(filename){ + cap = cvCaptureFromFile(filename); + }else{ + cap = cvCaptureFromCAM(cam_index); + } + + if(!cap) error("Couldn't connect to webcam.\n"); + cvNamedWindow("Segmenter", CV_WINDOW_NORMAL); + cvResizeWindow("Segmenter", 512, 512); + float fps = 0; + + while(1){ + struct timeval tval_before, tval_after, tval_result; + gettimeofday(&tval_before, NULL); + + image in = get_image_from_stream(cap); + image in_s = letterbox_image(in, net->w, net->h); + + network_predict(net, in_s.data); + + printf("\033[2J"); + printf("\033[1;1H"); + printf("\nFPS:%.0f\n",fps); + + image pred = get_network_image(net); + image prmask = mask_to_rgb(pred); + show_image(prmask, "Segmenter", 10); + + free_image(in_s); + free_image(in); + free_image(prmask); + + gettimeofday(&tval_after, NULL); + timersub(&tval_after, &tval_before, &tval_result); + float curr = 1000000.f/((long int)tval_result.tv_usec); + fps = .9*fps + .1*curr; + } +#endif +} + + +void run_isegmenter(int argc, char **argv) +{ + if(argc < 4){ + fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); + return; + } + + char *gpu_list = find_char_arg(argc, argv, "-gpus", 0); + int *gpus = 0; + int gpu = 0; + int ngpus = 0; + if(gpu_list){ + printf("%s\n", gpu_list); + int len = strlen(gpu_list); + ngpus = 1; + int i; + for(i = 0; i < len; ++i){ + if (gpu_list[i] == ',') ++ngpus; + } + gpus = calloc(ngpus, sizeof(int)); + for(i = 0; i < ngpus; ++i){ + gpus[i] = atoi(gpu_list); + gpu_list = strchr(gpu_list, ',')+1; + } + } else { + gpu = gpu_index; + gpus = &gpu; + ngpus = 1; + } + + int cam_index = find_int_arg(argc, argv, "-c", 0); + int clear = find_arg(argc, argv, "-clear"); + int display = find_arg(argc, argv, "-display"); + char *data = argv[3]; + char *cfg = argv[4]; + char *weights = (argc > 5) ? argv[5] : 0; + char *filename = (argc > 6) ? argv[6]: 0; + if(0==strcmp(argv[2], "test")) predict_isegmenter(data, cfg, weights, filename); + else if(0==strcmp(argv[2], "train")) train_isegmenter(data, cfg, weights, gpus, ngpus, clear, display); + else if(0==strcmp(argv[2], "demo")) demo_isegmenter(data, cfg, weights, cam_index, filename); +} + + diff --git a/examples/lsd.c b/examples/lsd.c index ef46c459..4ab944c8 100644 --- a/examples/lsd.c +++ b/examples/lsd.c @@ -460,13 +460,9 @@ void inter_dcgan(char *cfgfile, char *weightfile) printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); //char buff[256]; sprintf(buff, "out%05d", c); - show_image(out, "out"); save_image(out, "out"); save_image(out, buff); -#ifdef OPENCV - //cvWaitKey(0); -#endif - + show_image(out, "out", 0); } } @@ -499,11 +495,8 @@ void test_dcgan(char *cfgfile, char *weightfile) //yuv_to_rgb(out); normalize_image(out); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); - show_image(out, "out"); save_image(out, "out"); -#ifdef OPENCV - cvWaitKey(0); -#endif + show_image(out, "out", 0); free_image(im); } @@ -639,11 +632,10 @@ void train_prog(char *cfg, char *weight, char *acfg, char *aweight, int clear, i if(display){ image im = float_to_image(anet->w, anet->h, anet->c, gen.X.vals[0]); image im2 = float_to_image(anet->w, anet->h, anet->c, train.X.vals[0]); - show_image(im, "gen"); - show_image(im2, "train"); + show_image(im, "gen", 1); + show_image(im2, "train", 1); save_image(im, "gen"); save_image(im2, "train"); - cvWaitKey(1); } #endif @@ -826,11 +818,10 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, if(display){ image im = float_to_image(anet->w, anet->h, anet->c, gen.X.vals[0]); image im2 = float_to_image(anet->w, anet->h, anet->c, train.X.vals[0]); - show_image(im, "gen"); - show_image(im2, "train"); + show_image(im, "gen", 1); + show_image(im2, "train", 1); save_image(im, "gen"); save_image(im2, "train"); - cvWaitKey(1); } #endif @@ -1010,9 +1001,8 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle if(display){ image im = float_to_image(anet->w, anet->h, anet->c, gray.X.vals[0]); image im2 = float_to_image(anet->w, anet->h, anet->c, train.X.vals[0]); - show_image(im, "gen"); - show_image(im2, "train"); - cvWaitKey(1); + show_image(im, "gen", 1); + show_image(im2, "train", 1); } #endif free_data(merge); @@ -1342,12 +1332,9 @@ void test_lsd(char *cfg, char *weights, char *filename, int gray) //yuv_to_rgb(out); constrain_image(out); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); - show_image(out, "out"); - show_image(crop, "crop"); save_image(out, "out"); -#ifdef OPENCV - cvWaitKey(0); -#endif + show_image(out, "out", 1); + show_image(crop, "crop", 0); free_image(im); free_image(resized); diff --git a/examples/nightmare.c b/examples/nightmare.c index 8ec6e966..2978eb61 100644 --- a/examples/nightmare.c +++ b/examples/nightmare.c @@ -376,10 +376,7 @@ void run_nightmare(int argc, char **argv) if(reconstruct){ reconstruct_picture(net, features, im, update, rate, momentum, lambda, smooth_size, 1); //if ((n+1)%30 == 0) rate *= .5; - show_image(im, "reconstruction"); -#ifdef OPENCV - cvWaitKey(10); -#endif + show_image(im, "reconstruction", 10); }else{ int layer = max_layer + rand()%range - range/2; int octave = rand()%octaves; @@ -400,8 +397,7 @@ void run_nightmare(int argc, char **argv) } printf("%d %s\n", e, buff); save_image(im, buff); - //show_image(im, buff); - //cvWaitKey(0); + //show_image(im, buff, 0); if(rotate){ image rot = rotate_image(im, rotate); diff --git a/examples/regressor.c b/examples/regressor.c index 60a9f2b9..0bd7aba0 100644 --- a/examples/regressor.c +++ b/examples/regressor.c @@ -179,7 +179,6 @@ void demo_regressor(char *datacfg, char *cfgfile, char *weightfile, int cam_inde image in = get_image_from_stream(cap); image crop = center_crop_image(in, net->w, net->h); grayscale_image_3c(crop); - show_image(crop, "Regressor"); float *predictions = network_predict(net, crop.data); @@ -192,11 +191,10 @@ void demo_regressor(char *datacfg, char *cfgfile, char *weightfile, int cam_inde printf("%s: %f\n", names[i], predictions[i]); } + show_image(crop, "Regressor", 10); free_image(in); free_image(crop); - cvWaitKey(10); - gettimeofday(&tval_after, NULL); timersub(&tval_after, &tval_before, &tval_result); float curr = 1000000.f/((long int)tval_result.tv_usec); diff --git a/examples/segmenter.c b/examples/segmenter.c index d73aceb3..eadfd69a 100644 --- a/examples/segmenter.c +++ b/examples/segmenter.c @@ -42,7 +42,6 @@ void train_segmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus, char **paths = (char **)list_to_array(plist); printf("%d\n", plist->size); int N = plist->size; - clock_t time; load_args args = {0}; args.w = net->w; @@ -73,14 +72,14 @@ void train_segmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int epoch = (*net->seen)/N; while(get_current_batch(net) < net->max_batches || net->max_batches == 0){ - time=clock(); + double time = what_time_is_it_now(); pthread_join(load_thread, 0); train = buffer; load_thread = load_data(args); - printf("Loaded: %lf seconds\n", sec(clock()-time)); - time=clock(); + printf("Loaded: %lf seconds\n", what_time_is_it_now()-time); + time = what_time_is_it_now(); float loss = 0; #ifdef GPU @@ -97,18 +96,15 @@ void train_segmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus, image im = float_to_image(net->w, net->h, net->c, train.X.vals[net->batch*(net->subdivisions-1)]); image mask = mask_to_rgb(tr); image prmask = mask_to_rgb(pred); - show_image(im, "input"); - show_image(prmask, "pred"); - show_image(mask, "truth"); -#ifdef OPENCV - cvWaitKey(100); -#endif + show_image(im, "input", 1); + show_image(prmask, "pred", 1); + show_image(mask, "truth", 100); free_image(mask); free_image(prmask); } if(avg_loss == -1) avg_loss = loss; avg_loss = avg_loss*.9 + loss*.1; - printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net->seen); + printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, *net->seen); free_data(train); if(*net->seen/N > epoch){ epoch = *net->seen/N; @@ -159,13 +155,10 @@ void predict_segmenter(char *datafile, char *cfg, char *weights, char *filename) float *predictions = network_predict(net, X); image pred = get_network_image(net); image prmask = mask_to_rgb(pred); - show_image(sized, "orig"); - show_image(prmask, "pred"); -#ifdef OPENCV - cvWaitKey(0); -#endif printf("Predicted: %f\n", predictions[0]); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); + show_image(sized, "orig", 1); + show_image(prmask, "pred", 0); free_image(im); free_image(sized); free_image(prmask); @@ -210,14 +203,12 @@ void demo_segmenter(char *datacfg, char *cfg, char *weights, int cam_index, cons image pred = get_network_image(net); image prmask = mask_to_rgb(pred); - show_image(prmask, "Segmenter"); + show_image(prmask, "Segmenter", 10); free_image(in_s); free_image(in); free_image(prmask); - cvWaitKey(10); - gettimeofday(&tval_after, NULL); timersub(&tval_after, &tval_before, &tval_result); float curr = 1000000.f/((long int)tval_result.tv_usec); diff --git a/examples/super.c b/examples/super.c index 89a8e562..d34406b1 100644 --- a/examples/super.c +++ b/examples/super.c @@ -93,7 +93,7 @@ void test_super(char *cfgfile, char *weightfile, char *filename) image out = get_network_image(net); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); save_image(out, "out"); - show_image(out, "out"); + show_image(out, "out", 0); free_image(im); if (filename) break; diff --git a/examples/yolo.c b/examples/yolo.c index aa728163..4ddb69a3 100644 --- a/examples/yolo.c +++ b/examples/yolo.c @@ -296,14 +296,10 @@ void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh) draw_detections(im, dets, l.side*l.side*l.n, thresh, voc_names, alphabet, 20); save_image(im, "predictions"); - show_image(im, "predictions"); + show_image(im, "predictions", 0); free_detections(dets, nboxes); free_image(im); free_image(sized); -#ifdef OPENCV - cvWaitKey(0); - cvDestroyAllWindows(); -#endif if (filename) break; } } diff --git a/include/darknet.h b/include/darknet.h index 9327ea8c..abe4e5bf 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -86,6 +86,7 @@ typedef enum { XNOR, REGION, YOLO, + ISEG, REORG, UPSAMPLE, LOGXENT, @@ -166,6 +167,7 @@ struct layer{ float ratio; float learning_rate_scale; float clip; + int noloss; int softmax; int classes; int coords; @@ -203,6 +205,7 @@ struct layer{ int dontload; int dontsave; int dontloadscales; + int numload; float temperature; float probability; @@ -213,6 +216,8 @@ struct layer{ int * input_layers; int * input_sizes; int * map; + int * counts; + float ** sums; float * rand; float * cost; float * state; @@ -540,7 +545,7 @@ typedef struct{ } data; typedef enum { - CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA, STUDY_DATA, DET_DATA, SUPER_DATA, LETTERBOX_DATA, REGRESSION_DATA, SEGMENTATION_DATA, INSTANCE_DATA + CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA, STUDY_DATA, DET_DATA, SUPER_DATA, LETTERBOX_DATA, REGRESSION_DATA, SEGMENTATION_DATA, INSTANCE_DATA, ISEG_DATA } data_type; typedef struct load_args{ @@ -705,7 +710,7 @@ int resize_network(network *net, int w, int h); void free_matrix(matrix m); void test_resize(char *filename); void save_image(image p, const char *name); -void show_image(image p, const char *name); +int show_image(image p, const char *name, int ms); image copy_image(image p); void draw_box_width(image a, int x1, int y1, int x2, int y2, int w, float r, float g, float b); float get_current_rate(network *net); diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 5ac9ef0d..1fb58b09 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -151,7 +151,7 @@ void cudnn_convolutional_setup(layer *l) l->convDesc, l->dstTensorDesc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - 4000000000, + 2000000000, &l->fw_algo); cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(), l->weightDesc, @@ -159,7 +159,7 @@ void cudnn_convolutional_setup(layer *l) l->convDesc, l->dsrcTensorDesc, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - 4000000000, + 2000000000, &l->bd_algo); cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(), l->srcTensorDesc, @@ -167,7 +167,7 @@ void cudnn_convolutional_setup(layer *l) l->convDesc, l->dweightDesc, CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - 4000000000, + 2000000000, &l->bf_algo); } #endif diff --git a/src/data.c b/src/data.c index 51900f26..a5d69f26 100644 --- a/src/data.c +++ b/src/data.c @@ -361,6 +361,44 @@ box bound_image(image im) } void fill_truth_iseg(char *path, int num_boxes, float *truth, int classes, int w, int h, augment_args aug, int flip, int mw, int mh) +{ + char labelpath[4096]; + find_replace(path, "images", "mask", labelpath); + find_replace(labelpath, "JPEGImages", "mask", labelpath); + find_replace(labelpath, ".jpg", ".txt", labelpath); + find_replace(labelpath, ".JPG", ".txt", labelpath); + find_replace(labelpath, ".JPEG", ".txt", labelpath); + FILE *file = fopen(labelpath, "r"); + if(!file) file_error(labelpath); + char buff[32788]; + int id; + int i = 0; + int j; + image part = make_image(w, h, 1); + while((fscanf(file, "%d %s", &id, buff) == 2) && i < num_boxes){ + int n = 0; + int *rle = read_intlist(buff, &n, 0); + load_rle(part, rle, n); + image sized = rotate_crop_image(part, aug.rad, aug.scale, aug.w, aug.h, aug.dx, aug.dy, aug.aspect); + if(flip) flip_image(sized); + + image mask = resize_image(sized, mw, mh); + truth[i*(mw*mh+1)] = id; + for(j = 0; j < mw*mh; ++j){ + truth[i*(mw*mh + 1) + 1 + j] = mask.data[j]; + } + ++i; + + free_image(mask); + free_image(sized); + free(rle); + } + if(i < num_boxes) truth[i*(mw*mh+1)] = -1; + fclose(file); + free_image(part); +} + +void fill_truth_mask(char *path, int num_boxes, float *truth, int classes, int w, int h, augment_args aug, int flip, int mw, int mh) { char labelpath[4096]; find_replace(path, "images", "mask", labelpath); @@ -743,7 +781,47 @@ data load_data_seg(int n, char **paths, int m, int w, int h, int classes, int mi return d; } -data load_data_iseg(int n, char **paths, int m, int w, int h, int classes, int boxes, int coords, int min, int max, float angle, float aspect, float hue, float saturation, float exposure) +data load_data_iseg(int n, char **paths, int m, int w, int h, int classes, int boxes, int div, int min, int max, float angle, float aspect, float hue, float saturation, float exposure) +{ + char **random_paths = get_random_paths(paths, n, m); + int i; + data d = {0}; + d.shallow = 0; + + d.X.rows = n; + d.X.vals = calloc(d.X.rows, sizeof(float*)); + d.X.cols = h*w*3; + + d.y = make_matrix(n, (((w/div)*(h/div))+1)*boxes); + + for(i = 0; i < n; ++i){ + image orig = load_image_color(random_paths[i], 0, 0); + augment_args a = random_augment_args(orig, angle, aspect, min, max, w, h); + image sized = rotate_crop_image(orig, a.rad, a.scale, a.w, a.h, a.dx, a.dy, a.aspect); + + int flip = rand()%2; + if(flip) flip_image(sized); + random_distort_image(sized, hue, saturation, exposure); + d.X.vals[i] = sized.data; + //show_image(sized, "image"); + + fill_truth_iseg(random_paths[i], boxes, d.y.vals[i], classes, orig.w, orig.h, a, flip, w/div, h/div); + + free_image(orig); + + /* + image rgb = mask_to_rgb(sized_m, classes); + show_image(rgb, "part"); + show_image(sized, "orig"); + cvWaitKey(0); + free_image(rgb); + */ + } + free(random_paths); + return d; +} + +data load_data_mask(int n, char **paths, int m, int w, int h, int classes, int boxes, int coords, int min, int max, float angle, float aspect, float hue, float saturation, float exposure) { char **random_paths = get_random_paths(paths, n, m); int i; @@ -767,7 +845,7 @@ data load_data_iseg(int n, char **paths, int m, int w, int h, int classes, int b d.X.vals[i] = sized.data; //show_image(sized, "image"); - fill_truth_iseg(random_paths[i], boxes, d.y.vals[i], classes, orig.w, orig.h, a, flip, 14, 14); + fill_truth_mask(random_paths[i], boxes, d.y.vals[i], classes, orig.w, orig.h, a, flip, 14, 14); free_image(orig); @@ -975,7 +1053,8 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, in float dh = jitter * orig.h; float new_ar = (orig.w + rand_uniform(-dw, dw)) / (orig.h + rand_uniform(-dh, dh)); - float scale = rand_uniform(.25, 2); + //float scale = rand_uniform(.25, 2); + float scale = 1; float nw, nh; @@ -1025,8 +1104,10 @@ void *load_thread(void *ptr) *a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale); } else if (a.type == WRITING_DATA){ *a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h); + } else if (a.type == ISEG_DATA){ + *a.d = load_data_iseg(a.n, a.paths, a.m, a.w, a.h, a.classes, a.num_boxes, a.scale, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure); } else if (a.type == INSTANCE_DATA){ - *a.d = load_data_iseg(a.n, a.paths, a.m, a.w, a.h, a.classes, a.num_boxes, a.coords, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure); + *a.d = load_data_mask(a.n, a.paths, a.m, a.w, a.h, a.classes, a.num_boxes, a.coords, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure); } else if (a.type == SEGMENTATION_DATA){ *a.d = load_data_seg(a.n, a.paths, a.m, a.w, a.h, a.classes, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.scale); } else if (a.type == REGION_DATA){ @@ -1212,7 +1293,7 @@ data *tile_data(data orig, int divs, int size) { data *ds = calloc(divs*divs, sizeof(data)); int i, j; - #pragma omp parallel for +#pragma omp parallel for for(i = 0; i < divs*divs; ++i){ data d; d.shallow = 0; @@ -1223,7 +1304,7 @@ data *tile_data(data orig, int divs, int size) d.X.vals = calloc(d.X.rows, sizeof(float*)); d.y = copy_matrix(orig.y); - #pragma omp parallel for +#pragma omp parallel for for(j = 0; j < orig.X.rows; ++j){ int x = (i%divs) * orig.w / divs - (d.w - orig.w/divs)/2; int y = (i/divs) * orig.h / divs - (d.h - orig.h/divs)/2; @@ -1247,7 +1328,7 @@ data resize_data(data orig, int w, int h) d.X.vals = calloc(d.X.rows, sizeof(float*)); d.y = copy_matrix(orig.y); - #pragma omp parallel for +#pragma omp parallel for for(i = 0; i < orig.X.rows; ++i){ image im = float_to_image(orig.w, orig.h, 3, orig.X.vals[i]); d.X.vals[i] = resize_image(im, w, h).data; diff --git a/src/image.c b/src/image.c index 786acb08..c112dd81 100644 --- a/src/image.c +++ b/src/image.c @@ -572,7 +572,7 @@ void show_image_cv(image p, const char *name, IplImage *disp) } #endif -void show_image(image p, const char *name) +int show_image(image p, const char *name, int ms) { #ifdef OPENCV IplImage *disp = cvCreateImage(cvSize(p.w,p.h), IPL_DEPTH_8U, p.c); @@ -581,9 +581,13 @@ void show_image(image p, const char *name) show_image_cv(copy, name, disp); free_image(copy); cvReleaseImage(&disp); + int c = cvWaitKey(ms); + if (c != -1) c = c%256; + return c; #else fprintf(stderr, "Not compiled with OpenCV, saving to %s.png instead\n", name); save_image(p, name); + return 0; #endif } @@ -727,7 +731,7 @@ void show_image_layers(image p, char *name) for(i = 0; i < p.c; ++i){ sprintf(buff, "%s - Layer %d", name, i); image layer = get_image_layer(p, i); - show_image(layer, buff); + show_image(layer, buff, 1); free_image(layer); } } @@ -735,7 +739,7 @@ void show_image_layers(image p, char *name) void show_image_collapsed(image p, char *name) { image c = collapse_image_layers(p, 1); - show_image(c, name); + show_image(c, name, 1); free_image(c); } @@ -1406,16 +1410,16 @@ void test_resize(char *filename) distort_image(c4, .1, .66666, 1.5); - show_image(im, "Original"); - show_image(gray, "Gray"); - show_image(c1, "C1"); - show_image(c2, "C2"); - show_image(c3, "C3"); - show_image(c4, "C4"); + show_image(im, "Original", 1); + show_image(gray, "Gray", 1); + show_image(c1, "C1", 1); + show_image(c2, "C2", 1); + show_image(c3, "C3", 1); + show_image(c4, "C4", 1); #ifdef OPENCV while(1){ image aug = random_augment_image(im, 0, .75, 320, 448, 320, 320); - show_image(aug, "aug"); + show_image(aug, "aug", 1); free_image(aug); @@ -1430,7 +1434,7 @@ void test_resize(char *filename) float dhue = rand_uniform(-hue, hue); distort_image(c, dhue, dsat, dexp); - show_image(c, "rand"); + show_image(c, "rand", 1); printf("%f %f %f\n", dhue, dsat, dexp); free_image(c); cvWaitKey(0); @@ -1585,7 +1589,7 @@ void show_image_normalized(image im, const char *name) { image c = copy_image(im); normalize_image(c); - show_image(c, name); + show_image(c, name, 1); free_image(c); } @@ -1603,7 +1607,7 @@ void show_images(image *ims, int n, char *window) */ normalize_image(m); save_image(m, window); - show_image(m, window); + show_image(m, window, 1); free_image(m); } diff --git a/src/iseg_layer.c b/src/iseg_layer.c new file mode 100644 index 00000000..7c31d0d5 --- /dev/null +++ b/src/iseg_layer.c @@ -0,0 +1,219 @@ +#include "iseg_layer.h" +#include "activations.h" +#include "blas.h" +#include "box.h" +#include "cuda.h" +#include "utils.h" + +#include +#include +#include +#include + +layer make_iseg_layer(int batch, int w, int h, int classes, int ids) +{ + layer l = {0}; + l.type = ISEG; + + l.h = h; + l.w = w; + l.c = classes + ids; + l.out_w = l.w; + l.out_h = l.h; + l.out_c = l.c; + l.classes = classes; + l.batch = batch; + l.extra = ids; + l.cost = calloc(1, sizeof(float)); + l.outputs = h*w*l.c; + l.inputs = l.outputs; + l.truths = 90*(l.w*l.h+1); + l.delta = calloc(batch*l.outputs, sizeof(float)); + l.output = calloc(batch*l.outputs, sizeof(float)); + + l.counts = calloc(90, sizeof(int)); + l.sums = calloc(90, sizeof(float*)); + if(ids){ + int i; + for(i = 0; i < 90; ++i){ + l.sums[i] = calloc(ids, sizeof(float)); + } + } + + l.forward = forward_iseg_layer; + l.backward = backward_iseg_layer; +#ifdef GPU + l.forward_gpu = forward_iseg_layer_gpu; + l.backward_gpu = backward_iseg_layer_gpu; + l.output_gpu = cuda_make_array(l.output, batch*l.outputs); + l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs); +#endif + + fprintf(stderr, "iseg\n"); + srand(0); + + return l; +} + +void resize_iseg_layer(layer *l, int w, int h) +{ + l->w = w; + l->h = h; + + l->outputs = h*w*l->c; + l->inputs = l->outputs; + + l->output = realloc(l->output, l->batch*l->outputs*sizeof(float)); + l->delta = realloc(l->delta, l->batch*l->outputs*sizeof(float)); + +#ifdef GPU + cuda_free(l->delta_gpu); + cuda_free(l->output_gpu); + + l->delta_gpu = cuda_make_array(l->delta, l->batch*l->outputs); + l->output_gpu = cuda_make_array(l->output, l->batch*l->outputs); +#endif +} + +void forward_iseg_layer(const layer l, network net) +{ + + double time = what_time_is_it_now(); + int i,b,j,k; + int ids = l.extra; + memcpy(l.output, net.input, l.outputs*l.batch*sizeof(float)); + memset(l.delta, 0, l.outputs * l.batch * sizeof(float)); + +#ifndef GPU + for (b = 0; b < l.batch; ++b){ + int index = b*l.outputs; + activate_array(l.output + index, l.classes*l.w*l.h, LOGISTIC); + } +#endif + + for (b = 0; b < l.batch; ++b){ + // a priori, each pixel has no class + for(i = 0; i < l.classes; ++i){ + for(k = 0; k < l.w*l.h; ++k){ + int index = b*l.outputs + i*l.w*l.h + k; + l.delta[index] = 0 - l.output[index]; + } + } + + // a priori, embedding should be small magnitude + for(i = 0; i < ids; ++i){ + for(k = 0; k < l.w*l.h; ++k){ + int index = b*l.outputs + (i+l.classes)*l.w*l.h + k; + l.delta[index] = .1 * (0 - l.output[index]); + } + } + + + memset(l.counts, 0, 90*sizeof(float)); + for(i = 0; i < 90; ++i){ + l.counts[i] = 0; + fill_cpu(ids, 0, l.sums[i], 1); + + int c = net.truth[b*l.truths + i*(l.w*l.h+1)]; + if(c < 0) break; + // add up metric embeddings for each instance + for(k = 0; k < l.w*l.h; ++k){ + int index = b*l.outputs + c*l.w*l.h + k; + float v = net.truth[b*l.truths + i*(l.w*l.h + 1) + 1 + k]; + if(v){ + l.delta[index] = v - l.output[index]; + axpy_cpu(ids, 1, l.output + b*l.outputs + l.classes*l.w*l.h + k, l.w*l.h, l.sums[i], 1); + ++l.counts[i]; + } + } + } + + float *mse = calloc(90, sizeof(float)); + for(i = 0; i < 90; ++i){ + int c = net.truth[b*l.truths + i*(l.w*l.h+1)]; + if(c < 0) break; + for(k = 0; k < l.w*l.h; ++k){ + float v = net.truth[b*l.truths + i*(l.w*l.h + 1) + 1 + k]; + if(v){ + int z; + float sum = 0; + for(z = 0; z < ids; ++z){ + int index = b*l.outputs + (l.classes + z)*l.w*l.h + k; + sum += pow(l.sums[i][z]/l.counts[i] - l.output[index], 2); + } + mse[i] += sum; + } + } + mse[i] /= l.counts[i]; + } + + // Calculate average embedding + for(i = 0; i < 90; ++i){ + if(!l.counts[i]) continue; + scal_cpu(ids, 1.f/l.counts[i], l.sums[i], 1); + if(b == 0 && net.gpu_index == 0){ + printf("%4d, %6.3f, ", l.counts[i], mse[i]); + for(j = 0; j < ids/4; ++j){ + printf("%6.3f,", l.sums[i][j]); + } + printf("\n"); + } + } + free(mse); + + // Calculate embedding loss + for(i = 0; i < 90; ++i){ + if(!l.counts[i]) continue; + for(k = 0; k < l.w*l.h; ++k){ + float v = net.truth[b*l.truths + i*(l.w*l.h + 1) + 1 + k]; + if(v){ + for(j = 0; j < 90; ++j){ + if(!l.counts[j])continue; + int z; + for(z = 0; z < ids; ++z){ + int index = b*l.outputs + (l.classes + z)*l.w*l.h + k; + float diff = l.sums[j][z] - l.output[index]; + if (j == i) l.delta[index] += diff < 0? -.1 : .1; + else l.delta[index] += -(diff < 0? -.1 : .1); + } + } + } + } + } + } + + *(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2); + printf("took %lf sec\n", what_time_is_it_now() - time); +} + +void backward_iseg_layer(const layer l, network net) +{ + axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, net.delta, 1); +} + +#ifdef GPU + +void forward_iseg_layer_gpu(const layer l, network net) +{ + copy_gpu(l.batch*l.inputs, net.input_gpu, 1, l.output_gpu, 1); + int b; + for (b = 0; b < l.batch; ++b){ + activate_array_gpu(l.output_gpu + b*l.outputs, l.classes*l.w*l.h, LOGISTIC); + //if(l.extra) activate_array_gpu(l.output_gpu + b*l.outputs + l.classes*l.w*l.h, l.extra*l.w*l.h, LOGISTIC); + } + + cuda_pull_array(l.output_gpu, net.input, l.batch*l.inputs); + forward_iseg_layer(l, net); + cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs); +} + +void backward_iseg_layer_gpu(const layer l, network net) +{ + int b; + for (b = 0; b < l.batch; ++b){ + //if(l.extra) gradient_array_gpu(l.output_gpu + b*l.outputs + l.classes*l.w*l.h, l.extra*l.w*l.h, LOGISTIC, l.delta_gpu + b*l.outputs + l.classes*l.w*l.h); + } + axpy_gpu(l.batch*l.inputs, 1, l.delta_gpu, 1, net.delta_gpu, 1); +} +#endif + diff --git a/src/iseg_layer.h b/src/iseg_layer.h new file mode 100644 index 00000000..dd8e64e0 --- /dev/null +++ b/src/iseg_layer.h @@ -0,0 +1,19 @@ +#ifndef ISEG_LAYER_H +#define ISEG_LAYER_H + +#include "darknet.h" +#include "layer.h" +#include "network.h" + +layer make_iseg_layer(int batch, int w, int h, int classes, int ids); +void forward_iseg_layer(const layer l, network net); +void backward_iseg_layer(const layer l, network net); +void resize_iseg_layer(layer *l, int w, int h); +int iseg_num_detections(layer l, float thresh); + +#ifdef GPU +void forward_iseg_layer_gpu(const layer l, network net); +void backward_iseg_layer_gpu(layer l, network net); +#endif + +#endif diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c index 202c0295..971272b1 100644 --- a/src/maxpool_layer.c +++ b/src/maxpool_layer.c @@ -27,8 +27,8 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s l.w = w; l.c = c; l.pad = padding; - l.out_w = (w + 2*padding - size)/stride + 1; - l.out_h = (h + 2*padding - size)/stride + 1; + l.out_w = (w + padding - size)/stride + 1; + l.out_h = (h + padding - size)/stride + 1; l.out_c = c; l.outputs = l.out_h * l.out_w * l.out_c; l.inputs = h*w*c; @@ -57,8 +57,8 @@ void resize_maxpool_layer(maxpool_layer *l, int w, int h) l->w = w; l->inputs = h*w*l->c; - l->out_w = (w + 2*l->pad - l->size)/l->stride + 1; - l->out_h = (h + 2*l->pad - l->size)/l->stride + 1; + l->out_w = (w + l->pad - l->size)/l->stride + 1; + l->out_h = (h + l->pad - l->size)/l->stride + 1; l->outputs = l->out_w * l->out_h * l->c; int output_size = l->outputs * l->batch; @@ -79,8 +79,8 @@ void resize_maxpool_layer(maxpool_layer *l, int w, int h) void forward_maxpool_layer(const maxpool_layer l, network net) { int b,i,j,k,m,n; - int w_offset = -l.pad; - int h_offset = -l.pad; + int w_offset = -l.pad/l.stride; + int h_offset = -l.pad/l.stride; int h = l.out_h; int w = l.out_w; diff --git a/src/maxpool_layer_kernels.cu b/src/maxpool_layer_kernels.cu index e294e1e8..869ef466 100644 --- a/src/maxpool_layer_kernels.cu +++ b/src/maxpool_layer_kernels.cu @@ -9,8 +9,8 @@ extern "C" { __global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *input, float *output, int *indexes) { - int h = (in_h + 2*pad - size)/stride + 1; - int w = (in_w + 2*pad - size)/stride + 1; + int h = (in_h + pad - size)/stride + 1; + int w = (in_w + pad - size)/stride + 1; int c = in_c; int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; @@ -24,8 +24,8 @@ __global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c id /= c; int b = id; - int w_offset = -pad; - int h_offset = -pad; + int w_offset = -pad/2; + int h_offset = -pad/2; int out_index = j + w*(i + h*(k + c*b)); float max = -INFINITY; @@ -49,8 +49,8 @@ __global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c __global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *delta, float *prev_delta, int *indexes) { - int h = (in_h + 2*pad - size)/stride + 1; - int w = (in_w + 2*pad - size)/stride + 1; + int h = (in_h + pad - size)/stride + 1; + int w = (in_w + pad - size)/stride + 1; int c = in_c; int area = (size-1)/stride; @@ -66,8 +66,8 @@ __global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_ id /= in_c; int b = id; - int w_offset = -pad; - int h_offset = -pad; + int w_offset = -pad/2; + int h_offset = -pad/2; float d = 0; int l, m; diff --git a/src/parser.c b/src/parser.c index 8b43f3c5..c8141c9f 100644 --- a/src/parser.c +++ b/src/parser.c @@ -27,6 +27,7 @@ #include "parser.h" #include "region_layer.h" #include "yolo_layer.h" +#include "iseg_layer.h" #include "reorg_layer.h" #include "rnn_layer.h" #include "route_layer.h" @@ -52,6 +53,7 @@ LAYER_TYPE string_to_layer_type(char * type) if (strcmp(type, "[detection]")==0) return DETECTION; if (strcmp(type, "[region]")==0) return REGION; if (strcmp(type, "[yolo]")==0) return YOLO; + if (strcmp(type, "[iseg]")==0) return ISEG; if (strcmp(type, "[local]")==0) return LOCAL; if (strcmp(type, "[conv]")==0 || strcmp(type, "[convolutional]")==0) return CONVOLUTIONAL; @@ -265,18 +267,19 @@ layer parse_connected(list *options, size_params params) return l; } -softmax_layer parse_softmax(list *options, size_params params) +layer parse_softmax(list *options, size_params params) { int groups = option_find_int_quiet(options, "groups",1); - softmax_layer layer = make_softmax_layer(params.batch, params.inputs, groups); - layer.temperature = option_find_float_quiet(options, "temperature", 1); + layer l = make_softmax_layer(params.batch, params.inputs, groups); + l.temperature = option_find_float_quiet(options, "temperature", 1); char *tree_file = option_find_str(options, "tree", 0); - if (tree_file) layer.softmax_tree = read_tree(tree_file); - layer.w = params.w; - layer.h = params.h; - layer.c = params.c; - layer.spatial = option_find_float_quiet(options, "spatial", 0); - return layer; + if (tree_file) l.softmax_tree = read_tree(tree_file); + l.w = params.w; + l.h = params.h; + l.c = params.c; + l.spatial = option_find_float_quiet(options, "spatial", 0); + l.noloss = option_find_int_quiet(options, "noloss", 0); + return l; } int *parse_yolo_mask(char *a, int *num) @@ -338,6 +341,15 @@ layer parse_yolo(list *options, size_params params) return l; } +layer parse_iseg(list *options, size_params params) +{ + int classes = option_find_int(options, "classes", 20); + int ids = option_find_int(options, "ids", 32); + layer l = make_iseg_layer(params.batch, params.w, params.h, classes, ids); + assert(l.outputs == params.inputs); + return l; +} + layer parse_region(list *options, size_params params) { int coords = option_find_int(options, "coords", 4); @@ -472,7 +484,7 @@ maxpool_layer parse_maxpool(list *options, size_params params) { int stride = option_find_int(options, "stride",1); int size = option_find_int(options, "size",stride); - int padding = option_find_int_quiet(options, "padding", (size-1)/2); + int padding = option_find_int_quiet(options, "padding", size-1); int batch,h,w,c; h = params.h; @@ -791,6 +803,8 @@ network *parse_network_cfg(char *filename) l = parse_region(options, params); }else if(lt == YOLO){ l = parse_yolo(options, params); + }else if(lt == ISEG){ + l = parse_iseg(options, params); }else if(lt == DETECTION){ l = parse_detection(options, params); }else if(lt == SOFTMAX){ @@ -829,6 +843,7 @@ network *parse_network_cfg(char *filename) l.stopbackward = option_find_int_quiet(options, "stopbackward", 0); l.dontsave = option_find_int_quiet(options, "dontsave", 0); l.dontload = option_find_int_quiet(options, "dontload", 0); + l.numload = option_find_int_quiet(options, "numload", 0); l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0); l.learning_rate_scale = option_find_float_quiet(options, "learning_rate", 1); l.smooth = option_find_float_quiet(options, "smooth", 0); @@ -1152,7 +1167,8 @@ void load_convolutional_weights(layer l, FILE *fp) //load_convolutional_weights_binary(l, fp); //return; } - int num = l.nweights; + if(l.numload) l.n = l.numload; + int num = l.c/l.groups*l.n*l.size*l.size; fread(l.biases, sizeof(float), l.n, fp); if (l.batch_normalize && (!l.dontloadscales)){ fread(l.scales, sizeof(float), l.n, fp); diff --git a/src/softmax_layer.c b/src/softmax_layer.c index afcc6342..9cbc6be1 100644 --- a/src/softmax_layer.c +++ b/src/softmax_layer.c @@ -50,7 +50,7 @@ void forward_softmax_layer(const softmax_layer l, network net) softmax_cpu(net.input, l.inputs/l.groups, l.batch, l.inputs, l.groups, l.inputs/l.groups, 1, l.temperature, l.output); } - if(net.truth){ + if(net.truth && !l.noloss){ softmax_x_ent_cpu(l.batch*l.inputs, l.output, net.truth, l.delta, l.loss); l.cost[0] = sum_array(l.loss, l.batch*l.inputs); } @@ -88,7 +88,7 @@ void forward_softmax_layer_gpu(const softmax_layer l, network net) softmax_gpu(net.input_gpu, l.inputs/l.groups, l.batch, l.inputs, l.groups, l.inputs/l.groups, 1, l.temperature, l.output_gpu); } } - if(net.truth){ + if(net.truth && !l.noloss){ softmax_x_ent_gpu(l.batch*l.inputs, l.output_gpu, net.truth_gpu, l.delta_gpu, l.loss_gpu); if(l.softmax_tree){ mask_gpu(l.batch*l.inputs, l.delta_gpu, SECRET_NUM, net.truth_gpu, 0);