From 451ef0a0a6b595bb8e4a945633659b4d31f0a372 Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Mon, 20 Apr 2015 08:43:54 -0700 Subject: [PATCH] It's time, to du-du-du-du-DU-DU-DUEL!! https://www.youtube.com/watch?v=IVmtUK_1jh4 --- src/data.c | 25 ++++++++++++++ src/data.h | 1 + src/detection.c | 90 +++++++++++++++++++++++++++++-------------------- src/image.c | 2 +- src/parser.c | 1 + 5 files changed, 82 insertions(+), 37 deletions(-) diff --git a/src/data.c b/src/data.c index 012d7cfd..2b74386f 100644 --- a/src/data.c +++ b/src/data.c @@ -408,6 +408,31 @@ pthread_t load_data_thread(char **paths, int n, int m, char **labels, int k, int return thread; } +matrix concat_matrix(matrix m1, matrix m2) +{ + int i, count = 0; + matrix m; + m.cols = m1.cols; + m.rows = m1.rows+m2.rows; + m.vals = calloc(m1.rows + m2.rows, sizeof(float*)); + for(i = 0; i < m1.rows; ++i){ + m.vals[count++] = m1.vals[i]; + } + for(i = 0; i < m2.rows; ++i){ + m.vals[count++] = m2.vals[i]; + } + return m; +} + +data concat_data(data d1, data d2) +{ + data d; + d.shallow = 1; + d.X = concat_matrix(d1.X, d2.X); + d.y = concat_matrix(d1.y, d2.y); + return d; +} + data load_categorical_data_csv(char *filename, int target, int k) { data d; diff --git a/src/data.h b/src/data.h index 8e3e1d91..22fd248e 100644 --- a/src/data.h +++ b/src/data.h @@ -50,5 +50,6 @@ void scale_data_rows(data d, float s); void translate_data_rows(data d, float s); void randomize_data(data d); data *split_data(data d, int part, int total); +data concat_data(data d1, data d2); #endif diff --git a/src/detection.c b/src/detection.c index cba3d181..2610b486 100644 --- a/src/detection.c +++ b/src/detection.c @@ -81,9 +81,9 @@ void train_detection(char *cfgfile, char *weightfile) if (imgnet){ plist = get_paths("/home/pjreddie/data/imagenet/det.train.list"); }else{ - plist = get_paths("/home/pjreddie/data/voc/trainall.txt"); + //plist = get_paths("/home/pjreddie/data/voc/trainall.txt"); //plist = get_paths("/home/pjreddie/data/coco/trainval.txt"); - //plist = get_paths("/home/pjreddie/data/voc/all2007-2012.txt"); + plist = get_paths("/home/pjreddie/data/voc/all2007-2012.txt"); } paths = (char **)list_to_array(plist); pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer); @@ -118,6 +118,34 @@ void train_detection(char *cfgfile, char *weightfile) } } +void predict_detections(network net, data d, float threshold, int offset, int classes, int nuisance, int background, int num_boxes, int per_box) +{ + matrix pred = network_predict_data(net, d); + int j, k, class; + for(j = 0; j < pred.rows; ++j){ + for(k = 0; k < pred.cols; k += per_box){ + float scale = 1.; + int index = k/per_box; + int row = index / num_boxes; + int col = index % num_boxes; + if (nuisance) scale = 1.-pred.vals[j][k]; + for (class = 0; class < classes; ++class){ + int ci = k+classes+background+nuisance; + float y = (pred.vals[j][ci + 0] + row)/num_boxes; + float x = (pred.vals[j][ci + 1] + col)/num_boxes; + float h = pred.vals[j][ci + 2]; //* distance_from_edge(row, num_boxes); + h = h*h; + float w = pred.vals[j][ci + 3]; //* distance_from_edge(col, num_boxes); + w = w*w; + float prob = scale*pred.vals[j][k+class+background+nuisance]; + if(prob < threshold) continue; + printf("%d %d %f %f %f %f %f\n", offset + j, class, prob, y, x, h, w); + } + } + } + free_matrix(pred); +} + void validate_detection(char *cfgfile, char *weightfile) { network net = parse_network_cfg(cfgfile); @@ -144,47 +172,37 @@ void validate_detection(char *cfgfile, char *weightfile) int m = plist->size; int i = 0; int splits = 100; - int num = (i+1)*m/splits - i*m/splits; - fprintf(stderr, "%d\n", m); - data val, buffer; - pthread_t load_thread = load_data_thread(paths, num, 0, 0, num_output, net.w, net.h, &buffer); + int nthreads = 4; + int t; + data *val = calloc(nthreads, sizeof(data)); + data *buf = calloc(nthreads, sizeof(data)); + pthread_t *thr = calloc(nthreads, sizeof(data)); + for(t = 0; t < nthreads; ++t){ + int num = (i+1+t)*m/splits - (i+t)*m/splits; + char **part = paths+((i+t)*m/splits); + thr[t] = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &(buf[t])); + } + clock_t time; - for(i = 1; i <= splits; ++i){ + for(i = nthreads; i <= splits; i += nthreads){ time=clock(); - pthread_join(load_thread, 0); - val = buffer; - - num = (i+1)*m/splits - i*m/splits; - char **part = paths+(i*m/splits); - if(i != splits) load_thread = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &buffer); + for(t = 0; t < nthreads; ++t){ + pthread_join(thr[t], 0); + val[t] = buf[t]; + } + for(t = 0; t < nthreads && i < splits; ++t){ + int num = (i+1+t)*m/splits - (i+t)*m/splits; + char **part = paths+((i+t)*m/splits); + thr[t] = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &(buf[t])); + } fprintf(stderr, "%d: Loaded: %lf seconds\n", i, sec(clock()-time)); - matrix pred = network_predict_data(net, val); - int j, k, class; - for(j = 0; j < pred.rows; ++j){ - for(k = 0; k < pred.cols; k += per_box){ - float scale = 1.; - int index = k/per_box; - int row = index / num_boxes; - int col = index % num_boxes; - if (nuisance) scale = 1.-pred.vals[j][k]; - for (class = 0; class < classes; ++class){ - int ci = k+classes+background+nuisance; - float y = (pred.vals[j][ci + 0] + row)/num_boxes; - float x = (pred.vals[j][ci + 1] + col)/num_boxes; - float h = pred.vals[j][ci + 2]; //* distance_from_edge(row, num_boxes); - h = h*h; - float w = pred.vals[j][ci + 3]; //* distance_from_edge(col, num_boxes); - w = w*w; - float prob = scale*pred.vals[j][k+class+background+nuisance]; - if(prob < .001) continue; - printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, prob, y, x, h, w); - } - } + for(t = 0; t < nthreads; ++t){ + predict_detections(net, val[t], .01, (i-nthreads+t)*m/splits, classes, nuisance, background, num_boxes, per_box); + free_data(val[t]); } time=clock(); - free_data(val); } } diff --git a/src/image.c b/src/image.c index 1daea278..7509ce53 100644 --- a/src/image.c +++ b/src/image.c @@ -603,12 +603,12 @@ image load_image_color(char *filename, int w, int h) exit(0); } image out = ipl_to_image(src); + cvReleaseImage(&src); if((h && w) && (h != out.h || w != out.w)){ image resized = resize_image(out, w, h); free_image(out); out = resized; } - cvReleaseImage(&src); return out; } diff --git a/src/parser.c b/src/parser.c index ca60ef78..0f13d777 100644 --- a/src/parser.c +++ b/src/parser.c @@ -623,6 +623,7 @@ void load_weights_upto(network *net, char *filename, int cutoff) fread(&net->momentum, sizeof(float), 1, fp); fread(&net->decay, sizeof(float), 1, fp); fread(&net->seen, sizeof(int), 1, fp); + fprintf(stderr, "%f %f %f %d\n", net->learning_rate, net->momentum, net->decay, net->seen); int i; for(i = 0; i < net->n && i < cutoff; ++i){