2017-06-02 06:31:13 +03:00
|
|
|
#include "darknet.h"
|
2016-08-06 01:27:07 +03:00
|
|
|
|
2016-11-11 19:48:40 +03:00
|
|
|
static int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90};
|
2016-08-06 01:27:07 +03:00
|
|
|
|
2017-12-26 21:52:21 +03:00
|
|
|
|
2016-11-11 19:48:40 +03:00
|
|
|
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
|
2016-08-06 01:27:07 +03:00
|
|
|
{
|
2016-11-06 00:09:21 +03:00
|
|
|
list *options = read_data_cfg(datacfg);
|
|
|
|
char *train_images = option_find_str(options, "train", "data/train.list");
|
|
|
|
char *backup_directory = option_find_str(options, "backup", "/backup/");
|
|
|
|
|
2016-08-06 01:27:07 +03:00
|
|
|
srand(time(0));
|
|
|
|
char *base = basecfg(cfgfile);
|
|
|
|
printf("%s\n", base);
|
|
|
|
float avg_loss = -1;
|
2017-10-17 21:41:34 +03:00
|
|
|
network **nets = calloc(ngpus, sizeof(network));
|
2016-11-11 19:48:40 +03:00
|
|
|
|
|
|
|
srand(time(0));
|
|
|
|
int seed = rand();
|
|
|
|
int i;
|
|
|
|
for(i = 0; i < ngpus; ++i){
|
|
|
|
srand(seed);
|
|
|
|
#ifdef GPU
|
|
|
|
cuda_set_device(gpus[i]);
|
|
|
|
#endif
|
2017-06-08 23:47:31 +03:00
|
|
|
nets[i] = load_network(cfgfile, weightfile, clear);
|
2017-10-17 21:41:34 +03:00
|
|
|
nets[i]->learning_rate *= ngpus;
|
2016-08-06 01:27:07 +03:00
|
|
|
}
|
2016-11-11 19:48:40 +03:00
|
|
|
srand(time(0));
|
2017-10-17 21:41:34 +03:00
|
|
|
network *net = nets[0];
|
2016-11-11 19:48:40 +03:00
|
|
|
|
2017-10-17 21:41:34 +03:00
|
|
|
int imgs = net->batch * net->subdivisions * ngpus;
|
|
|
|
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
2016-08-06 01:27:07 +03:00
|
|
|
data train, buffer;
|
|
|
|
|
2017-10-17 21:41:34 +03:00
|
|
|
layer l = net->layers[net->n - 1];
|
2016-08-06 01:27:07 +03:00
|
|
|
|
|
|
|
int classes = l.classes;
|
|
|
|
float jitter = l.jitter;
|
|
|
|
|
|
|
|
list *plist = get_paths(train_images);
|
|
|
|
//int N = plist->size;
|
|
|
|
char **paths = (char **)list_to_array(plist);
|
|
|
|
|
2017-07-12 02:44:09 +03:00
|
|
|
load_args args = get_base_args(net);
|
|
|
|
args.coords = l.coords;
|
2016-08-06 01:27:07 +03:00
|
|
|
args.paths = paths;
|
|
|
|
args.n = imgs;
|
|
|
|
args.m = plist->size;
|
|
|
|
args.classes = classes;
|
|
|
|
args.jitter = jitter;
|
|
|
|
args.num_boxes = l.max_boxes;
|
|
|
|
args.d = &buffer;
|
|
|
|
args.type = DETECTION_DATA;
|
2017-07-12 02:44:09 +03:00
|
|
|
//args.type = INSTANCE_DATA;
|
2017-10-03 01:17:48 +03:00
|
|
|
args.threads = 64;
|
2016-08-06 01:27:07 +03:00
|
|
|
|
2016-09-25 09:12:54 +03:00
|
|
|
pthread_t load_thread = load_data(args);
|
2017-10-03 01:17:48 +03:00
|
|
|
double time;
|
2016-11-16 09:53:58 +03:00
|
|
|
int count = 0;
|
2016-08-06 01:27:07 +03:00
|
|
|
//while(i*imgs < N*120){
|
2017-10-17 21:41:34 +03:00
|
|
|
while(get_current_batch(net) < net->max_batches){
|
2016-11-16 09:53:58 +03:00
|
|
|
if(l.random && count++%10 == 0){
|
|
|
|
printf("Resizing\n");
|
|
|
|
int dim = (rand() % 10 + 10) * 32;
|
2017-10-17 21:41:34 +03:00
|
|
|
if (get_current_batch(net)+200 > net->max_batches) dim = 608;
|
2016-11-16 09:53:58 +03:00
|
|
|
//int dim = (rand() % 4 + 16) * 32;
|
|
|
|
printf("%d\n", dim);
|
|
|
|
args.w = dim;
|
|
|
|
args.h = dim;
|
|
|
|
|
|
|
|
pthread_join(load_thread, 0);
|
|
|
|
train = buffer;
|
|
|
|
free_data(train);
|
|
|
|
load_thread = load_data(args);
|
|
|
|
|
2017-12-26 21:52:21 +03:00
|
|
|
#pragma omp parallel for
|
2016-11-16 09:53:58 +03:00
|
|
|
for(i = 0; i < ngpus; ++i){
|
2017-10-17 21:41:34 +03:00
|
|
|
resize_network(nets[i], dim, dim);
|
2016-11-16 09:53:58 +03:00
|
|
|
}
|
|
|
|
net = nets[0];
|
|
|
|
}
|
2017-10-03 01:17:48 +03:00
|
|
|
time=what_time_is_it_now();
|
2016-08-06 01:27:07 +03:00
|
|
|
pthread_join(load_thread, 0);
|
|
|
|
train = buffer;
|
2016-09-25 09:12:54 +03:00
|
|
|
load_thread = load_data(args);
|
2016-08-06 01:27:07 +03:00
|
|
|
|
2016-11-11 19:48:40 +03:00
|
|
|
/*
|
2017-12-26 21:52:21 +03:00
|
|
|
int k;
|
|
|
|
for(k = 0; k < l.max_boxes; ++k){
|
|
|
|
box b = float_to_box(train.y.vals[10] + 1 + k*5);
|
|
|
|
if(!b.x) break;
|
|
|
|
printf("loaded: %f %f %f %f\n", b.x, b.y, b.w, b.h);
|
|
|
|
}
|
|
|
|
*/
|
2017-03-27 09:42:30 +03:00
|
|
|
/*
|
2017-12-26 21:52:21 +03:00
|
|
|
int zz;
|
|
|
|
for(zz = 0; zz < train.X.cols; ++zz){
|
|
|
|
image im = float_to_image(net->w, net->h, 3, train.X.vals[zz]);
|
|
|
|
int k;
|
|
|
|
for(k = 0; k < l.max_boxes; ++k){
|
|
|
|
box b = float_to_box(train.y.vals[zz] + k*5, 1);
|
|
|
|
printf("%f %f %f %f\n", b.x, b.y, b.w, b.h);
|
|
|
|
draw_bbox(im, b, 1, 1,0,0);
|
|
|
|
}
|
|
|
|
show_image(im, "truth11");
|
|
|
|
cvWaitKey(0);
|
|
|
|
save_image(im, "truth11");
|
|
|
|
}
|
|
|
|
*/
|
2016-08-06 01:27:07 +03:00
|
|
|
|
2017-10-03 01:17:48 +03:00
|
|
|
printf("Loaded: %lf seconds\n", what_time_is_it_now()-time);
|
2016-08-06 01:27:07 +03:00
|
|
|
|
2017-10-03 01:17:48 +03:00
|
|
|
time=what_time_is_it_now();
|
2016-11-11 19:48:40 +03:00
|
|
|
float loss = 0;
|
|
|
|
#ifdef GPU
|
|
|
|
if(ngpus == 1){
|
|
|
|
loss = train_network(net, train);
|
|
|
|
} else {
|
|
|
|
loss = train_networks(nets, ngpus, train, 4);
|
|
|
|
}
|
|
|
|
#else
|
|
|
|
loss = train_network(net, train);
|
|
|
|
#endif
|
2016-08-06 01:27:07 +03:00
|
|
|
if (avg_loss < 0) avg_loss = loss;
|
|
|
|
avg_loss = avg_loss*.9 + loss*.1;
|
|
|
|
|
2016-11-11 19:48:40 +03:00
|
|
|
i = get_current_batch(net);
|
2017-10-03 01:17:48 +03:00
|
|
|
printf("%ld: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, i*imgs);
|
2017-07-12 02:44:09 +03:00
|
|
|
if(i%100==0){
|
2017-05-04 01:38:54 +03:00
|
|
|
#ifdef GPU
|
|
|
|
if(ngpus != 1) sync_nets(nets, ngpus, 0);
|
|
|
|
#endif
|
|
|
|
char buff[256];
|
|
|
|
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
|
|
|
save_weights(net, buff);
|
|
|
|
}
|
|
|
|
if(i%10000==0 || (i < 1000 && i%100 == 0)){
|
2016-11-16 11:15:46 +03:00
|
|
|
#ifdef GPU
|
2016-11-16 09:53:58 +03:00
|
|
|
if(ngpus != 1) sync_nets(nets, ngpus, 0);
|
2016-11-16 11:15:46 +03:00
|
|
|
#endif
|
2016-08-06 01:27:07 +03:00
|
|
|
char buff[256];
|
|
|
|
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
|
|
|
save_weights(net, buff);
|
|
|
|
}
|
|
|
|
free_data(train);
|
|
|
|
}
|
2016-11-16 11:15:46 +03:00
|
|
|
#ifdef GPU
|
2016-11-16 09:53:58 +03:00
|
|
|
if(ngpus != 1) sync_nets(nets, ngpus, 0);
|
2016-11-16 11:15:46 +03:00
|
|
|
#endif
|
2016-08-06 01:27:07 +03:00
|
|
|
char buff[256];
|
|
|
|
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
|
|
|
save_weights(net, buff);
|
|
|
|
}
|
|
|
|
|
2016-11-11 19:48:40 +03:00
|
|
|
|
|
|
|
static int get_coco_image_id(char *filename)
|
|
|
|
{
|
|
|
|
char *p = strrchr(filename, '_');
|
|
|
|
return atoi(p+1);
|
|
|
|
}
|
|
|
|
|
2017-12-26 21:52:21 +03:00
|
|
|
static void print_cocos(FILE *fp, char *image_path, detection *dets, int num_boxes, int classes, int w, int h)
|
2016-11-11 19:48:40 +03:00
|
|
|
{
|
|
|
|
int i, j;
|
|
|
|
int image_id = get_coco_image_id(image_path);
|
|
|
|
for(i = 0; i < num_boxes; ++i){
|
2017-12-26 21:52:21 +03:00
|
|
|
float xmin = dets[i].bbox.x - dets[i].bbox.w/2.;
|
|
|
|
float xmax = dets[i].bbox.x + dets[i].bbox.w/2.;
|
|
|
|
float ymin = dets[i].bbox.y - dets[i].bbox.h/2.;
|
|
|
|
float ymax = dets[i].bbox.y + dets[i].bbox.h/2.;
|
2016-11-11 19:48:40 +03:00
|
|
|
|
|
|
|
if (xmin < 0) xmin = 0;
|
|
|
|
if (ymin < 0) ymin = 0;
|
|
|
|
if (xmax > w) xmax = w;
|
|
|
|
if (ymax > h) ymax = h;
|
|
|
|
|
|
|
|
float bx = xmin;
|
|
|
|
float by = ymin;
|
|
|
|
float bw = xmax - xmin;
|
|
|
|
float bh = ymax - ymin;
|
|
|
|
|
|
|
|
for(j = 0; j < classes; ++j){
|
2017-12-26 21:52:21 +03:00
|
|
|
if (dets[i].prob[j]) fprintf(fp, "{\"image_id\":%d, \"category_id\":%d, \"bbox\":[%f, %f, %f, %f], \"score\":%f},\n", image_id, coco_ids[j], bx, by, bw, bh, dets[i].prob[j]);
|
2016-11-11 19:48:40 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-12-26 21:52:21 +03:00
|
|
|
void print_detector_detections(FILE **fps, char *id, detection *dets, int total, int classes, int w, int h)
|
2016-08-06 01:27:07 +03:00
|
|
|
{
|
|
|
|
int i, j;
|
|
|
|
for(i = 0; i < total; ++i){
|
2017-12-26 21:52:21 +03:00
|
|
|
float xmin = dets[i].bbox.x - dets[i].bbox.w/2. + 1;
|
|
|
|
float xmax = dets[i].bbox.x + dets[i].bbox.w/2. + 1;
|
|
|
|
float ymin = dets[i].bbox.y - dets[i].bbox.h/2. + 1;
|
|
|
|
float ymax = dets[i].bbox.y + dets[i].bbox.h/2. + 1;
|
2016-08-06 01:27:07 +03:00
|
|
|
|
2017-03-27 09:42:30 +03:00
|
|
|
if (xmin < 1) xmin = 1;
|
|
|
|
if (ymin < 1) ymin = 1;
|
2016-08-06 01:27:07 +03:00
|
|
|
if (xmax > w) xmax = w;
|
|
|
|
if (ymax > h) ymax = h;
|
|
|
|
|
|
|
|
for(j = 0; j < classes; ++j){
|
2017-12-26 21:52:21 +03:00
|
|
|
if (dets[i].prob[j]) fprintf(fps[j], "%s %f %f %f %f %f\n", id, dets[i].prob[j],
|
2016-08-06 01:27:07 +03:00
|
|
|
xmin, ymin, xmax, ymax);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-12-26 21:52:21 +03:00
|
|
|
void print_imagenet_detections(FILE *fp, int id, detection *dets, int total, int classes, int w, int h)
|
2016-11-16 09:53:58 +03:00
|
|
|
{
|
|
|
|
int i, j;
|
|
|
|
for(i = 0; i < total; ++i){
|
2017-12-26 21:52:21 +03:00
|
|
|
float xmin = dets[i].bbox.x - dets[i].bbox.w/2.;
|
|
|
|
float xmax = dets[i].bbox.x + dets[i].bbox.w/2.;
|
|
|
|
float ymin = dets[i].bbox.y - dets[i].bbox.h/2.;
|
|
|
|
float ymax = dets[i].bbox.y + dets[i].bbox.h/2.;
|
2016-11-16 09:53:58 +03:00
|
|
|
|
|
|
|
if (xmin < 0) xmin = 0;
|
|
|
|
if (ymin < 0) ymin = 0;
|
|
|
|
if (xmax > w) xmax = w;
|
|
|
|
if (ymax > h) ymax = h;
|
|
|
|
|
|
|
|
for(j = 0; j < classes; ++j){
|
|
|
|
int class = j;
|
2017-12-26 21:52:21 +03:00
|
|
|
if (dets[i].prob[class]) fprintf(fp, "%d %d %f %f %f %f %f\n", id, j+1, dets[i].prob[class],
|
2016-11-16 09:53:58 +03:00
|
|
|
xmin, ymin, xmax, ymax);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-03-27 09:42:30 +03:00
|
|
|
void validate_detector_flip(char *datacfg, char *cfgfile, char *weightfile, char *outfile)
|
|
|
|
{
|
|
|
|
int j;
|
|
|
|
list *options = read_data_cfg(datacfg);
|
|
|
|
char *valid_images = option_find_str(options, "valid", "data/train.list");
|
|
|
|
char *name_list = option_find_str(options, "names", "data/names.list");
|
|
|
|
char *prefix = option_find_str(options, "results", "results");
|
|
|
|
char **names = get_labels(name_list);
|
|
|
|
char *mapf = option_find_str(options, "map", 0);
|
|
|
|
int *map = 0;
|
|
|
|
if (mapf) map = read_map(mapf);
|
|
|
|
|
2017-10-17 21:41:34 +03:00
|
|
|
network *net = load_network(cfgfile, weightfile, 0);
|
|
|
|
set_batch_network(net, 2);
|
|
|
|
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
2017-03-27 09:42:30 +03:00
|
|
|
srand(time(0));
|
|
|
|
|
|
|
|
list *plist = get_paths(valid_images);
|
|
|
|
char **paths = (char **)list_to_array(plist);
|
|
|
|
|
2017-10-17 21:41:34 +03:00
|
|
|
layer l = net->layers[net->n-1];
|
2017-03-27 09:42:30 +03:00
|
|
|
int classes = l.classes;
|
|
|
|
|
|
|
|
char buff[1024];
|
|
|
|
char *type = option_find_str(options, "eval", "voc");
|
|
|
|
FILE *fp = 0;
|
|
|
|
FILE **fps = 0;
|
|
|
|
int coco = 0;
|
|
|
|
int imagenet = 0;
|
|
|
|
if(0==strcmp(type, "coco")){
|
|
|
|
if(!outfile) outfile = "coco_results";
|
|
|
|
snprintf(buff, 1024, "%s/%s.json", prefix, outfile);
|
|
|
|
fp = fopen(buff, "w");
|
|
|
|
fprintf(fp, "[\n");
|
|
|
|
coco = 1;
|
|
|
|
} else if(0==strcmp(type, "imagenet")){
|
|
|
|
if(!outfile) outfile = "imagenet-detection";
|
|
|
|
snprintf(buff, 1024, "%s/%s.txt", prefix, outfile);
|
|
|
|
fp = fopen(buff, "w");
|
|
|
|
imagenet = 1;
|
|
|
|
classes = 200;
|
|
|
|
} else {
|
|
|
|
if(!outfile) outfile = "comp4_det_test_";
|
|
|
|
fps = calloc(classes, sizeof(FILE *));
|
|
|
|
for(j = 0; j < classes; ++j){
|
|
|
|
snprintf(buff, 1024, "%s/%s%s.txt", prefix, outfile, names[j]);
|
|
|
|
fps[j] = fopen(buff, "w");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-12-26 21:52:21 +03:00
|
|
|
detection *dets = make_network_boxes(net);
|
2017-03-27 09:42:30 +03:00
|
|
|
|
|
|
|
int m = plist->size;
|
|
|
|
int i=0;
|
|
|
|
int t;
|
|
|
|
|
|
|
|
float thresh = .005;
|
|
|
|
float nms = .45;
|
|
|
|
|
|
|
|
int nthreads = 4;
|
|
|
|
image *val = calloc(nthreads, sizeof(image));
|
|
|
|
image *val_resized = calloc(nthreads, sizeof(image));
|
|
|
|
image *buf = calloc(nthreads, sizeof(image));
|
|
|
|
image *buf_resized = calloc(nthreads, sizeof(image));
|
|
|
|
pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
|
|
|
|
|
2017-10-17 21:41:34 +03:00
|
|
|
image input = make_image(net->w, net->h, net->c*2);
|
2017-03-27 09:42:30 +03:00
|
|
|
|
|
|
|
load_args args = {0};
|
2017-10-17 21:41:34 +03:00
|
|
|
args.w = net->w;
|
|
|
|
args.h = net->h;
|
2017-03-27 09:42:30 +03:00
|
|
|
//args.type = IMAGE_DATA;
|
|
|
|
args.type = LETTERBOX_DATA;
|
|
|
|
|
|
|
|
for(t = 0; t < nthreads; ++t){
|
|
|
|
args.path = paths[i+t];
|
|
|
|
args.im = &buf[t];
|
|
|
|
args.resized = &buf_resized[t];
|
|
|
|
thr[t] = load_data_in_thread(args);
|
|
|
|
}
|
2017-10-03 01:17:48 +03:00
|
|
|
double start = what_time_is_it_now();
|
2017-03-27 09:42:30 +03:00
|
|
|
for(i = nthreads; i < m+nthreads; i += nthreads){
|
|
|
|
fprintf(stderr, "%d\n", i);
|
|
|
|
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
|
|
|
pthread_join(thr[t], 0);
|
|
|
|
val[t] = buf[t];
|
|
|
|
val_resized[t] = buf_resized[t];
|
|
|
|
}
|
|
|
|
for(t = 0; t < nthreads && i+t < m; ++t){
|
|
|
|
args.path = paths[i+t];
|
|
|
|
args.im = &buf[t];
|
|
|
|
args.resized = &buf_resized[t];
|
|
|
|
thr[t] = load_data_in_thread(args);
|
|
|
|
}
|
|
|
|
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
|
|
|
char *path = paths[i+t-nthreads];
|
|
|
|
char *id = basecfg(path);
|
2017-10-17 21:41:34 +03:00
|
|
|
copy_cpu(net->w*net->h*net->c, val_resized[t].data, 1, input.data, 1);
|
2017-03-27 09:42:30 +03:00
|
|
|
flip_image(val_resized[t]);
|
2017-10-17 21:41:34 +03:00
|
|
|
copy_cpu(net->w*net->h*net->c, val_resized[t].data, 1, input.data + net->w*net->h*net->c, 1);
|
2017-03-27 09:42:30 +03:00
|
|
|
|
|
|
|
network_predict(net, input.data);
|
|
|
|
int w = val[t].w;
|
|
|
|
int h = val[t].h;
|
2017-12-26 21:52:21 +03:00
|
|
|
fill_network_boxes(net, w, h, thresh, .5, map, 0, dets);
|
|
|
|
if (nms) do_nms_sort(dets, l.w*l.h*l.n, classes, nms);
|
2017-03-27 09:42:30 +03:00
|
|
|
if (coco){
|
2017-12-26 21:52:21 +03:00
|
|
|
print_cocos(fp, path, dets, l.w*l.h*l.n, classes, w, h);
|
2017-03-27 09:42:30 +03:00
|
|
|
} else if (imagenet){
|
2017-12-26 21:52:21 +03:00
|
|
|
print_imagenet_detections(fp, i+t-nthreads+1, dets, l.w*l.h*l.n, classes, w, h);
|
2017-03-27 09:42:30 +03:00
|
|
|
} else {
|
2017-12-26 21:52:21 +03:00
|
|
|
print_detector_detections(fps, id, dets, l.w*l.h*l.n, classes, w, h);
|
2017-03-27 09:42:30 +03:00
|
|
|
}
|
|
|
|
free(id);
|
|
|
|
free_image(val[t]);
|
|
|
|
free_image(val_resized[t]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for(j = 0; j < classes; ++j){
|
|
|
|
if(fps) fclose(fps[j]);
|
|
|
|
}
|
|
|
|
if(coco){
|
|
|
|
fseek(fp, -2, SEEK_CUR);
|
|
|
|
fprintf(fp, "\n]\n");
|
|
|
|
fclose(fp);
|
|
|
|
}
|
2017-10-03 01:17:48 +03:00
|
|
|
fprintf(stderr, "Total Detection Time: %f Seconds\n", what_time_is_it_now() - start);
|
2017-03-27 09:42:30 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2017-01-04 15:44:00 +03:00
|
|
|
void validate_detector(char *datacfg, char *cfgfile, char *weightfile, char *outfile)
|
2016-08-06 01:27:07 +03:00
|
|
|
{
|
2016-11-27 07:02:46 +03:00
|
|
|
int j;
|
2016-11-06 00:09:21 +03:00
|
|
|
list *options = read_data_cfg(datacfg);
|
|
|
|
char *valid_images = option_find_str(options, "valid", "data/train.list");
|
|
|
|
char *name_list = option_find_str(options, "names", "data/names.list");
|
2016-11-11 19:48:40 +03:00
|
|
|
char *prefix = option_find_str(options, "results", "results");
|
2016-11-06 00:09:21 +03:00
|
|
|
char **names = get_labels(name_list);
|
2016-11-16 09:53:58 +03:00
|
|
|
char *mapf = option_find_str(options, "map", 0);
|
|
|
|
int *map = 0;
|
|
|
|
if (mapf) map = read_map(mapf);
|
2016-11-06 00:09:21 +03:00
|
|
|
|
2017-10-17 21:41:34 +03:00
|
|
|
network *net = load_network(cfgfile, weightfile, 0);
|
|
|
|
set_batch_network(net, 1);
|
|
|
|
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
2016-11-27 07:02:46 +03:00
|
|
|
srand(time(0));
|
|
|
|
|
|
|
|
list *plist = get_paths(valid_images);
|
|
|
|
char **paths = (char **)list_to_array(plist);
|
|
|
|
|
2017-10-17 21:41:34 +03:00
|
|
|
layer l = net->layers[net->n-1];
|
2016-11-27 07:02:46 +03:00
|
|
|
int classes = l.classes;
|
2016-11-11 19:48:40 +03:00
|
|
|
|
|
|
|
char buff[1024];
|
2016-11-16 09:53:58 +03:00
|
|
|
char *type = option_find_str(options, "eval", "voc");
|
|
|
|
FILE *fp = 0;
|
2016-11-27 07:02:46 +03:00
|
|
|
FILE **fps = 0;
|
2016-11-16 09:53:58 +03:00
|
|
|
int coco = 0;
|
|
|
|
int imagenet = 0;
|
|
|
|
if(0==strcmp(type, "coco")){
|
2017-01-04 15:44:00 +03:00
|
|
|
if(!outfile) outfile = "coco_results";
|
|
|
|
snprintf(buff, 1024, "%s/%s.json", prefix, outfile);
|
2016-11-16 09:53:58 +03:00
|
|
|
fp = fopen(buff, "w");
|
|
|
|
fprintf(fp, "[\n");
|
|
|
|
coco = 1;
|
|
|
|
} else if(0==strcmp(type, "imagenet")){
|
2017-01-04 15:44:00 +03:00
|
|
|
if(!outfile) outfile = "imagenet-detection";
|
|
|
|
snprintf(buff, 1024, "%s/%s.txt", prefix, outfile);
|
2016-11-16 09:53:58 +03:00
|
|
|
fp = fopen(buff, "w");
|
|
|
|
imagenet = 1;
|
2016-11-27 07:02:46 +03:00
|
|
|
classes = 200;
|
|
|
|
} else {
|
2017-01-04 15:44:00 +03:00
|
|
|
if(!outfile) outfile = "comp4_det_test_";
|
2016-11-27 07:02:46 +03:00
|
|
|
fps = calloc(classes, sizeof(FILE *));
|
|
|
|
for(j = 0; j < classes; ++j){
|
2017-01-04 15:44:00 +03:00
|
|
|
snprintf(buff, 1024, "%s/%s%s.txt", prefix, outfile, names[j]);
|
2016-11-27 07:02:46 +03:00
|
|
|
fps[j] = fopen(buff, "w");
|
|
|
|
}
|
2016-11-11 19:48:40 +03:00
|
|
|
}
|
|
|
|
|
2017-12-26 21:52:21 +03:00
|
|
|
detection *dets = make_network_boxes(net);
|
|
|
|
int nboxes = num_boxes(net);
|
2016-08-06 01:27:07 +03:00
|
|
|
|
|
|
|
int m = plist->size;
|
|
|
|
int i=0;
|
|
|
|
int t;
|
|
|
|
|
2016-11-16 09:53:58 +03:00
|
|
|
float thresh = .005;
|
|
|
|
float nms = .45;
|
2016-08-06 01:27:07 +03:00
|
|
|
|
2016-11-16 09:53:58 +03:00
|
|
|
int nthreads = 4;
|
2016-08-06 01:27:07 +03:00
|
|
|
image *val = calloc(nthreads, sizeof(image));
|
|
|
|
image *val_resized = calloc(nthreads, sizeof(image));
|
|
|
|
image *buf = calloc(nthreads, sizeof(image));
|
|
|
|
image *buf_resized = calloc(nthreads, sizeof(image));
|
|
|
|
pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
|
|
|
|
|
|
|
|
load_args args = {0};
|
2017-10-17 21:41:34 +03:00
|
|
|
args.w = net->w;
|
|
|
|
args.h = net->h;
|
2017-03-27 09:42:30 +03:00
|
|
|
//args.type = IMAGE_DATA;
|
|
|
|
args.type = LETTERBOX_DATA;
|
2016-08-06 01:27:07 +03:00
|
|
|
|
|
|
|
for(t = 0; t < nthreads; ++t){
|
|
|
|
args.path = paths[i+t];
|
|
|
|
args.im = &buf[t];
|
|
|
|
args.resized = &buf_resized[t];
|
|
|
|
thr[t] = load_data_in_thread(args);
|
|
|
|
}
|
2017-10-03 01:17:48 +03:00
|
|
|
double start = what_time_is_it_now();
|
2016-08-06 01:27:07 +03:00
|
|
|
for(i = nthreads; i < m+nthreads; i += nthreads){
|
|
|
|
fprintf(stderr, "%d\n", i);
|
|
|
|
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
|
|
|
pthread_join(thr[t], 0);
|
|
|
|
val[t] = buf[t];
|
|
|
|
val_resized[t] = buf_resized[t];
|
|
|
|
}
|
|
|
|
for(t = 0; t < nthreads && i+t < m; ++t){
|
|
|
|
args.path = paths[i+t];
|
|
|
|
args.im = &buf[t];
|
|
|
|
args.resized = &buf_resized[t];
|
|
|
|
thr[t] = load_data_in_thread(args);
|
|
|
|
}
|
|
|
|
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
|
|
|
char *path = paths[i+t-nthreads];
|
|
|
|
char *id = basecfg(path);
|
|
|
|
float *X = val_resized[t].data;
|
2016-09-25 09:12:54 +03:00
|
|
|
network_predict(net, X);
|
2016-08-06 01:27:07 +03:00
|
|
|
int w = val[t].w;
|
|
|
|
int h = val[t].h;
|
2017-12-26 21:52:21 +03:00
|
|
|
fill_network_boxes(net, w, h, thresh, .5, map, 0, dets);
|
|
|
|
if (nms) do_nms_sort(dets, nboxes, classes, nms);
|
2016-11-16 09:53:58 +03:00
|
|
|
if (coco){
|
2017-12-26 21:52:21 +03:00
|
|
|
print_cocos(fp, path, dets, nboxes, classes, w, h);
|
2016-11-16 09:53:58 +03:00
|
|
|
} else if (imagenet){
|
2017-12-26 21:52:21 +03:00
|
|
|
print_imagenet_detections(fp, i+t-nthreads+1, dets, nboxes, classes, w, h);
|
2016-11-16 09:53:58 +03:00
|
|
|
} else {
|
2017-12-26 21:52:21 +03:00
|
|
|
print_detector_detections(fps, id, dets, nboxes, classes, w, h);
|
2016-11-11 19:48:40 +03:00
|
|
|
}
|
2016-08-06 01:27:07 +03:00
|
|
|
free(id);
|
|
|
|
free_image(val[t]);
|
|
|
|
free_image(val_resized[t]);
|
|
|
|
}
|
|
|
|
}
|
2016-09-08 08:27:56 +03:00
|
|
|
for(j = 0; j < classes; ++j){
|
2016-11-27 07:02:46 +03:00
|
|
|
if(fps) fclose(fps[j]);
|
2016-09-08 08:27:56 +03:00
|
|
|
}
|
2016-11-16 09:53:58 +03:00
|
|
|
if(coco){
|
|
|
|
fseek(fp, -2, SEEK_CUR);
|
|
|
|
fprintf(fp, "\n]\n");
|
|
|
|
fclose(fp);
|
2016-11-11 19:48:40 +03:00
|
|
|
}
|
2017-10-03 01:17:48 +03:00
|
|
|
fprintf(stderr, "Total Detection Time: %f Seconds\n", what_time_is_it_now() - start);
|
2016-08-06 01:27:07 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
void validate_detector_recall(char *cfgfile, char *weightfile)
|
|
|
|
{
|
2017-10-17 21:41:34 +03:00
|
|
|
network *net = load_network(cfgfile, weightfile, 0);
|
|
|
|
set_batch_network(net, 1);
|
|
|
|
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
2016-08-06 01:27:07 +03:00
|
|
|
srand(time(0));
|
|
|
|
|
2017-07-22 05:33:49 +03:00
|
|
|
list *plist = get_paths("data/coco_val_5k.list");
|
2016-08-06 01:27:07 +03:00
|
|
|
char **paths = (char **)list_to_array(plist);
|
|
|
|
|
2017-10-17 21:41:34 +03:00
|
|
|
layer l = net->layers[net->n-1];
|
2016-08-06 01:27:07 +03:00
|
|
|
|
|
|
|
int j, k;
|
2017-12-26 21:52:21 +03:00
|
|
|
detection *dets = make_network_boxes(net);
|
2016-08-06 01:27:07 +03:00
|
|
|
|
|
|
|
int m = plist->size;
|
|
|
|
int i=0;
|
|
|
|
|
|
|
|
float thresh = .001;
|
|
|
|
float iou_thresh = .5;
|
|
|
|
float nms = .4;
|
|
|
|
|
|
|
|
int total = 0;
|
|
|
|
int correct = 0;
|
|
|
|
int proposals = 0;
|
|
|
|
float avg_iou = 0;
|
2017-12-26 21:52:21 +03:00
|
|
|
int nboxes = num_boxes(net);
|
2016-08-06 01:27:07 +03:00
|
|
|
|
|
|
|
for(i = 0; i < m; ++i){
|
|
|
|
char *path = paths[i];
|
|
|
|
image orig = load_image_color(path, 0, 0);
|
2017-10-17 21:41:34 +03:00
|
|
|
image sized = resize_image(orig, net->w, net->h);
|
2016-08-06 01:27:07 +03:00
|
|
|
char *id = basecfg(path);
|
2016-09-25 09:12:54 +03:00
|
|
|
network_predict(net, sized.data);
|
2017-12-26 21:52:21 +03:00
|
|
|
fill_network_boxes(net, sized.w, sized.h, thresh, .5, 0, 1, dets);
|
|
|
|
if (nms) do_nms_obj(dets, nboxes, 1, nms);
|
2016-08-06 01:27:07 +03:00
|
|
|
|
2016-09-25 09:12:54 +03:00
|
|
|
char labelpath[4096];
|
|
|
|
find_replace(path, "images", "labels", labelpath);
|
|
|
|
find_replace(labelpath, "JPEGImages", "labels", labelpath);
|
|
|
|
find_replace(labelpath, ".jpg", ".txt", labelpath);
|
|
|
|
find_replace(labelpath, ".JPEG", ".txt", labelpath);
|
2016-08-06 01:27:07 +03:00
|
|
|
|
|
|
|
int num_labels = 0;
|
|
|
|
box_label *truth = read_boxes(labelpath, &num_labels);
|
2017-12-26 21:52:21 +03:00
|
|
|
for(k = 0; k < nboxes; ++k){
|
|
|
|
if(dets[k].objectness > thresh){
|
2016-08-06 01:27:07 +03:00
|
|
|
++proposals;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (j = 0; j < num_labels; ++j) {
|
|
|
|
++total;
|
|
|
|
box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h};
|
|
|
|
float best_iou = 0;
|
2016-09-25 09:12:54 +03:00
|
|
|
for(k = 0; k < l.w*l.h*l.n; ++k){
|
2017-12-26 21:52:21 +03:00
|
|
|
float iou = box_iou(dets[k].bbox, t);
|
|
|
|
if(dets[k].objectness > thresh && iou > best_iou){
|
2016-08-06 01:27:07 +03:00
|
|
|
best_iou = iou;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
avg_iou += best_iou;
|
|
|
|
if(best_iou > iou_thresh){
|
|
|
|
++correct;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total);
|
|
|
|
free(id);
|
|
|
|
free_image(orig);
|
|
|
|
free_image(sized);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-12-26 21:52:21 +03:00
|
|
|
|
2017-04-18 03:18:08 +03:00
|
|
|
void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, float hier_thresh, char *outfile, int fullscreen)
|
2016-08-06 01:27:07 +03:00
|
|
|
{
|
2016-11-06 00:09:21 +03:00
|
|
|
list *options = read_data_cfg(datacfg);
|
2016-11-11 19:48:40 +03:00
|
|
|
char *name_list = option_find_str(options, "names", "data/names.list");
|
|
|
|
char **names = get_labels(name_list);
|
2016-11-06 00:09:21 +03:00
|
|
|
|
|
|
|
image **alphabet = load_alphabet();
|
2017-10-17 21:41:34 +03:00
|
|
|
network *net = load_network(cfgfile, weightfile, 0);
|
|
|
|
set_batch_network(net, 1);
|
2016-08-06 01:27:07 +03:00
|
|
|
srand(2222222);
|
2017-07-12 02:44:09 +03:00
|
|
|
double time;
|
2016-08-06 01:27:07 +03:00
|
|
|
char buff[256];
|
|
|
|
char *input = buff;
|
2017-07-12 02:44:09 +03:00
|
|
|
float nms=.3;
|
2016-08-06 01:27:07 +03:00
|
|
|
while(1){
|
|
|
|
if(filename){
|
|
|
|
strncpy(input, filename, 256);
|
|
|
|
} else {
|
|
|
|
printf("Enter Image Path: ");
|
|
|
|
fflush(stdout);
|
|
|
|
input = fgets(input, 256, stdin);
|
|
|
|
if(!input) return;
|
|
|
|
strtok(input, "\n");
|
|
|
|
}
|
|
|
|
image im = load_image_color(input,0,0);
|
2017-10-17 21:41:34 +03:00
|
|
|
image sized = letterbox_image(im, net->w, net->h);
|
|
|
|
//image sized = resize_image(im, net->w, net->h);
|
|
|
|
//image sized2 = resize_max(im, net->w);
|
|
|
|
//image sized = crop_image(sized2, -((net->w - sized2.w)/2), -((net->h - sized2.h)/2), net->w, net->h);
|
|
|
|
//resize_network(net, sized.w, sized.h);
|
|
|
|
layer l = net->layers[net->n-1];
|
2016-11-19 08:51:36 +03:00
|
|
|
|
2017-12-26 21:52:21 +03:00
|
|
|
int nboxes = num_boxes(net);
|
|
|
|
printf("%d\n", nboxes);
|
2016-11-19 08:51:36 +03:00
|
|
|
|
2016-08-06 01:27:07 +03:00
|
|
|
float *X = sized.data;
|
2017-07-12 02:44:09 +03:00
|
|
|
time=what_time_is_it_now();
|
2016-09-25 09:12:54 +03:00
|
|
|
network_predict(net, X);
|
2017-07-12 02:44:09 +03:00
|
|
|
printf("%s: Predicted in %f seconds.\n", input, what_time_is_it_now()-time);
|
2017-12-26 21:52:21 +03:00
|
|
|
detection *dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, 0, 1);
|
2017-11-08 03:10:33 +03:00
|
|
|
//if (nms) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms);
|
2017-12-26 21:52:21 +03:00
|
|
|
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
|
|
|
|
draw_detections(im, dets, nboxes, thresh, names, alphabet, l.classes);
|
|
|
|
free_detections(dets, num_boxes(net));
|
2017-04-10 05:56:42 +03:00
|
|
|
if(outfile){
|
2017-04-18 00:53:48 +03:00
|
|
|
save_image(im, outfile);
|
2017-04-10 05:56:42 +03:00
|
|
|
}
|
|
|
|
else{
|
2017-04-18 00:53:48 +03:00
|
|
|
save_image(im, "predictions");
|
2017-04-10 05:56:42 +03:00
|
|
|
#ifdef OPENCV
|
2017-04-18 03:18:08 +03:00
|
|
|
cvNamedWindow("predictions", CV_WINDOW_NORMAL);
|
|
|
|
if(fullscreen){
|
|
|
|
cvSetWindowProperty("predictions", CV_WND_PROP_FULLSCREEN, CV_WINDOW_FULLSCREEN);
|
|
|
|
}
|
|
|
|
show_image(im, "predictions");
|
2017-04-10 05:56:42 +03:00
|
|
|
cvWaitKey(0);
|
|
|
|
cvDestroyAllWindows();
|
|
|
|
#endif
|
|
|
|
}
|
2016-08-06 01:27:07 +03:00
|
|
|
|
|
|
|
free_image(im);
|
|
|
|
free_image(sized);
|
|
|
|
if (filename) break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-12-26 21:52:21 +03:00
|
|
|
void network_detect(network *net, image im, float thresh, float hier_thresh, float nms, detection *dets)
|
|
|
|
{
|
|
|
|
network_predict_image(net, im);
|
|
|
|
layer l = net->layers[net->n-1];
|
|
|
|
int nboxes = num_boxes(net);
|
|
|
|
fill_network_boxes(net, im.w, im.h, thresh, hier_thresh, 0, 0, dets);
|
|
|
|
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
|
|
|
|
}
|
|
|
|
|
2016-08-06 01:27:07 +03:00
|
|
|
void run_detector(int argc, char **argv)
|
|
|
|
{
|
2016-09-25 09:12:54 +03:00
|
|
|
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
|
2016-11-27 08:21:04 +03:00
|
|
|
float thresh = find_float_arg(argc, argv, "-thresh", .24);
|
2017-01-04 15:44:00 +03:00
|
|
|
float hier_thresh = find_float_arg(argc, argv, "-hier", .5);
|
2016-09-25 09:12:54 +03:00
|
|
|
int cam_index = find_int_arg(argc, argv, "-c", 0);
|
|
|
|
int frame_skip = find_int_arg(argc, argv, "-s", 0);
|
2017-04-30 23:54:40 +03:00
|
|
|
int avg = find_int_arg(argc, argv, "-avg", 3);
|
2016-08-06 01:27:07 +03:00
|
|
|
if(argc < 4){
|
|
|
|
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
|
|
|
return;
|
|
|
|
}
|
2016-11-11 19:48:40 +03:00
|
|
|
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
|
2017-01-04 15:44:00 +03:00
|
|
|
char *outfile = find_char_arg(argc, argv, "-out", 0);
|
2016-11-11 19:48:40 +03:00
|
|
|
int *gpus = 0;
|
|
|
|
int gpu = 0;
|
|
|
|
int ngpus = 0;
|
|
|
|
if(gpu_list){
|
|
|
|
printf("%s\n", gpu_list);
|
|
|
|
int len = strlen(gpu_list);
|
|
|
|
ngpus = 1;
|
|
|
|
int i;
|
|
|
|
for(i = 0; i < len; ++i){
|
|
|
|
if (gpu_list[i] == ',') ++ngpus;
|
|
|
|
}
|
|
|
|
gpus = calloc(ngpus, sizeof(int));
|
|
|
|
for(i = 0; i < ngpus; ++i){
|
|
|
|
gpus[i] = atoi(gpu_list);
|
|
|
|
gpu_list = strchr(gpu_list, ',')+1;
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
gpu = gpu_index;
|
|
|
|
gpus = &gpu;
|
|
|
|
ngpus = 1;
|
|
|
|
}
|
|
|
|
|
2016-11-06 00:09:21 +03:00
|
|
|
int clear = find_arg(argc, argv, "-clear");
|
2017-04-18 00:53:48 +03:00
|
|
|
int fullscreen = find_arg(argc, argv, "-fullscreen");
|
|
|
|
int width = find_int_arg(argc, argv, "-w", 0);
|
2017-04-18 02:23:50 +03:00
|
|
|
int height = find_int_arg(argc, argv, "-h", 0);
|
2017-04-18 00:53:48 +03:00
|
|
|
int fps = find_int_arg(argc, argv, "-fps", 0);
|
2016-11-06 00:09:21 +03:00
|
|
|
|
|
|
|
char *datacfg = argv[3];
|
|
|
|
char *cfg = argv[4];
|
|
|
|
char *weights = (argc > 5) ? argv[5] : 0;
|
|
|
|
char *filename = (argc > 6) ? argv[6]: 0;
|
2017-04-18 03:18:08 +03:00
|
|
|
if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, outfile, fullscreen);
|
2016-11-11 19:48:40 +03:00
|
|
|
else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear);
|
2017-01-04 15:44:00 +03:00
|
|
|
else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
|
2017-03-27 09:42:30 +03:00
|
|
|
else if(0==strcmp(argv[2], "valid2")) validate_detector_flip(datacfg, cfg, weights, outfile);
|
2016-08-06 01:27:07 +03:00
|
|
|
else if(0==strcmp(argv[2], "recall")) validate_detector_recall(cfg, weights);
|
2016-11-06 00:09:21 +03:00
|
|
|
else if(0==strcmp(argv[2], "demo")) {
|
|
|
|
list *options = read_data_cfg(datacfg);
|
|
|
|
int classes = option_find_int(options, "classes", 20);
|
|
|
|
char *name_list = option_find_str(options, "names", "data/names.list");
|
|
|
|
char **names = get_labels(name_list);
|
2017-04-30 23:54:40 +03:00
|
|
|
demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix, avg, hier_thresh, width, height, fps, fullscreen);
|
2016-11-06 00:09:21 +03:00
|
|
|
}
|
2016-08-06 01:27:07 +03:00
|
|
|
}
|