It's time, to du-du-du-du-DU-DU-DUEL!!

https://www.youtube.com/watch?v=IVmtUK_1jh4
This commit is contained in:
Joseph Redmon 2015-04-20 08:43:54 -07:00
parent f199fd3b64
commit 451ef0a0a6
5 changed files with 82 additions and 37 deletions

View File

@ -408,6 +408,31 @@ pthread_t load_data_thread(char **paths, int n, int m, char **labels, int k, int
return thread;
}
matrix concat_matrix(matrix m1, matrix m2)
{
int i, count = 0;
matrix m;
m.cols = m1.cols;
m.rows = m1.rows+m2.rows;
m.vals = calloc(m1.rows + m2.rows, sizeof(float*));
for(i = 0; i < m1.rows; ++i){
m.vals[count++] = m1.vals[i];
}
for(i = 0; i < m2.rows; ++i){
m.vals[count++] = m2.vals[i];
}
return m;
}
data concat_data(data d1, data d2)
{
data d;
d.shallow = 1;
d.X = concat_matrix(d1.X, d2.X);
d.y = concat_matrix(d1.y, d2.y);
return d;
}
data load_categorical_data_csv(char *filename, int target, int k)
{
data d;

View File

@ -50,5 +50,6 @@ void scale_data_rows(data d, float s);
void translate_data_rows(data d, float s);
void randomize_data(data d);
data *split_data(data d, int part, int total);
data concat_data(data d1, data d2);
#endif

View File

@ -81,9 +81,9 @@ void train_detection(char *cfgfile, char *weightfile)
if (imgnet){
plist = get_paths("/home/pjreddie/data/imagenet/det.train.list");
}else{
plist = get_paths("/home/pjreddie/data/voc/trainall.txt");
//plist = get_paths("/home/pjreddie/data/voc/trainall.txt");
//plist = get_paths("/home/pjreddie/data/coco/trainval.txt");
//plist = get_paths("/home/pjreddie/data/voc/all2007-2012.txt");
plist = get_paths("/home/pjreddie/data/voc/all2007-2012.txt");
}
paths = (char **)list_to_array(plist);
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
@ -118,6 +118,34 @@ void train_detection(char *cfgfile, char *weightfile)
}
}
void predict_detections(network net, data d, float threshold, int offset, int classes, int nuisance, int background, int num_boxes, int per_box)
{
matrix pred = network_predict_data(net, d);
int j, k, class;
for(j = 0; j < pred.rows; ++j){
for(k = 0; k < pred.cols; k += per_box){
float scale = 1.;
int index = k/per_box;
int row = index / num_boxes;
int col = index % num_boxes;
if (nuisance) scale = 1.-pred.vals[j][k];
for (class = 0; class < classes; ++class){
int ci = k+classes+background+nuisance;
float y = (pred.vals[j][ci + 0] + row)/num_boxes;
float x = (pred.vals[j][ci + 1] + col)/num_boxes;
float h = pred.vals[j][ci + 2]; //* distance_from_edge(row, num_boxes);
h = h*h;
float w = pred.vals[j][ci + 3]; //* distance_from_edge(col, num_boxes);
w = w*w;
float prob = scale*pred.vals[j][k+class+background+nuisance];
if(prob < threshold) continue;
printf("%d %d %f %f %f %f %f\n", offset + j, class, prob, y, x, h, w);
}
}
}
free_matrix(pred);
}
void validate_detection(char *cfgfile, char *weightfile)
{
network net = parse_network_cfg(cfgfile);
@ -144,47 +172,37 @@ void validate_detection(char *cfgfile, char *weightfile)
int m = plist->size;
int i = 0;
int splits = 100;
int num = (i+1)*m/splits - i*m/splits;
fprintf(stderr, "%d\n", m);
data val, buffer;
pthread_t load_thread = load_data_thread(paths, num, 0, 0, num_output, net.w, net.h, &buffer);
int nthreads = 4;
int t;
data *val = calloc(nthreads, sizeof(data));
data *buf = calloc(nthreads, sizeof(data));
pthread_t *thr = calloc(nthreads, sizeof(data));
for(t = 0; t < nthreads; ++t){
int num = (i+1+t)*m/splits - (i+t)*m/splits;
char **part = paths+((i+t)*m/splits);
thr[t] = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &(buf[t]));
}
clock_t time;
for(i = 1; i <= splits; ++i){
for(i = nthreads; i <= splits; i += nthreads){
time=clock();
pthread_join(load_thread, 0);
val = buffer;
num = (i+1)*m/splits - i*m/splits;
char **part = paths+(i*m/splits);
if(i != splits) load_thread = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &buffer);
for(t = 0; t < nthreads; ++t){
pthread_join(thr[t], 0);
val[t] = buf[t];
}
for(t = 0; t < nthreads && i < splits; ++t){
int num = (i+1+t)*m/splits - (i+t)*m/splits;
char **part = paths+((i+t)*m/splits);
thr[t] = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &(buf[t]));
}
fprintf(stderr, "%d: Loaded: %lf seconds\n", i, sec(clock()-time));
matrix pred = network_predict_data(net, val);
int j, k, class;
for(j = 0; j < pred.rows; ++j){
for(k = 0; k < pred.cols; k += per_box){
float scale = 1.;
int index = k/per_box;
int row = index / num_boxes;
int col = index % num_boxes;
if (nuisance) scale = 1.-pred.vals[j][k];
for (class = 0; class < classes; ++class){
int ci = k+classes+background+nuisance;
float y = (pred.vals[j][ci + 0] + row)/num_boxes;
float x = (pred.vals[j][ci + 1] + col)/num_boxes;
float h = pred.vals[j][ci + 2]; //* distance_from_edge(row, num_boxes);
h = h*h;
float w = pred.vals[j][ci + 3]; //* distance_from_edge(col, num_boxes);
w = w*w;
float prob = scale*pred.vals[j][k+class+background+nuisance];
if(prob < .001) continue;
printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, prob, y, x, h, w);
}
}
for(t = 0; t < nthreads; ++t){
predict_detections(net, val[t], .01, (i-nthreads+t)*m/splits, classes, nuisance, background, num_boxes, per_box);
free_data(val[t]);
}
time=clock();
free_data(val);
}
}

View File

@ -603,12 +603,12 @@ image load_image_color(char *filename, int w, int h)
exit(0);
}
image out = ipl_to_image(src);
cvReleaseImage(&src);
if((h && w) && (h != out.h || w != out.w)){
image resized = resize_image(out, w, h);
free_image(out);
out = resized;
}
cvReleaseImage(&src);
return out;
}

View File

@ -623,6 +623,7 @@ void load_weights_upto(network *net, char *filename, int cutoff)
fread(&net->momentum, sizeof(float), 1, fp);
fread(&net->decay, sizeof(float), 1, fp);
fread(&net->seen, sizeof(int), 1, fp);
fprintf(stderr, "%f %f %f %d\n", net->learning_rate, net->momentum, net->decay, net->seen);
int i;
for(i = 0; i < net->n && i < cutoff; ++i){