diff --git a/src/detector.c b/src/detector.c index 4f421593..b28783ff 100644 --- a/src/detector.c +++ b/src/detector.c @@ -752,7 +752,18 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa int unique_truth_count = 0; int* truth_classes_count = (int*)calloc(classes, sizeof(int)); - + + // For multi-class precision and recall computation + float avg_iou_per_class[classes]; + int tp_for_thresh_per_class[classes]; + int fp_for_thresh_per_class[classes]; + int unique_truth_count_per_class[classes]; + + memset(avg_iou_per_class, 0.0, classes * sizeof(float)); + memset(tp_for_thresh_per_class, 0, classes * sizeof(int)); + memset(fp_for_thresh_per_class, 0, classes * sizeof(int)); + memset(unique_truth_count_per_class, 0, classes * sizeof(int)); + for (t = 0; t < nthreads; ++t) { args.path = paths[i + t]; args.im = &buf[t]; @@ -801,6 +812,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa int i, j; for (j = 0; j < num_labels; ++j) { truth_classes_count[truth[j].id]++; + unique_truth_count_per_class[truth[j].id]++; } // difficult @@ -877,9 +889,13 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa if (truth_index > -1 && found == 0) { avg_iou += max_iou; ++tp_for_thresh; + avg_iou_per_class[class_id] += max_iou; + tp_for_thresh_per_class[class_id]++; } - else + else{ fp_for_thresh++; + fp_for_thresh_per_class[class_id]++; + } } } } @@ -906,7 +922,11 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa if ((tp_for_thresh + fp_for_thresh) > 0) avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh); - + for(int class_id = 0; class_id < classes; class_id++){ + if ((tp_for_thresh_per_class[class_id] + fp_for_thresh_per_class[class_id]) > 0) + avg_iou_per_class[class_id] = avg_iou_per_class[class_id] / (tp_for_thresh_per_class[class_id] + fp_for_thresh_per_class[class_id]); + } + // SORT(detections) qsort(detections, detections_count, sizeof(box_prob), detections_comparator); @@ -984,6 +1004,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa for (i = 0; i < classes; ++i) { double avg_precision = 0; + float class_precision = 0.0; + float class_recall = 0.0; int point; for (point = 0; point < 11; ++point) { double cur_recall = point * 0.1; @@ -1001,10 +1023,17 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa avg_precision += cur_precision; } avg_precision = avg_precision / 11; - printf("class_id = %d, name = %s, \t ap = %2.2f %% \n", i, names[i], avg_precision * 100); + + printf("\nTP = %d, FP = %d \n", tp_for_thresh_per_class[i], fp_for_thresh_per_class[i]); + class_precision = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)fp_for_thresh_per_class[i]); + class_recall = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)(unique_truth_count_per_class[i] - tp_for_thresh_per_class[i])); + + printf("class_id = %d, name = %s, \t P = %1.2f, \t R = %1.2f, \t avg IOU = %2.2f %%, \t ap = %2.2f %% \n", i, names[i], class_precision, class_recall, avg_iou_per_class[i], avg_precision * 100); mean_average_precision += avg_precision; } + printf("TP = %1.2f, FP = %1.2f \n", (float)tp_for_thresh, (float)fp_for_thresh); + const float cur_precision = (float)tp_for_thresh / ((float)tp_for_thresh + (float)fp_for_thresh); const float cur_recall = (float)tp_for_thresh / ((float)tp_for_thresh + (float)(unique_truth_count - tp_for_thresh)); const float f1_score = 2.F * cur_precision * cur_recall / (cur_precision + cur_recall);