mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Fixes for training Yolo on small objects
This commit is contained in:
17
src/data.c
17
src/data.c
@ -292,7 +292,7 @@ void fill_truth_region(char *path, float *truth, int classes, int num_boxes, int
|
|||||||
free(boxes);
|
free(boxes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, int flip, float dx, float dy, float sx, float sy)
|
void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, int flip, float dx, float dy, float sx, float sy, int small_object)
|
||||||
{
|
{
|
||||||
char labelpath[4096];
|
char labelpath[4096];
|
||||||
find_replace(path, "images", "labels", labelpath);
|
find_replace(path, "images", "labels", labelpath);
|
||||||
@ -305,6 +305,12 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes,
|
|||||||
find_replace(labelpath, ".JPEG", ".txt", labelpath);
|
find_replace(labelpath, ".JPEG", ".txt", labelpath);
|
||||||
int count = 0;
|
int count = 0;
|
||||||
box_label *boxes = read_boxes(labelpath, &count);
|
box_label *boxes = read_boxes(labelpath, &count);
|
||||||
|
if (small_object == 1) {
|
||||||
|
for (int i = 0; i < count; ++i) {
|
||||||
|
if (boxes[i].w < 0.01) boxes[i].w = 0.01;
|
||||||
|
if (boxes[i].h < 0.01) boxes[i].h = 0.01;
|
||||||
|
}
|
||||||
|
}
|
||||||
randomize_boxes(boxes, count);
|
randomize_boxes(boxes, count);
|
||||||
correct_boxes(boxes, count, dx, dy, sx, sy, flip);
|
correct_boxes(boxes, count, dx, dy, sx, sy, flip);
|
||||||
if(count > num_boxes) count = num_boxes;
|
if(count > num_boxes) count = num_boxes;
|
||||||
@ -319,7 +325,8 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes,
|
|||||||
h = boxes[i].h;
|
h = boxes[i].h;
|
||||||
id = boxes[i].id;
|
id = boxes[i].id;
|
||||||
|
|
||||||
if ((w < .01 || h < .01)) continue;
|
// not detect small objects
|
||||||
|
if ((w < 0.01 || h < 0.01)) continue;
|
||||||
|
|
||||||
truth[i*5+0] = x;
|
truth[i*5+0] = x;
|
||||||
truth[i*5+1] = y;
|
truth[i*5+1] = y;
|
||||||
@ -661,7 +668,7 @@ data load_data_swag(char **paths, int n, int classes, float jitter)
|
|||||||
return d;
|
return d;
|
||||||
}
|
}
|
||||||
|
|
||||||
data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter, float hue, float saturation, float exposure)
|
data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter, float hue, float saturation, float exposure, int small_object)
|
||||||
{
|
{
|
||||||
char **random_paths = get_random_paths(paths, n, m);
|
char **random_paths = get_random_paths(paths, n, m);
|
||||||
int i;
|
int i;
|
||||||
@ -704,7 +711,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, in
|
|||||||
random_distort_image(sized, hue, saturation, exposure);
|
random_distort_image(sized, hue, saturation, exposure);
|
||||||
d.X.vals[i] = sized.data;
|
d.X.vals[i] = sized.data;
|
||||||
|
|
||||||
fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy);
|
fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy, small_object);
|
||||||
|
|
||||||
free_image(orig);
|
free_image(orig);
|
||||||
free_image(cropped);
|
free_image(cropped);
|
||||||
@ -734,7 +741,7 @@ void *load_thread(void *ptr)
|
|||||||
} else if (a.type == REGION_DATA){
|
} else if (a.type == REGION_DATA){
|
||||||
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
|
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
|
||||||
} else if (a.type == DETECTION_DATA){
|
} else if (a.type == DETECTION_DATA){
|
||||||
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
|
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure, a.small_object);
|
||||||
} else if (a.type == SWAG_DATA){
|
} else if (a.type == SWAG_DATA){
|
||||||
*a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
|
*a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
|
||||||
} else if (a.type == COMPARE_DATA){
|
} else if (a.type == COMPARE_DATA){
|
||||||
|
@ -53,6 +53,7 @@ typedef struct load_args{
|
|||||||
int classes;
|
int classes;
|
||||||
int background;
|
int background;
|
||||||
int scale;
|
int scale;
|
||||||
|
int small_object;
|
||||||
float jitter;
|
float jitter;
|
||||||
float angle;
|
float angle;
|
||||||
float aspect;
|
float aspect;
|
||||||
|
@ -82,6 +82,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
|||||||
args.classes = classes;
|
args.classes = classes;
|
||||||
args.jitter = jitter;
|
args.jitter = jitter;
|
||||||
args.num_boxes = l.max_boxes;
|
args.num_boxes = l.max_boxes;
|
||||||
|
args.small_object = l.small_object;
|
||||||
args.d = &buffer;
|
args.d = &buffer;
|
||||||
args.type = DETECTION_DATA;
|
args.type = DETECTION_DATA;
|
||||||
args.threads = 4;// 8;
|
args.threads = 4;// 8;
|
||||||
|
@ -63,6 +63,7 @@ struct layer{
|
|||||||
int out_h, out_w, out_c;
|
int out_h, out_w, out_c;
|
||||||
int n;
|
int n;
|
||||||
int max_boxes;
|
int max_boxes;
|
||||||
|
int small_object;
|
||||||
int groups;
|
int groups;
|
||||||
int size;
|
int size;
|
||||||
int side;
|
int side;
|
||||||
|
@ -245,6 +245,7 @@ layer parse_region(list *options, size_params params)
|
|||||||
l.log = option_find_int_quiet(options, "log", 0);
|
l.log = option_find_int_quiet(options, "log", 0);
|
||||||
l.sqrt = option_find_int_quiet(options, "sqrt", 0);
|
l.sqrt = option_find_int_quiet(options, "sqrt", 0);
|
||||||
|
|
||||||
|
l.small_object = option_find_int(options, "small_object", 0);
|
||||||
l.softmax = option_find_int(options, "softmax", 0);
|
l.softmax = option_find_int(options, "softmax", 0);
|
||||||
l.max_boxes = option_find_int_quiet(options, "max",30);
|
l.max_boxes = option_find_int_quiet(options, "max",30);
|
||||||
l.jitter = option_find_float(options, "jitter", .2);
|
l.jitter = option_find_float(options, "jitter", .2);
|
||||||
|
Reference in New Issue
Block a user