mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Added label_smooth_eps=0.1 for [net] layer for Label Smoothing for Classifier
This commit is contained in:
@ -638,6 +638,7 @@ typedef struct network {
|
||||
int flip; // horizontal flip 50% probability augmentaiont for classifier training (default = 1)
|
||||
int blur;
|
||||
int mixup;
|
||||
float label_smooth_eps;
|
||||
int letter_box;
|
||||
float angle;
|
||||
float aspect;
|
||||
@ -813,6 +814,7 @@ typedef struct load_args {
|
||||
int flip;
|
||||
int blur;
|
||||
int mixup;
|
||||
float label_smooth_eps;
|
||||
float angle;
|
||||
float aspect;
|
||||
float saturation;
|
||||
|
@ -92,6 +92,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
|
||||
args.hue = net.hue;
|
||||
args.size = net.w > net.h ? net.w : net.h;
|
||||
|
||||
args.label_smooth_eps = net.label_smooth_eps;
|
||||
args.mixup = net.mixup;
|
||||
if (dont_show && show_imgs) show_imgs = 2;
|
||||
args.show_imgs = show_imgs;
|
||||
|
44
src/data.c
44
src/data.c
@ -516,6 +516,32 @@ void fill_truth(char *path, char **labels, int k, float *truth)
|
||||
}
|
||||
}
|
||||
|
||||
void fill_truth_smooth(char *path, char **labels, int k, float *truth, float label_smooth_eps)
|
||||
{
|
||||
int i;
|
||||
memset(truth, 0, k * sizeof(float));
|
||||
int count = 0;
|
||||
for (i = 0; i < k; ++i) {
|
||||
if (strstr(path, labels[i])) {
|
||||
truth[i] = (1 - label_smooth_eps);
|
||||
++count;
|
||||
}
|
||||
else {
|
||||
truth[i] = label_smooth_eps / (k - 1);
|
||||
}
|
||||
}
|
||||
if (count != 1) {
|
||||
printf("Too many or too few labels: %d, %s\n", count, path);
|
||||
count = 0;
|
||||
for (i = 0; i < k; ++i) {
|
||||
if (strstr(path, labels[i])) {
|
||||
printf("\t label %d: %s \n", count, labels[i]);
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fill_hierarchy(float *truth, int k, tree *hierarchy)
|
||||
{
|
||||
int j;
|
||||
@ -548,12 +574,12 @@ void fill_hierarchy(float *truth, int k, tree *hierarchy)
|
||||
}
|
||||
}
|
||||
|
||||
matrix load_labels_paths(char **paths, int n, char **labels, int k, tree *hierarchy)
|
||||
matrix load_labels_paths(char **paths, int n, char **labels, int k, tree *hierarchy, float label_smooth_eps)
|
||||
{
|
||||
matrix y = make_matrix(n, k);
|
||||
int i;
|
||||
for(i = 0; i < n && labels; ++i){
|
||||
fill_truth(paths[i], labels, k, y.vals[i]);
|
||||
fill_truth_smooth(paths[i], labels, k, y.vals[i], label_smooth_eps);
|
||||
if(hierarchy){
|
||||
fill_hierarchy(y.vals[i], k, hierarchy);
|
||||
}
|
||||
@ -1336,7 +1362,7 @@ void *load_thread(void *ptr)
|
||||
if (a.type == OLD_CLASSIFICATION_DATA){
|
||||
*a.d = load_data_old(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
|
||||
} else if (a.type == CLASSIFICATION_DATA){
|
||||
*a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.hierarchy, a.flip, a.min, a.max, a.w, a.h, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.mixup, a.blur, a.show_imgs);
|
||||
*a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.hierarchy, a.flip, a.min, a.max, a.w, a.h, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.mixup, a.blur, a.show_imgs, a.label_smooth_eps);
|
||||
} else if (a.type == SUPER_DATA){
|
||||
*a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale);
|
||||
} else if (a.type == WRITING_DATA){
|
||||
@ -1432,7 +1458,7 @@ data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int
|
||||
data d = {0};
|
||||
d.shallow = 0;
|
||||
d.X = load_image_paths(paths, n, w, h);
|
||||
d.y = load_labels_paths(paths, n, labels, k, 0);
|
||||
d.y = load_labels_paths(paths, n, labels, k, 0, 0);
|
||||
if(m) free(paths);
|
||||
return d;
|
||||
}
|
||||
@ -1481,21 +1507,21 @@ data load_data_super(char **paths, int n, int m, int w, int h, int scale)
|
||||
return d;
|
||||
}
|
||||
|
||||
data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure, int mixup, int use_blur, int show_imgs)
|
||||
data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure, int mixup, int use_blur, int show_imgs, float label_smooth_eps)
|
||||
{
|
||||
char **paths_stored = paths;
|
||||
if(m) paths = get_random_paths(paths, n, m);
|
||||
data d = {0};
|
||||
d.shallow = 0;
|
||||
d.X = load_image_augment_paths(paths, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure);
|
||||
d.y = load_labels_paths(paths, n, labels, k, hierarchy);
|
||||
d.y = load_labels_paths(paths, n, labels, k, hierarchy, label_smooth_eps);
|
||||
|
||||
if (mixup && rand_int(0, 1)) {
|
||||
char **paths_mix = get_random_paths(paths_stored, n, m);
|
||||
data d2 = { 0 };
|
||||
d2.shallow = 0;
|
||||
d2.X = load_image_augment_paths(paths_mix, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure);
|
||||
d2.y = load_labels_paths(paths_mix, n, labels, k, hierarchy);
|
||||
d2.y = load_labels_paths(paths_mix, n, labels, k, hierarchy, label_smooth_eps);
|
||||
free(paths_mix);
|
||||
|
||||
data d3 = { 0 };
|
||||
@ -1505,12 +1531,12 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h
|
||||
if (mixup >= 3) {
|
||||
char **paths_mix3 = get_random_paths(paths_stored, n, m);
|
||||
d3.X = load_image_augment_paths(paths_mix3, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure);
|
||||
d3.y = load_labels_paths(paths_mix3, n, labels, k, hierarchy);
|
||||
d3.y = load_labels_paths(paths_mix3, n, labels, k, hierarchy, label_smooth_eps);
|
||||
free(paths_mix3);
|
||||
|
||||
char **paths_mix4 = get_random_paths(paths_stored, n, m);
|
||||
d4.X = load_image_augment_paths(paths_mix4, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure);
|
||||
d4.y = load_labels_paths(paths_mix4, n, labels, k, hierarchy);
|
||||
d4.y = load_labels_paths(paths_mix4, n, labels, k, hierarchy, label_smooth_eps);
|
||||
free(paths_mix4);
|
||||
}
|
||||
|
||||
|
@ -91,7 +91,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
||||
data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure);
|
||||
matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure);
|
||||
data load_data_super(char **paths, int n, int m, int w, int h, int scale);
|
||||
data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure, int mixup, int use_blur, int show_imgs);
|
||||
data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure, int mixup, int use_blur, int show_imgs, float label_smooth_eps);
|
||||
data load_go(char *filename);
|
||||
|
||||
box_label *read_boxes(char *filename, int *n);
|
||||
@ -116,6 +116,7 @@ data *split_data(data d, int part, int total);
|
||||
data concat_data(data d1, data d2);
|
||||
data concat_datas(data *d, int n);
|
||||
void fill_truth(char *path, char **labels, int k, float *truth);
|
||||
void fill_truth_smooth(char *path, char **labels, int k, float *truth, float label_smooth_eps);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
|
@ -930,6 +930,7 @@ void parse_net_options(list *options, network *net)
|
||||
else if (cutmix) net->mixup = 2;
|
||||
else if (mosaic) net->mixup = 3;
|
||||
net->letter_box = option_find_int_quiet(options, "letter_box", 0);
|
||||
net->label_smooth_eps = option_find_float_quiet(options, "label_smooth_eps", 0.0f);
|
||||
|
||||
net->angle = option_find_float_quiet(options, "angle", 0);
|
||||
net->aspect = option_find_float_quiet(options, "aspect", 1);
|
||||
|
Reference in New Issue
Block a user