From f4825906916130ac1c32eb69c9fb812aacde1838 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Wed, 27 Mar 2019 01:34:06 +0300 Subject: [PATCH] Added -map_points flag for mAP calculation for MSCOCO, PascalVOC2007 / 2010-2012, ImageNet --- build/darknet/x64/calc_mAP.cmd | 4 +- build/darknet/x64/calc_mAP_coco.cmd | 16 ++++ include/darknet.h | 2 +- src/detector.c | 114 ++++++++++++++++++---------- 4 files changed, 91 insertions(+), 45 deletions(-) create mode 100644 build/darknet/x64/calc_mAP_coco.cmd diff --git a/build/darknet/x64/calc_mAP.cmd b/build/darknet/x64/calc_mAP.cmd index 9b4fafb3..4622c525 100644 --- a/build/darknet/x64/calc_mAP.cmd +++ b/build/darknet/x64/calc_mAP.cmd @@ -1,10 +1,10 @@ rem # How to calculate mAP (mean average precision) -rem darknet.exe detector map data/voc.data cfg/yolov2-tiny-voc.cfg yolov2-tiny-voc.weights -11points +rem darknet.exe detector map data/voc.data cfg/yolov2-tiny-voc.cfg yolov2-tiny-voc.weights -points 11 -darknet.exe detector map data/voc.data cfg/yolov2-voc.cfg yolo-voc.weights -11points +darknet.exe detector map data/voc.data cfg/yolov2-voc.cfg yolo-voc.weights -points 11 diff --git a/build/darknet/x64/calc_mAP_coco.cmd b/build/darknet/x64/calc_mAP_coco.cmd new file mode 100644 index 00000000..e4bd32de --- /dev/null +++ b/build/darknet/x64/calc_mAP_coco.cmd @@ -0,0 +1,16 @@ +rem # How to calculate Yolo v3 mAP on MS COCO + +rem darknet.exe detector map cfg/coco.data cfg/yolov3-tiny.cfg yolov3-tiny.weights -points 101 + + +darknet.exe detector map cfg/coco.data cfg/yolov3-spp.cfg yolov3-spp.weights -points 101 + + +rem darknet.exe detector map cfg/coco.data cfg/yolov3.cfg yolov3.weights -points 101 + + +rem darknet.exe detector map cfg/coco.data cfg/yolov3.cfg yolov3.weights -iou_thresh 0.75 -points 101 + + + +pause diff --git a/include/darknet.h b/include/darknet.h index fc56a4f1..e0479641 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -767,7 +767,7 @@ LIB_API layer* get_network_layer(network* net, int i); LIB_API detection *make_network_boxes(network *net, float thresh, int *num); LIB_API void reset_rnn(network *net); LIB_API float *network_predict_image(network *net, image im); -LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_11_points, network *existing_net); +LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, network *existing_net); LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port); LIB_API void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile); diff --git a/src/detector.c b/src/detector.c index b28783ff..d189d6b8 100644 --- a/src/detector.c +++ b/src/detector.c @@ -674,7 +674,7 @@ int detections_comparator(const void *pa, const void *pb) return 0; } -float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_11_points, network *existing_net) +float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, network *existing_net) { int j; list *options = read_data_cfg(datacfg); @@ -752,18 +752,12 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa int unique_truth_count = 0; int* truth_classes_count = (int*)calloc(classes, sizeof(int)); - + // For multi-class precision and recall computation - float avg_iou_per_class[classes]; - int tp_for_thresh_per_class[classes]; - int fp_for_thresh_per_class[classes]; - int unique_truth_count_per_class[classes]; - - memset(avg_iou_per_class, 0.0, classes * sizeof(float)); - memset(tp_for_thresh_per_class, 0, classes * sizeof(int)); - memset(fp_for_thresh_per_class, 0, classes * sizeof(int)); - memset(unique_truth_count_per_class, 0, classes * sizeof(int)); - + float *avg_iou_per_class = (int*)calloc(classes, sizeof(int)); + int *tp_for_thresh_per_class = (int*)calloc(classes, sizeof(int)); + int *fp_for_thresh_per_class = (int*)calloc(classes, sizeof(int)); + for (t = 0; t < nthreads; ++t) { args.path = paths[i + t]; args.im = &buf[t]; @@ -812,7 +806,6 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa int i, j; for (j = 0; j < num_labels; ++j) { truth_classes_count[truth[j].id]++; - unique_truth_count_per_class[truth[j].id]++; } // difficult @@ -881,10 +874,11 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa // calc avg IoU, true-positives, false-positives for required Threshold if (prob > thresh_calc_avg_iou) { int z, found = 0; - for (z = checkpoint_detections_count; z < detections_count - 1; ++z) + for (z = checkpoint_detections_count; z < detections_count - 1; ++z) { if (detections[z].unique_truth_index == truth_index) { found = 1; break; } + } if (truth_index > -1 && found == 0) { avg_iou += max_iou; @@ -922,11 +916,12 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa if ((tp_for_thresh + fp_for_thresh) > 0) avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh); - for(int class_id = 0; class_id < classes; class_id++){ + int class_id; + for(class_id = 0; class_id < classes; class_id++){ if ((tp_for_thresh_per_class[class_id] + fp_for_thresh_per_class[class_id]) > 0) avg_iou_per_class[class_id] = avg_iou_per_class[class_id] / (tp_for_thresh_per_class[class_id] + fp_for_thresh_per_class[class_id]); } - + // SORT(detections) qsort(detections, detections_count, sizeof(box_prob), detections_comparator); @@ -1004,47 +999,74 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa for (i = 0; i < classes; ++i) { double avg_precision = 0; - float class_precision = 0.0; - float class_recall = 0.0; - int point; - for (point = 0; point < 11; ++point) { - double cur_recall = point * 0.1; - double cur_precision = 0; - for (rank = 0; rank < detections_count; ++rank) + + // MS COCO - uses 101-Recall-points on PR-chart. + // PascalVOC2007 - uses 11-Recall-points on PR-chart. + // PascalVOC2010–2012 - uses Area-Under-Curve on PR-chart. + // ImageNet - uses Area-Under-Curve on PR-chart. + + // correct mAP calculation: ImageNet, PascalVOC 2010-2012 + if (map_points == 0) + { + double last_recall = pr[i][detections_count - 1].recall; + double last_precision = pr[i][detections_count - 1].precision; + for (rank = detections_count - 2; rank >= 0; --rank) { - if (pr[i][rank].recall >= cur_recall) { // > or >= - if (pr[i][rank].precision > cur_precision) { - cur_precision = pr[i][rank].precision; + double delta_recall = last_recall - pr[i][rank].recall; + last_recall = pr[i][rank].recall; + + if (pr[i][rank].precision > last_precision) { + last_precision = pr[i][rank].precision; + } + + avg_precision += delta_recall * last_precision; + } + } + // MSCOCO - 101 Recall-points, PascalVOC - 11 Recall-points + else + { + int point; + for (point = 0; point < map_points; ++point) { + double cur_recall = point * 1.0 / (map_points-1); + double cur_precision = 0; + for (rank = 0; rank < detections_count; ++rank) + { + if (pr[i][rank].recall >= cur_recall) { // > or >= + if (pr[i][rank].precision > cur_precision) { + cur_precision = pr[i][rank].precision; + } } } - } - //printf("class_id = %d, point = %d, cur_recall = %.4f, cur_precision = %.4f \n", i, point, cur_recall, cur_precision); + //printf("class_id = %d, point = %d, cur_recall = %.4f, cur_precision = %.4f \n", i, point, cur_recall, cur_precision); - avg_precision += cur_precision; + avg_precision += cur_precision; + } + avg_precision = avg_precision / map_points; } - avg_precision = avg_precision / 11; - - printf("\nTP = %d, FP = %d \n", tp_for_thresh_per_class[i], fp_for_thresh_per_class[i]); - class_precision = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)fp_for_thresh_per_class[i]); - class_recall = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)(unique_truth_count_per_class[i] - tp_for_thresh_per_class[i])); - - printf("class_id = %d, name = %s, \t P = %1.2f, \t R = %1.2f, \t avg IOU = %2.2f %%, \t ap = %2.2f %% \n", i, names[i], class_precision, class_recall, avg_iou_per_class[i], avg_precision * 100); + + printf("class_id = %d, name = %s, ap = %2.2f%% \t (TP = %d, FP = %d) \n", + i, names[i], avg_precision * 100, tp_for_thresh_per_class[i], fp_for_thresh_per_class[i]); + + float class_precision = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)fp_for_thresh_per_class[i]); + float class_recall = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)(truth_classes_count[i] - tp_for_thresh_per_class[i])); + //printf("Precision = %1.2f, Recall = %1.2f, avg IOU = %2.2f%% \n\n", class_precision, class_recall, avg_iou_per_class[i]); + mean_average_precision += avg_precision; } - printf("TP = %1.2f, FP = %1.2f \n", (float)tp_for_thresh, (float)fp_for_thresh); - const float cur_precision = (float)tp_for_thresh / ((float)tp_for_thresh + (float)fp_for_thresh); const float cur_recall = (float)tp_for_thresh / ((float)tp_for_thresh + (float)(unique_truth_count - tp_for_thresh)); const float f1_score = 2.F * cur_precision * cur_recall / (cur_precision + cur_recall); - printf(" for thresh = %1.2f, precision = %1.2f, recall = %1.2f, F1-score = %1.2f \n", + printf("\n for thresh = %1.2f, precision = %1.2f, recall = %1.2f, F1-score = %1.2f \n", thresh_calc_avg_iou, cur_precision, cur_recall, f1_score); printf(" for thresh = %0.2f, TP = %d, FP = %d, FN = %d, average IoU = %2.2f %% \n", thresh_calc_avg_iou, tp_for_thresh, fp_for_thresh, unique_truth_count - tp_for_thresh, avg_iou * 100); mean_average_precision = mean_average_precision / classes; - printf("\n IoU threshold = %2.0f %% \n", iou_thresh * 100); + printf("\n IoU threshold = %2.0f %%, ", iou_thresh * 100); + if (map_points) printf("used %d Recall-points \n", map_points); + else printf("used Area-Under-Curve for each unique Recall \n"); printf(" mean average precision (mAP@%0.2f) = %f, or %2.2f %% \n", iou_thresh, mean_average_precision, mean_average_precision * 100); @@ -1056,7 +1078,15 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa free(truth_classes_count); free(detection_per_class_count); + free(avg_iou_per_class); + free(tp_for_thresh_per_class); + free(fp_for_thresh_per_class); + fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start)); + printf("\nSet -map_points flag:\n"); + printf(" `-map_points 101` for MS COCO \n"); + printf(" `-map_points 11` for PascalVOC 2007 (uncomment `difficult` in voc.data) \n"); + printf(" `-map_points 0` (AUC) for ImageNet, PascalVOC 2010-2012, your custom dataset\n"); if (reinforcement_fd != NULL) fclose(reinforcement_fd); // free memory @@ -1449,7 +1479,7 @@ void run_detector(int argc, char **argv) int dont_show = find_arg(argc, argv, "-dont_show"); int show = find_arg(argc, argv, "-show"); int calc_map = find_arg(argc, argv, "-map"); - int map_11_points = find_arg(argc, argv, "-11points"); + int map_points = find_int_arg(argc, argv, "-points", 0); check_mistakes = find_arg(argc, argv, "-check_mistakes"); int mjpeg_port = find_int_arg(argc, argv, "-mjpeg_port", -1); int json_port = find_int_arg(argc, argv, "-json_port", -1); @@ -1509,7 +1539,7 @@ void run_detector(int argc, char **argv) else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map, mjpeg_port); else if (0 == strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile); else if (0 == strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights); - else if (0 == strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, map_11_points, NULL); + else if (0 == strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, map_points, NULL); else if (0 == strcmp(argv[2], "calc_anchors")) calc_anchors(datacfg, num_of_clusters, width, height, show); else if (0 == strcmp(argv[2], "demo")) { list *options = read_data_cfg(datacfg);