Set the triaining data path to global variables.

Set the triaining data path to global variables.
This commit is contained in:
wisdom 2016-06-19 14:46:26 +08:00 committed by GitHub
parent 391ccbc4ff
commit 28fec8644f

View File

@ -9,13 +9,20 @@
#include "opencv2/highgui/highgui_c.h" #include "opencv2/highgui/highgui_c.h"
#endif #endif
#define CLASSES_NUM (20)
char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"}; char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
image voc_labels[20]; image voc_labels[CLASSES_NUM];
char * g_train_images_path = "/usr/local/darknet/2007_train.txt";
char * g_val_images_path = "/usr/local/darknet/2007_val.txt";
char * g_test_images_path = "/usr/local/darknet/2007_test.txt";
char * g_backup_directory_path = "/data/darknet/darknet_backup/darknet_my_test/";
void train_yolo(char *cfgfile, char *weightfile) void train_yolo(char *cfgfile, char *weightfile)
{ {
char *train_images = "/data/voc/train.txt"; char *train_images = g_train_images_path;
char *backup_directory = "/home/pjreddie/backup/"; char *backup_directory = g_backup_directory_path;
srand(time(0)); srand(time(0));
data_seed = time(0); data_seed = time(0);
char *base = basecfg(cfgfile); char *base = basecfg(cfgfile);
@ -144,7 +151,7 @@ void validate_yolo(char *cfgfile, char *weightfile)
char *base = "results/comp4_det_test_"; char *base = "results/comp4_det_test_";
//list *plist = get_paths("data/voc.2007.test"); //list *plist = get_paths("data/voc.2007.test");
list *plist = get_paths("/home/pjreddie/data/voc/2007_test.txt"); list *plist = get_paths(g_val_images_path);
//list *plist = get_paths("data/voc.2012.test"); //list *plist = get_paths("data/voc.2012.test");
char **paths = (char **)list_to_array(plist); char **paths = (char **)list_to_array(plist);
@ -233,7 +240,7 @@ void validate_yolo_recall(char *cfgfile, char *weightfile)
srand(time(0)); srand(time(0));
char *base = "results/comp4_det_test_"; char *base = "results/comp4_det_test_";
list *plist = get_paths("data/voc.2007.test"); list *plist = get_paths(g_test_images_path);
char **paths = (char **)list_to_array(plist); char **paths = (char **)list_to_array(plist);
layer l = net.layers[net.n-1]; layer l = net.layers[net.n-1];
@ -277,6 +284,18 @@ void validate_yolo_recall(char *cfgfile, char *weightfile)
labelpath = find_replace(labelpath, "JPEGImages", "labels"); labelpath = find_replace(labelpath, "JPEGImages", "labels");
labelpath = find_replace(labelpath, ".jpg", ".txt"); labelpath = find_replace(labelpath, ".jpg", ".txt");
labelpath = find_replace(labelpath, ".JPEG", ".txt"); labelpath = find_replace(labelpath, ".JPEG", ".txt");
labelpath = find_replace(labelpath, ".bmp", ".txt");
labelpath = find_replace(labelpath, ".dib", ".txt");
labelpath = find_replace(labelpath, ".jpe", ".txt");
labelpath = find_replace(labelpath, ".jp2", ".txt");
labelpath = find_replace(labelpath, ".png", ".txt");
labelpath = find_replace(labelpath, ".pbm", ".txt");
labelpath = find_replace(labelpath, ".pgm", ".txt");
labelpath = find_replace(labelpath, ".ppm", ".txt");
labelpath = find_replace(labelpath, ".sr", ".txt");
labelpath = find_replace(labelpath, ".ras", ".txt");
labelpath = find_replace(labelpath, ".tiff", ".txt");
labelpath = find_replace(labelpath, ".tif", ".txt");
int num_labels = 0; int num_labels = 0;
box_label *truth = read_boxes(labelpath, &num_labels); box_label *truth = read_boxes(labelpath, &num_labels);
@ -345,7 +364,7 @@ void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)
convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0); convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms); if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms);
//draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, voc_labels, 20); //draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, voc_labels, 20);
draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, voc_labels, 20); draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, voc_labels, CLASSES_NUM);
save_image(im, "predictions"); save_image(im, "predictions");
show_image(im, "predictions"); show_image(im, "predictions");