diff --git a/examples/detector.c b/examples/detector.c index 318f7fbb..2f2e330b 100644 --- a/examples/detector.c +++ b/examples/detector.c @@ -486,14 +486,18 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile, char *out fprintf(stderr, "Total Detection Time: %f Seconds\n", what_time_is_it_now() - start); } -void validate_detector_recall(char *cfgfile, char *weightfile) +void validate_detector_recall(char *datacfg, char *cfgfile, char *weightfile) { network *net = load_network(cfgfile, weightfile, 0); set_batch_network(net, 1); fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay); srand(time(0)); - list *plist = get_paths("data/coco_val_5k.list"); + // list *plist = get_paths("data/coco_val_5k.list"); + list *options = read_data_cfg(datacfg); + char *test_images = option_find_str(options, "test", "data/test.list"); + list *plist = get_paths(test_images); + char **paths = (char **)list_to_array(plist); layer l = net->layers[net->n-1]; @@ -837,7 +841,7 @@ void run_detector(int argc, char **argv) else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear); else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile); else if(0==strcmp(argv[2], "valid2")) validate_detector_flip(datacfg, cfg, weights, outfile); - else if(0==strcmp(argv[2], "recall")) validate_detector_recall(cfg, weights); + else if(0==strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights); else if(0==strcmp(argv[2], "demo")) { list *options = read_data_cfg(datacfg); int classes = option_find_int(options, "classes", 20);