mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Merge pull request #2703 from agirbau/patch-1
Compute Precision and recall per class
This commit is contained in:
@ -752,7 +752,18 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
||||
int unique_truth_count = 0;
|
||||
|
||||
int* truth_classes_count = (int*)calloc(classes, sizeof(int));
|
||||
|
||||
|
||||
// For multi-class precision and recall computation
|
||||
float avg_iou_per_class[classes];
|
||||
int tp_for_thresh_per_class[classes];
|
||||
int fp_for_thresh_per_class[classes];
|
||||
int unique_truth_count_per_class[classes];
|
||||
|
||||
memset(avg_iou_per_class, 0.0, classes * sizeof(float));
|
||||
memset(tp_for_thresh_per_class, 0, classes * sizeof(int));
|
||||
memset(fp_for_thresh_per_class, 0, classes * sizeof(int));
|
||||
memset(unique_truth_count_per_class, 0, classes * sizeof(int));
|
||||
|
||||
for (t = 0; t < nthreads; ++t) {
|
||||
args.path = paths[i + t];
|
||||
args.im = &buf[t];
|
||||
@ -801,6 +812,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
||||
int i, j;
|
||||
for (j = 0; j < num_labels; ++j) {
|
||||
truth_classes_count[truth[j].id]++;
|
||||
unique_truth_count_per_class[truth[j].id]++;
|
||||
}
|
||||
|
||||
// difficult
|
||||
@ -877,9 +889,13 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
||||
if (truth_index > -1 && found == 0) {
|
||||
avg_iou += max_iou;
|
||||
++tp_for_thresh;
|
||||
avg_iou_per_class[class_id] += max_iou;
|
||||
tp_for_thresh_per_class[class_id]++;
|
||||
}
|
||||
else
|
||||
else{
|
||||
fp_for_thresh++;
|
||||
fp_for_thresh_per_class[class_id]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -906,7 +922,11 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
||||
if ((tp_for_thresh + fp_for_thresh) > 0)
|
||||
avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh);
|
||||
|
||||
|
||||
for(int class_id = 0; class_id < classes; class_id++){
|
||||
if ((tp_for_thresh_per_class[class_id] + fp_for_thresh_per_class[class_id]) > 0)
|
||||
avg_iou_per_class[class_id] = avg_iou_per_class[class_id] / (tp_for_thresh_per_class[class_id] + fp_for_thresh_per_class[class_id]);
|
||||
}
|
||||
|
||||
// SORT(detections)
|
||||
qsort(detections, detections_count, sizeof(box_prob), detections_comparator);
|
||||
|
||||
@ -984,6 +1004,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
||||
|
||||
for (i = 0; i < classes; ++i) {
|
||||
double avg_precision = 0;
|
||||
float class_precision = 0.0;
|
||||
float class_recall = 0.0;
|
||||
int point;
|
||||
for (point = 0; point < 11; ++point) {
|
||||
double cur_recall = point * 0.1;
|
||||
@ -1001,10 +1023,17 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
||||
avg_precision += cur_precision;
|
||||
}
|
||||
avg_precision = avg_precision / 11;
|
||||
printf("class_id = %d, name = %s, \t ap = %2.2f %% \n", i, names[i], avg_precision * 100);
|
||||
|
||||
printf("\nTP = %d, FP = %d \n", tp_for_thresh_per_class[i], fp_for_thresh_per_class[i]);
|
||||
class_precision = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)fp_for_thresh_per_class[i]);
|
||||
class_recall = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)(unique_truth_count_per_class[i] - tp_for_thresh_per_class[i]));
|
||||
|
||||
printf("class_id = %d, name = %s, \t P = %1.2f, \t R = %1.2f, \t avg IOU = %2.2f %%, \t ap = %2.2f %% \n", i, names[i], class_precision, class_recall, avg_iou_per_class[i], avg_precision * 100);
|
||||
mean_average_precision += avg_precision;
|
||||
}
|
||||
|
||||
printf("TP = %1.2f, FP = %1.2f \n", (float)tp_for_thresh, (float)fp_for_thresh);
|
||||
|
||||
const float cur_precision = (float)tp_for_thresh / ((float)tp_for_thresh + (float)fp_for_thresh);
|
||||
const float cur_recall = (float)tp_for_thresh / ((float)tp_for_thresh + (float)(unique_truth_count - tp_for_thresh));
|
||||
const float f1_score = 2.F * cur_precision * cur_recall / (cur_precision + cur_recall);
|
||||
|
Reference in New Issue
Block a user