From 198d169d222b6f2da845ccc9c381d6922c3fdaaf Mon Sep 17 00:00:00 2001 From: agirbau Date: Fri, 22 Mar 2019 12:09:50 +0100 Subject: [PATCH] Compute Precision and recall per class Doing ./darknet detector map ... now returns precision and recall per class instead of a global precision and recall --- src/detector.c | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/src/detector.c b/src/detector.c index 62a8a078..2a1cc151 100644 --- a/src/detector.c +++ b/src/detector.c @@ -751,7 +751,18 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa int unique_truth_count = 0; int* truth_classes_count = (int*)calloc(classes, sizeof(int)); - + + // For multi-class precision and recall computation + float avg_iou_per_class[classes]; + int tp_for_thresh_per_class[classes]; + int fp_for_thresh_per_class[classes]; + int unique_truth_count_per_class[classes]; + + memset(avg_iou_per_class, 0.0, classes * sizeof(float)); + memset(tp_for_thresh_per_class, 0, classes * sizeof(int)); + memset(fp_for_thresh_per_class, 0, classes * sizeof(int)); + memset(unique_truth_count_per_class, 0, classes * sizeof(int)); + for (t = 0; t < nthreads; ++t) { args.path = paths[i + t]; args.im = &buf[t]; @@ -800,6 +811,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa int i, j; for (j = 0; j < num_labels; ++j) { truth_classes_count[truth[j].id]++; + unique_truth_count_per_class[truth[j].id]++; } // difficult @@ -876,9 +888,13 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa if (truth_index > -1 && found == 0) { avg_iou += max_iou; ++tp_for_thresh; + avg_iou_per_class[class_id] += max_iou; + tp_for_thresh_per_class[class_id]++; } - else + else{ fp_for_thresh++; + fp_for_thresh_per_class[class_id]++; + } } } } @@ -905,7 +921,11 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa if ((tp_for_thresh + fp_for_thresh) > 0) avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh); - + for(int class_id = 0; class_id < classes; class_id++){ + if ((tp_for_thresh_per_class[class_id] + fp_for_thresh_per_class[class_id]) > 0) + avg_iou_per_class[class_id] = avg_iou_per_class[class_id] / (tp_for_thresh_per_class[class_id] + fp_for_thresh_per_class[class_id]); + } + // SORT(detections) qsort(detections, detections_count, sizeof(box_prob), detections_comparator); @@ -983,6 +1003,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa for (i = 0; i < classes; ++i) { double avg_precision = 0; + float class_precision = 0.0; + float class_recall = 0.0; int point; for (point = 0; point < 11; ++point) { double cur_recall = point * 0.1; @@ -1000,10 +1022,17 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa avg_precision += cur_precision; } avg_precision = avg_precision / 11; - printf("class_id = %d, name = %s, \t ap = %2.2f %% \n", i, names[i], avg_precision * 100); + + printf("\nTP = %d, FP = %d \n", tp_for_thresh_per_class[i], fp_for_thresh_per_class[i]); + class_precision = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)fp_for_thresh_per_class[i]); + class_recall = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)(unique_truth_count_per_class[i] - tp_for_thresh_per_class[i])); + + printf("class_id = %d, name = %s, \t P = %1.2f, \t R = %1.2f, \t avg IOU = %2.2f %%, \t ap = %2.2f %% \n", i, names[i], class_precision, class_recall, avg_iou_per_class[i], avg_precision * 100); mean_average_precision += avg_precision; } + printf("TP = %1.2f, FP = %1.2f \n", (float)tp_for_thresh, (float)fp_for_thresh); + const float cur_precision = (float)tp_for_thresh / ((float)tp_for_thresh + (float)fp_for_thresh); const float cur_recall = (float)tp_for_thresh / ((float)tp_for_thresh + (float)(unique_truth_count - tp_for_thresh)); const float f1_score = 2.F * cur_precision * cur_recall / (cur_precision + cur_recall);