this'll teach me to mess with maxpooling

This commit is contained in:
Joseph Redmon
2018-08-03 15:57:48 -07:00
parent e209b3bbbf
commit b13f67bfdd
23 changed files with 737 additions and 130 deletions

View File

@@ -645,6 +645,45 @@ void label_classifier(char *datacfg, char *filename, char *weightfile)
}
}
void csv_classifier(char *datacfg, char *cfgfile, char *weightfile)
{
int i,j;
network *net = load_network(cfgfile, weightfile, 0);
srand(time(0));
list *options = read_data_cfg(datacfg);
char *test_list = option_find_str(options, "test", "data/test.list");
int top = option_find_int(options, "top", 1);
list *plist = get_paths(test_list);
char **paths = (char **)list_to_array(plist);
int m = plist->size;
free_list(plist);
int *indexes = calloc(top, sizeof(int));
for(i = 0; i < m; ++i){
double time = what_time_is_it_now();
char *path = paths[i];
image im = load_image_color(path, 0, 0);
image r = letterbox_image(im, net->w, net->h);
float *predictions = network_predict(net, r.data);
if(net->hierarchy) hierarchy_predictions(predictions, net->outputs, net->hierarchy, 1, 1);
top_k(predictions, net->outputs, top, indexes);
printf("%s", path);
for(j = 0; j < top; ++j){
printf("\t%d", indexes[j]);
}
printf("\n");
free_image(im);
free_image(r);
fprintf(stderr, "%lf seconds, %d images, %d total\n", what_time_is_it_now() - time, i+1, m);
}
}
void test_classifier(char *datacfg, char *cfgfile, char *weightfile, int target_layer)
{
@@ -869,8 +908,7 @@ void threat_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_i
}
if(1){
show_image(out, "Threat");
cvWaitKey(10);
show_image(out, "Threat", 10);
}
free_image(in_s);
free_image(in);
@@ -922,7 +960,6 @@ void gun_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_inde
image in = get_image_from_stream(cap);
image in_s = resize_image(in, net->w, net->h);
show_image(in, "Threat Detection");
float *predictions = network_predict(net, in_s.data);
top_predictions(net, top, indexes);
@@ -947,11 +984,10 @@ void gun_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_inde
}
}
show_image(in, "Threat Detection", 10);
free_image(in_s);
free_image(in);
cvWaitKey(10);
gettimeofday(&tval_after, NULL);
timersub(&tval_after, &tval_before, &tval_result);
float curr = 1000000.f/((long int)tval_result.tv_usec);
@@ -1036,12 +1072,10 @@ void demo_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_ind
free_image(label);
}
show_image(in, base);
show_image(in, base, 10);
free_image(in_s);
free_image(in);
cvWaitKey(10);
gettimeofday(&tval_after, NULL);
timersub(&tval_after, &tval_before, &tval_result);
float curr = 1000000.f/((long int)tval_result.tv_usec);
@@ -1080,6 +1114,7 @@ void run_classifier(int argc, char **argv)
else if(0==strcmp(argv[2], "gun")) gun_classifier(data, cfg, weights, cam_index, filename);
else if(0==strcmp(argv[2], "threat")) threat_classifier(data, cfg, weights, cam_index, filename);
else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights, layer);
else if(0==strcmp(argv[2], "csv")) csv_classifier(data, cfg, weights);
else if(0==strcmp(argv[2], "label")) label_classifier(data, cfg, weights);
else if(0==strcmp(argv[2], "valid")) validate_classifier_single(data, cfg, weights);
else if(0==strcmp(argv[2], "validmulti")) validate_classifier_multi(data, cfg, weights);