mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Added letter_box=1 param in [net] section (cfg-file) for keeping aspect ratio during training
This commit is contained in:
@ -600,6 +600,7 @@ typedef struct network {
|
||||
int flip; // horizontal flip 50% probability augmentaiont for classifier training (default = 1)
|
||||
int blur;
|
||||
int mixup;
|
||||
int letter_box;
|
||||
float angle;
|
||||
float aspect;
|
||||
float exposure;
|
||||
@ -760,6 +761,7 @@ typedef struct load_args {
|
||||
int mini_batch;
|
||||
int track;
|
||||
int augment_speed;
|
||||
int letter_box;
|
||||
int show_imgs;
|
||||
float jitter;
|
||||
int flip;
|
||||
@ -827,7 +829,7 @@ LIB_API layer* get_network_layer(network* net, int i);
|
||||
LIB_API detection *make_network_boxes(network *net, float thresh, int *num);
|
||||
LIB_API void reset_rnn(network *net);
|
||||
LIB_API float *network_predict_image(network *net, image im);
|
||||
LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, network *existing_net);
|
||||
LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, int letter_box, network *existing_net);
|
||||
LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port, int show_imgs);
|
||||
LIB_API void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh,
|
||||
float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile, int letter_box);
|
||||
|
72
src/data.c
72
src/data.c
@ -804,8 +804,8 @@ void blend_truth(float *new_truth, int boxes, float *old_truth)
|
||||
|
||||
#include "http_stream.h"
|
||||
|
||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, int use_mixup, float jitter,
|
||||
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs)
|
||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, int use_mixup,
|
||||
float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int letter_box, int show_imgs)
|
||||
{
|
||||
const int random_index = random_gen();
|
||||
c = c ? c : 3;
|
||||
@ -828,7 +828,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
||||
d.X.vals = (float**)calloc(d.X.rows, sizeof(float*));
|
||||
d.X.cols = h*w*c;
|
||||
|
||||
float r1 = 0, r2 = 0, r3 = 0, r4 = 0;
|
||||
float r1 = 0, r2 = 0, r3 = 0, r4 = 0, r_scale = 0;
|
||||
float dhue = 0, dsat = 0, dexp = 0, flip = 0, blur = 0;
|
||||
int augmentation_calculated = 0;
|
||||
|
||||
@ -862,6 +862,8 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
||||
r3 = random_float();
|
||||
r4 = random_float();
|
||||
|
||||
r_scale = random_float();
|
||||
|
||||
dhue = rand_uniform_strong(-hue, hue);
|
||||
dsat = rand_scale(saturation);
|
||||
dexp = rand_scale(exposure);
|
||||
@ -874,6 +876,33 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
||||
int pright = rand_precalc_random(-dw, dw, r2);
|
||||
int ptop = rand_precalc_random(-dh, dh, r3);
|
||||
int pbot = rand_precalc_random(-dh, dh, r4);
|
||||
//printf("\n pleft = %d, pright = %d, ptop = %d, pbot = %d, ow = %d, oh = %d \n", pleft, pright, ptop, pbot, ow, oh);
|
||||
|
||||
float scale = rand_precalc_random(.25, 2, r_scale); // unused currently
|
||||
|
||||
if (letter_box)
|
||||
{
|
||||
float img_ar = (float)ow / (float)oh;
|
||||
float net_ar = (float)w / (float)h;
|
||||
float result_ar = img_ar / net_ar;
|
||||
//printf(" ow = %d, oh = %d, w = %d, h = %d, img_ar = %f, net_ar = %f, result_ar = %f \n", ow, oh, w, h, img_ar, net_ar, result_ar);
|
||||
if (result_ar > 1) // sheight - should be increased
|
||||
{
|
||||
float oh_tmp = ow / net_ar;
|
||||
float delta_h = (oh_tmp - oh)/2;
|
||||
ptop = ptop - delta_h;
|
||||
pbot = pbot - delta_h;
|
||||
//printf(" result_ar = %f, oh_tmp = %f, delta_h = %d, ptop = %f, pbot = %f \n", result_ar, oh_tmp, delta_h, ptop, pbot);
|
||||
}
|
||||
else // swidth - should be increased
|
||||
{
|
||||
float ow_tmp = oh * net_ar;
|
||||
float delta_w = (ow_tmp - ow)/2;
|
||||
pleft = pleft - delta_w;
|
||||
pright = pright - delta_w;
|
||||
//printf(" result_ar = %f, ow_tmp = %f, delta_w = %d, pleft = %f, pright = %f \n", result_ar, ow_tmp, delta_w, pleft, pright);
|
||||
}
|
||||
}
|
||||
|
||||
int swidth = ow - pleft - pright;
|
||||
int sheight = oh - ptop - pbot;
|
||||
@ -884,9 +913,10 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
||||
float dx = ((float)pleft / ow) / sx;
|
||||
float dy = ((float)ptop / oh) / sy;
|
||||
|
||||
|
||||
fill_truth_detection(filename, boxes, truth, classes, flip, dx, dy, 1. / sx, 1. / sy, w, h);
|
||||
|
||||
image ai = image_data_augmentation(src, w, h, pleft, ptop, swidth, sheight, flip, jitter, dhue, dsat, dexp,
|
||||
image ai = image_data_augmentation(src, w, h, pleft, ptop, swidth, sheight, flip, dhue, dsat, dexp,
|
||||
blur, boxes, d.y.vals[i]);
|
||||
|
||||
if (i_mixup) {
|
||||
@ -947,7 +977,7 @@ void blend_images(image new_img, float alpha, image old_img, float beta)
|
||||
}
|
||||
|
||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, int use_mixup, float jitter,
|
||||
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs)
|
||||
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int letter_box, int show_imgs)
|
||||
{
|
||||
const int random_index = random_gen();
|
||||
c = c ? c : 3;
|
||||
@ -971,7 +1001,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
||||
d.X.vals = (float**)calloc(d.X.rows, sizeof(float*));
|
||||
d.X.cols = h*w*c;
|
||||
|
||||
float r1 = 0, r2 = 0, r3 = 0, r4 = 0;
|
||||
float r1 = 0, r2 = 0, r3 = 0, r4 = 0, r_scale;
|
||||
float dhue = 0, dsat = 0, dexp = 0, flip = 0;
|
||||
int augmentation_calculated = 0;
|
||||
|
||||
@ -999,6 +1029,8 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
||||
r3 = random_float();
|
||||
r4 = random_float();
|
||||
|
||||
r_scale = random_float();
|
||||
|
||||
dhue = rand_uniform_strong(-hue, hue);
|
||||
dsat = rand_scale(saturation);
|
||||
dexp = rand_scale(exposure);
|
||||
@ -1011,6 +1043,32 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
||||
int ptop = rand_precalc_random(-dh, dh, r3);
|
||||
int pbot = rand_precalc_random(-dh, dh, r4);
|
||||
|
||||
float scale = rand_precalc_random(.25, 2, r_scale); // unused currently
|
||||
|
||||
if (letter_box)
|
||||
{
|
||||
float img_ar = (float)ow / (float)oh;
|
||||
float net_ar = (float)w / (float)h;
|
||||
float result_ar = img_ar / net_ar;
|
||||
//printf(" ow = %d, oh = %d, w = %d, h = %d, img_ar = %f, net_ar = %f, result_ar = %f \n", ow, oh, w, h, img_ar, net_ar, result_ar);
|
||||
if (result_ar > 1) // sheight - should be increased
|
||||
{
|
||||
float oh_tmp = ow / net_ar;
|
||||
float delta_h = (oh_tmp - oh) / 2;
|
||||
ptop = ptop - delta_h;
|
||||
pbot = pbot - delta_h;
|
||||
//printf(" result_ar = %f, oh_tmp = %f, delta_h = %d, ptop = %f, pbot = %f \n", result_ar, oh_tmp, delta_h, ptop, pbot);
|
||||
}
|
||||
else // swidth - should be increased
|
||||
{
|
||||
float ow_tmp = oh * net_ar;
|
||||
float delta_w = (ow_tmp - ow) / 2;
|
||||
pleft = pleft - delta_w;
|
||||
pright = pright - delta_w;
|
||||
//printf(" result_ar = %f, ow_tmp = %f, delta_w = %d, pleft = %f, pright = %f \n", result_ar, ow_tmp, delta_w, pleft, pright);
|
||||
}
|
||||
}
|
||||
|
||||
int swidth = ow - pleft - pright;
|
||||
int sheight = oh - ptop - pbot;
|
||||
|
||||
@ -1100,7 +1158,7 @@ void *load_thread(void *ptr)
|
||||
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
|
||||
} else if (a.type == DETECTION_DATA){
|
||||
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.blur, a.mixup, a.jitter,
|
||||
a.hue, a.saturation, a.exposure, a.mini_batch, a.track, a.augment_speed, a.show_imgs);
|
||||
a.hue, a.saturation, a.exposure, a.mini_batch, a.track, a.augment_speed, a.letter_box, a.show_imgs);
|
||||
} else if (a.type == SWAG_DATA){
|
||||
*a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
|
||||
} else if (a.type == COMPARE_DATA){
|
||||
|
@ -86,8 +86,8 @@ void print_letters(float *pred, int n);
|
||||
data load_data_captcha(char **paths, int n, int m, int k, int w, int h);
|
||||
data load_data_captcha_encode(char **paths, int n, int m, int w, int h);
|
||||
data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h);
|
||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, int use_mixup, float jitter,
|
||||
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs);
|
||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, int use_mixup,
|
||||
float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int letter_box, int show_imgs);
|
||||
data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
|
||||
matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
|
||||
data load_data_super(char **paths, int n, int m, int w, int h, int scale);
|
||||
|
@ -133,6 +133,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
args.exposure = net.exposure;
|
||||
args.saturation = net.saturation;
|
||||
args.hue = net.hue;
|
||||
args.letter_box = net.letter_box;
|
||||
if (dont_show && show_imgs) show_imgs = 2;
|
||||
args.show_imgs = show_imgs;
|
||||
|
||||
@ -275,7 +276,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
//network net_combined = combine_train_valid_networks(net, net_map);
|
||||
|
||||
iter_map = i;
|
||||
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, 0, &net_map);// &net_combined);
|
||||
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, 0, net.letter_box, &net_map);// &net_combined);
|
||||
printf("\n mean_average_precision (mAP@0.5) = %f \n", mean_average_precision);
|
||||
if (mean_average_precision > best_map) {
|
||||
best_map = mean_average_precision;
|
||||
@ -660,7 +661,7 @@ int detections_comparator(const void *pa, const void *pb)
|
||||
return 0;
|
||||
}
|
||||
|
||||
float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, network *existing_net)
|
||||
float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, int letter_box, network *existing_net)
|
||||
{
|
||||
int j;
|
||||
list *options = read_data_cfg(datacfg);
|
||||
@ -733,8 +734,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
||||
args.w = net.w;
|
||||
args.h = net.h;
|
||||
args.c = net.c;
|
||||
args.type = IMAGE_DATA;
|
||||
//args.type = LETTERBOX_DATA;
|
||||
if (letter_box) args.type = LETTERBOX_DATA;
|
||||
else args.type = IMAGE_DATA;
|
||||
|
||||
//const float thresh_calc_avg_iou = 0.24;
|
||||
float avg_iou = 0;
|
||||
@ -783,14 +784,12 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
||||
float hier_thresh = 0;
|
||||
detection *dets;
|
||||
if (args.type == LETTERBOX_DATA) {
|
||||
int letterbox = 1;
|
||||
dets = get_network_boxes(&net, val[t].w, val[t].h, thresh, hier_thresh, 0, 1, &nboxes, letterbox);
|
||||
dets = get_network_boxes(&net, val[t].w, val[t].h, thresh, hier_thresh, 0, 1, &nboxes, letter_box);
|
||||
}
|
||||
else {
|
||||
int letterbox = 0;
|
||||
dets = get_network_boxes(&net, 1, 1, thresh, hier_thresh, 0, 0, &nboxes, letterbox);
|
||||
dets = get_network_boxes(&net, 1, 1, thresh, hier_thresh, 0, 0, &nboxes, letter_box);
|
||||
}
|
||||
//detection *dets = get_network_boxes(&net, val[t].w, val[t].h, thresh, hier_thresh, 0, 1, &nboxes, letterbox); // for letterbox=1
|
||||
//detection *dets = get_network_boxes(&net, val[t].w, val[t].h, thresh, hier_thresh, 0, 1, &nboxes, letter_box); // for letter_box=1
|
||||
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
|
||||
|
||||
char labelpath[4096];
|
||||
@ -1486,7 +1485,7 @@ void run_detector(int argc, char **argv)
|
||||
else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map, mjpeg_port, show_imgs);
|
||||
else if (0 == strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
|
||||
else if (0 == strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights);
|
||||
else if (0 == strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, map_points, NULL);
|
||||
else if (0 == strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, map_points, letter_box, NULL);
|
||||
else if (0 == strcmp(argv[2], "calc_anchors")) calc_anchors(datacfg, num_of_clusters, width, height, show);
|
||||
else if (0 == strcmp(argv[2], "demo")) {
|
||||
list *options = read_data_cfg(datacfg);
|
||||
|
@ -1137,7 +1137,7 @@ static box float_to_box_stride(float *f, int stride)
|
||||
|
||||
image image_data_augmentation(mat_cv* mat, int w, int h,
|
||||
int pleft, int ptop, int swidth, int sheight, int flip,
|
||||
float jitter, float dhue, float dsat, float dexp,
|
||||
float dhue, float dsat, float dexp,
|
||||
int blur, int num_boxes, float *truth)
|
||||
{
|
||||
image out;
|
||||
|
@ -95,7 +95,7 @@ void draw_train_loss(mat_cv* img, int img_size, float avg_loss, float max_img_lo
|
||||
// Data augmentation
|
||||
image image_data_augmentation(mat_cv* mat, int w, int h,
|
||||
int pleft, int ptop, int swidth, int sheight, int flip,
|
||||
float jitter, float dhue, float dsat, float dexp,
|
||||
float dhue, float dsat, float dexp,
|
||||
int blur, int num_boxes, float *truth);
|
||||
|
||||
// blend two images with (alpha and beta)
|
||||
|
@ -153,7 +153,7 @@ float get_network_cost(network net);
|
||||
//LIB_API network *load_network_custom(char *cfg, char *weights, int clear, int batch);
|
||||
//LIB_API network *load_network(char *cfg, char *weights, int clear);
|
||||
//LIB_API float *network_predict_image(network *net, image im);
|
||||
//LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, network *existing_net);
|
||||
//LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, int map_points, int letter_box, network *existing_net);
|
||||
//LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port);
|
||||
//LIB_API int network_width(network *net);
|
||||
//LIB_API int network_height(network *net);
|
||||
|
@ -738,6 +738,7 @@ void parse_net_options(list *options, network *net)
|
||||
net->flip = option_find_int_quiet(options, "flip", 1);
|
||||
net->blur = option_find_int_quiet(options, "blur", 0);
|
||||
net->mixup = option_find_int_quiet(options, "mixup", 0);
|
||||
net->letter_box = option_find_int_quiet(options, "letter_box", 0);
|
||||
|
||||
net->angle = option_find_float_quiet(options, "angle", 0);
|
||||
net->aspect = option_find_float_quiet(options, "aspect", 1);
|
||||
|
Reference in New Issue
Block a user