mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Added flag -show_imgs for Training, to show augmented images with bboxes
This commit is contained in:
@ -706,6 +706,7 @@ typedef struct load_args {
|
|||||||
int mini_batch;
|
int mini_batch;
|
||||||
int track;
|
int track;
|
||||||
int augment_speed;
|
int augment_speed;
|
||||||
|
int show_imgs;
|
||||||
float jitter;
|
float jitter;
|
||||||
int flip;
|
int flip;
|
||||||
float angle;
|
float angle;
|
||||||
@ -771,7 +772,7 @@ LIB_API detection *make_network_boxes(network *net, float thresh, int *num);
|
|||||||
LIB_API void reset_rnn(network *net);
|
LIB_API void reset_rnn(network *net);
|
||||||
LIB_API float *network_predict_image(network *net, image im);
|
LIB_API float *network_predict_image(network *net, image im);
|
||||||
LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, network *existing_net);
|
LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, network *existing_net);
|
||||||
LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port);
|
LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port, int show_imgs);
|
||||||
LIB_API void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh,
|
LIB_API void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh,
|
||||||
float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile);
|
float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile);
|
||||||
LIB_API int network_width(network *net);
|
LIB_API int network_width(network *net);
|
||||||
|
28
src/data.c
28
src/data.c
@ -769,7 +769,8 @@ static box float_to_box_stride(float *f, int stride)
|
|||||||
|
|
||||||
#include "http_stream.h"
|
#include "http_stream.h"
|
||||||
|
|
||||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed)
|
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter,
|
||||||
|
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs)
|
||||||
{
|
{
|
||||||
c = c ? c : 3;
|
c = c ? c : 3;
|
||||||
char **random_paths;
|
char **random_paths;
|
||||||
@ -844,10 +845,9 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
|||||||
|
|
||||||
fill_truth_detection(filename, boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy, w, h);
|
fill_truth_detection(filename, boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy, w, h);
|
||||||
|
|
||||||
const int show_augmented_images = 0;
|
if(show_imgs)
|
||||||
if(show_augmented_images)
|
|
||||||
{
|
{
|
||||||
char buff[10];
|
char buff[1000];
|
||||||
sprintf(buff, "aug_%s_%d", random_paths[i], random_gen());
|
sprintf(buff, "aug_%s_%d", random_paths[i], random_gen());
|
||||||
int t;
|
int t;
|
||||||
for (t = 0; t < boxes; ++t) {
|
for (t = 0; t < boxes; ++t) {
|
||||||
@ -860,18 +860,20 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
|||||||
draw_box_width(ai, left, top, right, bot, 3, 150, 100, 50); // 3 channels RGB
|
draw_box_width(ai, left, top, right, bot, 3, 150, 100, 50); // 3 channels RGB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
save_image(ai, buff);
|
||||||
show_image(ai, buff);
|
show_image(ai, buff);
|
||||||
wait_until_press_key_cv();
|
wait_until_press_key_cv();
|
||||||
|
printf("\nYou use flag -show_imgs, so will be saved aug_...jpg images. Click on window and press ESC button \n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
release_ipl(&src);
|
release_ipl(&src);
|
||||||
}
|
}
|
||||||
free(random_paths);
|
free(random_paths);
|
||||||
return d;
|
return d;
|
||||||
}
|
}
|
||||||
#else // OPENCV
|
#else // OPENCV
|
||||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed)
|
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter,
|
||||||
|
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs)
|
||||||
{
|
{
|
||||||
c = c ? c : 3;
|
c = c ? c : 3;
|
||||||
char **random_paths;
|
char **random_paths;
|
||||||
@ -938,10 +940,10 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
|||||||
|
|
||||||
fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, dx, dy, 1. / sx, 1. / sy, w, h);
|
fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, dx, dy, 1. / sx, 1. / sy, w, h);
|
||||||
|
|
||||||
/*
|
if(show_imgs)
|
||||||
{
|
{
|
||||||
char buff[10];
|
char buff[1000];
|
||||||
sprintf(buff, "aug_%s_%d", random_paths[i], random_gen());
|
sprintf(buff, "aug_%s_%d", basecfg(random_paths[i]), random_gen());
|
||||||
int t;
|
int t;
|
||||||
for (t = 0; t < boxes; ++t) {
|
for (t = 0; t < boxes; ++t) {
|
||||||
box b = float_to_box_stride(d.y.vals[i] + t*(4 + 1), 1);
|
box b = float_to_box_stride(d.y.vals[i] + t*(4 + 1), 1);
|
||||||
@ -954,8 +956,11 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
show_image(sized, buff);
|
show_image(sized, buff);
|
||||||
|
save_image(sized, buff);
|
||||||
wait_until_press_key_cv();
|
wait_until_press_key_cv();
|
||||||
}*/
|
printf("\nYou use flag -show_imgs, so will be saved aug_...jpg images. Press Enter: \n");
|
||||||
|
getchar();
|
||||||
|
}
|
||||||
|
|
||||||
free_image(orig);
|
free_image(orig);
|
||||||
free_image(cropped);
|
free_image(cropped);
|
||||||
@ -985,7 +990,8 @@ void *load_thread(void *ptr)
|
|||||||
} else if (a.type == REGION_DATA){
|
} else if (a.type == REGION_DATA){
|
||||||
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
|
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
|
||||||
} else if (a.type == DETECTION_DATA){
|
} else if (a.type == DETECTION_DATA){
|
||||||
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.jitter, a.hue, a.saturation, a.exposure, a.mini_batch, a.track, a.augment_speed);
|
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.jitter,
|
||||||
|
a.hue, a.saturation, a.exposure, a.mini_batch, a.track, a.augment_speed, a.show_imgs);
|
||||||
} else if (a.type == SWAG_DATA){
|
} else if (a.type == SWAG_DATA){
|
||||||
*a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
|
*a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
|
||||||
} else if (a.type == COMPARE_DATA){
|
} else if (a.type == COMPARE_DATA){
|
||||||
|
@ -86,7 +86,8 @@ void print_letters(float *pred, int n);
|
|||||||
data load_data_captcha(char **paths, int n, int m, int k, int w, int h);
|
data load_data_captcha(char **paths, int n, int m, int k, int w, int h);
|
||||||
data load_data_captcha_encode(char **paths, int n, int m, int w, int h);
|
data load_data_captcha_encode(char **paths, int n, int m, int w, int h);
|
||||||
data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h);
|
data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h);
|
||||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed);
|
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter,
|
||||||
|
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs);
|
||||||
data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
|
data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
|
||||||
matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
|
matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
|
||||||
data load_data_super(char **paths, int n, int m, int w, int h, int scale);
|
data load_data_super(char **paths, int n, int m, int w, int h, int scale);
|
||||||
|
@ -22,7 +22,7 @@ int check_mistakes;
|
|||||||
|
|
||||||
static int coco_ids[] = { 1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90 };
|
static int coco_ids[] = { 1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90 };
|
||||||
|
|
||||||
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port)
|
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port, int show_imgs)
|
||||||
{
|
{
|
||||||
list *options = read_data_cfg(datacfg);
|
list *options = read_data_cfg(datacfg);
|
||||||
char *train_images = option_find_str(options, "train", "data/train.txt");
|
char *train_images = option_find_str(options, "train", "data/train.txt");
|
||||||
@ -127,6 +127,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
|||||||
args.exposure = net.exposure;
|
args.exposure = net.exposure;
|
||||||
args.saturation = net.saturation;
|
args.saturation = net.saturation;
|
||||||
args.hue = net.hue;
|
args.hue = net.hue;
|
||||||
|
args.show_imgs = show_imgs;
|
||||||
|
|
||||||
#ifdef OPENCV
|
#ifdef OPENCV
|
||||||
args.threads = 3 * ngpus; // Amazon EC2 Tesla V100: p3.2xlarge (8 logical cores) - p3.16xlarge
|
args.threads = 3 * ngpus; // Amazon EC2 Tesla V100: p3.2xlarge (8 logical cores) - p3.16xlarge
|
||||||
@ -1388,6 +1389,7 @@ void run_detector(int argc, char **argv)
|
|||||||
int calc_map = find_arg(argc, argv, "-map");
|
int calc_map = find_arg(argc, argv, "-map");
|
||||||
int map_points = find_int_arg(argc, argv, "-points", 0);
|
int map_points = find_int_arg(argc, argv, "-points", 0);
|
||||||
check_mistakes = find_arg(argc, argv, "-check_mistakes");
|
check_mistakes = find_arg(argc, argv, "-check_mistakes");
|
||||||
|
int show_imgs = find_arg(argc, argv, "-show_imgs");
|
||||||
int mjpeg_port = find_int_arg(argc, argv, "-mjpeg_port", -1);
|
int mjpeg_port = find_int_arg(argc, argv, "-mjpeg_port", -1);
|
||||||
int json_port = find_int_arg(argc, argv, "-json_port", -1);
|
int json_port = find_int_arg(argc, argv, "-json_port", -1);
|
||||||
char *out_filename = find_char_arg(argc, argv, "-out_filename", 0);
|
char *out_filename = find_char_arg(argc, argv, "-out_filename", 0);
|
||||||
@ -1443,7 +1445,7 @@ void run_detector(int argc, char **argv)
|
|||||||
if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0;
|
if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0;
|
||||||
char *filename = (argc > 6) ? argv[6] : 0;
|
char *filename = (argc > 6) ? argv[6] : 0;
|
||||||
if (0 == strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels, outfile);
|
if (0 == strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels, outfile);
|
||||||
else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map, mjpeg_port);
|
else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map, mjpeg_port, show_imgs);
|
||||||
else if (0 == strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
|
else if (0 == strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
|
||||||
else if (0 == strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights);
|
else if (0 == strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights);
|
||||||
else if (0 == strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, map_points, NULL);
|
else if (0 == strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, map_points, NULL);
|
||||||
|
Reference in New Issue
Block a user