mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
changing data loading
This commit is contained in:
parent
3a99b8151f
commit
9d42f49a24
9
Makefile
9
Makefile
@ -1,8 +1,9 @@
|
||||
GPU=0
|
||||
OPENCV=0
|
||||
GPU=1
|
||||
OPENCV=1
|
||||
DEBUG=0
|
||||
|
||||
ARCH= --gpu-architecture=compute_20 --gpu-code=compute_20
|
||||
ARCH= -arch sm_52
|
||||
|
||||
VPATH=./src/
|
||||
EXEC=darknet
|
||||
@ -10,7 +11,7 @@ OBJDIR=./obj/
|
||||
|
||||
CC=gcc
|
||||
NVCC=nvcc
|
||||
OPTS=-Ofast
|
||||
OPTS=-O2
|
||||
LDFLAGS= -lm -pthread -lstdc++
|
||||
COMMON= -I/usr/local/cuda/include/
|
||||
CFLAGS=-Wall -Wfatal-errors
|
||||
@ -34,7 +35,7 @@ CFLAGS+= -DGPU
|
||||
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
|
||||
endif
|
||||
|
||||
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o
|
||||
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o
|
||||
ifeq ($(GPU), 1)
|
||||
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
|
||||
endif
|
||||
|
@ -1,6 +1,6 @@
|
||||
[net]
|
||||
batch=128
|
||||
subdivisions=1
|
||||
subdivisions=32
|
||||
height=256
|
||||
width=256
|
||||
channels=3
|
||||
|
19
src/box.c
19
src/box.c
@ -231,3 +231,22 @@ void do_nms(box *boxes, float **probs, int num_boxes, int classes, float thresh)
|
||||
}
|
||||
}
|
||||
|
||||
box encode_box(box b, box anchor)
|
||||
{
|
||||
box encode;
|
||||
encode.x = (b.x - anchor.x) / anchor.w;
|
||||
encode.y = (b.y - anchor.y) / anchor.h;
|
||||
encode.w = log2(b.w / anchor.w);
|
||||
encode.h = log2(b.h / anchor.h);
|
||||
return encode;
|
||||
}
|
||||
|
||||
box decode_box(box b, box anchor)
|
||||
{
|
||||
box decode;
|
||||
decode.x = b.x * anchor.w + anchor.x;
|
||||
decode.y = b.y * anchor.h + anchor.y;
|
||||
decode.w = pow(2., b.w) * anchor.w;
|
||||
decode.h = pow(2., b.h) * anchor.h;
|
||||
return decode;
|
||||
}
|
||||
|
@ -12,5 +12,7 @@ typedef struct{
|
||||
float box_iou(box a, box b);
|
||||
dbox diou(box a, box b);
|
||||
void do_nms(box *boxes, float **probs, int num_boxes, int classes, float thresh);
|
||||
box decode_box(box b, box anchor);
|
||||
box encode_box(box b, box anchor);
|
||||
|
||||
#endif
|
||||
|
117
src/captcha.c
117
src/captcha.c
@ -26,7 +26,7 @@ void fix_data_captcha(data d, int mask)
|
||||
}
|
||||
}
|
||||
|
||||
void train_captcha2(char *cfgfile, char *weightfile)
|
||||
void train_captcha(char *cfgfile, char *weightfile)
|
||||
{
|
||||
data_seed = time(0);
|
||||
srand(time(0));
|
||||
@ -55,7 +55,19 @@ void train_captcha2(char *cfgfile, char *weightfile)
|
||||
pthread_t load_thread;
|
||||
data train;
|
||||
data buffer;
|
||||
load_thread = load_data_thread(paths, imgs, plist->size, labels, 26, net.w, net.h, &buffer);
|
||||
|
||||
load_args args = {0};
|
||||
args.w = net.w;
|
||||
args.h = net.h;
|
||||
args.paths = paths;
|
||||
args.classes = 26;
|
||||
args.n = imgs;
|
||||
args.m = plist->size;
|
||||
args.labels = labels;
|
||||
args.d = &buffer;
|
||||
args.type = CLASSIFICATION_DATA;
|
||||
|
||||
load_thread = load_data_in_thread(args);
|
||||
while(1){
|
||||
++i;
|
||||
time=clock();
|
||||
@ -69,7 +81,7 @@ void train_captcha2(char *cfgfile, char *weightfile)
|
||||
cvWaitKey(0);
|
||||
*/
|
||||
|
||||
load_thread = load_data_thread(paths, imgs, plist->size, labels, 26, net.w, net.h, &buffer);
|
||||
load_thread = load_data_in_thread(args);
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
time=clock();
|
||||
float loss = train_network(net, train);
|
||||
@ -86,7 +98,7 @@ void train_captcha2(char *cfgfile, char *weightfile)
|
||||
}
|
||||
}
|
||||
|
||||
void test_captcha2(char *cfgfile, char *weightfile, char *filename)
|
||||
void test_captcha(char *cfgfile, char *weightfile, char *filename)
|
||||
{
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
@ -165,99 +177,6 @@ void valid_captcha(char *cfgfile, char *weightfile, char *filename)
|
||||
}
|
||||
}
|
||||
|
||||
void train_captcha(char *cfgfile, char *weightfile)
|
||||
{
|
||||
data_seed = time(0);
|
||||
srand(time(0));
|
||||
float avg_loss = -1;
|
||||
char *base = basecfg(cfgfile);
|
||||
printf("%s\n", base);
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
//net.seen=0;
|
||||
int imgs = 1024;
|
||||
int i = net.seen/imgs;
|
||||
char **labels = get_labels("/data/captcha/reimgs.labels.list");
|
||||
list *plist = get_paths("/data/captcha/reimgs.train.list");
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
printf("%d\n", plist->size);
|
||||
clock_t time;
|
||||
pthread_t load_thread;
|
||||
data train;
|
||||
data buffer;
|
||||
load_thread = load_data_thread(paths, imgs, plist->size, labels, 13, net.w, net.h, &buffer);
|
||||
while(1){
|
||||
++i;
|
||||
time=clock();
|
||||
pthread_join(load_thread, 0);
|
||||
train = buffer;
|
||||
|
||||
/*
|
||||
image im = float_to_image(256, 256, 3, train.X.vals[114]);
|
||||
show_image(im, "training");
|
||||
cvWaitKey(0);
|
||||
*/
|
||||
|
||||
load_thread = load_data_thread(paths, imgs, plist->size, labels, 13, net.w, net.h, &buffer);
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
time=clock();
|
||||
float loss = train_network(net, train);
|
||||
net.seen += imgs;
|
||||
if(avg_loss == -1) avg_loss = loss;
|
||||
avg_loss = avg_loss*.9 + loss*.1;
|
||||
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
|
||||
free_data(train);
|
||||
if(i%100==0){
|
||||
char buff[256];
|
||||
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
|
||||
save_weights(net, buff);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void test_captcha(char *cfgfile, char *weightfile, char *filename)
|
||||
{
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
set_batch_network(&net, 1);
|
||||
srand(2222222);
|
||||
int i = 0;
|
||||
char **names = get_labels("/data/captcha/reimgs.labels.list");
|
||||
char input[256];
|
||||
int indexes[13];
|
||||
while(1){
|
||||
if(filename){
|
||||
strncpy(input, filename, 256);
|
||||
}else{
|
||||
//printf("Enter Image Path: ");
|
||||
//fflush(stdout);
|
||||
fgets(input, 256, stdin);
|
||||
strtok(input, "\n");
|
||||
}
|
||||
image im = load_image_color(input, net.w, net.h);
|
||||
float *X = im.data;
|
||||
float *predictions = network_predict(net, X);
|
||||
top_predictions(net, 13, indexes);
|
||||
//printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
||||
for(i = 0; i < 13; ++i){
|
||||
int index = indexes[i];
|
||||
if(i != 0) printf(", ");
|
||||
printf("%s %f", names[index], predictions[index]);
|
||||
}
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
free_image(im);
|
||||
if (filename) break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/*
|
||||
void train_captcha(char *cfgfile, char *weightfile)
|
||||
{
|
||||
@ -435,8 +354,8 @@ void run_captcha(int argc, char **argv)
|
||||
char *cfg = argv[3];
|
||||
char *weights = (argc > 4) ? argv[4] : 0;
|
||||
char *filename = (argc > 5) ? argv[5]: 0;
|
||||
if(0==strcmp(argv[2], "train")) train_captcha2(cfg, weights);
|
||||
else if(0==strcmp(argv[2], "test")) test_captcha2(cfg, weights, filename);
|
||||
if(0==strcmp(argv[2], "train")) train_captcha(cfg, weights);
|
||||
else if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights, filename);
|
||||
else if(0==strcmp(argv[2], "valid")) valid_captcha(cfg, weights, filename);
|
||||
//if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights);
|
||||
//else if(0==strcmp(argv[2], "encode")) encode_captcha(cfg, weights);
|
||||
|
91
src/coco.c
91
src/coco.c
@ -15,41 +15,32 @@ char *coco_classes[] = {"person","bicycle","car","motorcycle","airplane","bus","
|
||||
|
||||
int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90};
|
||||
|
||||
void draw_coco(image im, float *box, int side, int objectness, char *label)
|
||||
void draw_coco(image im, float *pred, int side, char *label)
|
||||
{
|
||||
int classes = 80;
|
||||
int elems = 4+classes+objectness;
|
||||
int classes = 81;
|
||||
int elems = 4+classes;
|
||||
int j;
|
||||
int r, c;
|
||||
|
||||
for(r = 0; r < side; ++r){
|
||||
for(c = 0; c < side; ++c){
|
||||
j = (r*side + c) * elems;
|
||||
float scale = 1;
|
||||
if(objectness) scale = 1 - box[j++];
|
||||
int class = max_index(box+j, classes);
|
||||
if(scale * box[j+class] > 0.2){
|
||||
int width = box[j+class]*5 + 1;
|
||||
printf("%f %s\n", scale * box[j+class], coco_classes[class]);
|
||||
int class = max_index(pred+j, classes);
|
||||
if (class == 0) continue;
|
||||
if (pred[j+class] > 0.2){
|
||||
int width = pred[j+class]*5 + 1;
|
||||
printf("%f %s\n", pred[j+class], coco_classes[class-1]);
|
||||
float red = get_color(0,class,classes);
|
||||
float green = get_color(1,class,classes);
|
||||
float blue = get_color(2,class,classes);
|
||||
|
||||
j += classes;
|
||||
float x = box[j+0];
|
||||
float y = box[j+1];
|
||||
x = (x+c)/side;
|
||||
y = (y+r)/side;
|
||||
float w = box[j+2]; //*maxwidth;
|
||||
float h = box[j+3]; //*maxheight;
|
||||
h = h*h;
|
||||
w = w*w;
|
||||
|
||||
int left = (x-w/2)*im.w;
|
||||
int right = (x+w/2)*im.w;
|
||||
int top = (y-h/2)*im.h;
|
||||
int bot = (y+h/2)*im.h;
|
||||
draw_box_width(im, left, top, right, bot, width, red, green, blue);
|
||||
box predict = {pred[j+0], pred[j+1], pred[j+2], pred[j+3]};
|
||||
box anchor = {(c+.5)/side, (r+.5)/side, .5, .5};
|
||||
box decode = decode_box(predict, anchor);
|
||||
|
||||
draw_bbox(im, decode, width, red, green, blue);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -69,39 +60,47 @@ void train_coco(char *cfgfile, char *weightfile)
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
detection_layer layer = get_network_detection_layer(net);
|
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
int imgs = 128;
|
||||
int i = net.seen/imgs;
|
||||
data train, buffer;
|
||||
|
||||
int classes = layer.classes;
|
||||
int background = layer.objectness;
|
||||
int side = sqrt(get_detection_layer_locations(layer));
|
||||
int classes = 81;
|
||||
int side = 7;
|
||||
|
||||
char **paths;
|
||||
list *plist = get_paths(train_images);
|
||||
int N = plist->size;
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
|
||||
paths = (char **)list_to_array(plist);
|
||||
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
|
||||
load_args args = {0};
|
||||
args.w = net.w;
|
||||
args.h = net.h;
|
||||
args.paths = paths;
|
||||
args.n = imgs;
|
||||
args.m = plist->size;
|
||||
args.classes = classes;
|
||||
args.num_boxes = side;
|
||||
args.d = &buffer;
|
||||
args.type = REGION_DATA;
|
||||
|
||||
pthread_t load_thread = load_data_in_thread(args);
|
||||
clock_t time;
|
||||
while(i*imgs < N*120){
|
||||
i += 1;
|
||||
time=clock();
|
||||
pthread_join(load_thread, 0);
|
||||
train = buffer;
|
||||
load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
|
||||
load_thread = load_data_in_thread(args);
|
||||
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
|
||||
/*
|
||||
image im = float_to_image(net.w, net.h, 3, train.X.vals[114]);
|
||||
image copy = copy_image(im);
|
||||
draw_coco(copy, train.y.vals[114], 7, layer.objectness, "truth");
|
||||
cvWaitKey(0);
|
||||
free_image(copy);
|
||||
*/
|
||||
/*
|
||||
image im = float_to_image(net.w, net.h, 3, train.X.vals[114]);
|
||||
image copy = copy_image(im);
|
||||
draw_coco(copy, train.y.vals[114], 7, "truth");
|
||||
cvWaitKey(0);
|
||||
free_image(copy);
|
||||
*/
|
||||
|
||||
time=clock();
|
||||
float loss = train_network(net, train);
|
||||
@ -220,6 +219,11 @@ void validate_coco(char *cfgfile, char *weightfile)
|
||||
int nms = 1;
|
||||
float iou_thresh = .5;
|
||||
|
||||
load_args args = {0};
|
||||
args.w = net.w;
|
||||
args.h = net.h;
|
||||
args.type = IMAGE_DATA;
|
||||
|
||||
int nthreads = 8;
|
||||
image *val = calloc(nthreads, sizeof(image));
|
||||
image *val_resized = calloc(nthreads, sizeof(image));
|
||||
@ -227,7 +231,10 @@ void validate_coco(char *cfgfile, char *weightfile)
|
||||
image *buf_resized = calloc(nthreads, sizeof(image));
|
||||
pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
|
||||
for(t = 0; t < nthreads; ++t){
|
||||
thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h);
|
||||
args.path = paths[i+t];
|
||||
args.im = &buf[t];
|
||||
args.resized = &buf_resized[t];
|
||||
thr[t] = load_data_in_thread(args);
|
||||
}
|
||||
time_t start = time(0);
|
||||
for(i = nthreads; i < m+nthreads; i += nthreads){
|
||||
@ -238,7 +245,10 @@ void validate_coco(char *cfgfile, char *weightfile)
|
||||
val_resized[t] = buf_resized[t];
|
||||
}
|
||||
for(t = 0; t < nthreads && i+t < m; ++t){
|
||||
thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h);
|
||||
args.path = paths[i+t];
|
||||
args.im = &buf[t];
|
||||
args.resized = &buf_resized[t];
|
||||
thr[t] = load_data_in_thread(args);
|
||||
}
|
||||
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
||||
char *path = paths[i+t-nthreads];
|
||||
@ -267,7 +277,6 @@ void test_coco(char *cfgfile, char *weightfile, char *filename)
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
detection_layer layer = get_network_detection_layer(net);
|
||||
set_batch_network(&net, 1);
|
||||
srand(2222222);
|
||||
clock_t time;
|
||||
@ -287,7 +296,7 @@ void test_coco(char *cfgfile, char *weightfile, char *filename)
|
||||
time=clock();
|
||||
float *predictions = network_predict(net, X);
|
||||
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
||||
draw_coco(im, predictions, 7, layer.objectness, "predictions");
|
||||
draw_coco(im, predictions, 7, "predictions");
|
||||
free_image(im);
|
||||
free_image(sized);
|
||||
#ifdef OPENCV
|
||||
|
@ -12,7 +12,6 @@
|
||||
#endif
|
||||
|
||||
extern void run_imagenet(int argc, char **argv);
|
||||
extern void run_detection(int argc, char **argv);
|
||||
extern void run_yolo(int argc, char **argv);
|
||||
extern void run_coco(int argc, char **argv);
|
||||
extern void run_writing(int argc, char **argv);
|
||||
@ -164,8 +163,6 @@ int main(int argc, char **argv)
|
||||
run_imagenet(argc, argv);
|
||||
} else if (0 == strcmp(argv[1], "average")){
|
||||
average(argc, argv);
|
||||
} else if (0 == strcmp(argv[1], "detection")){
|
||||
run_detection(argc, argv);
|
||||
} else if (0 == strcmp(argv[1], "yolo")){
|
||||
run_yolo(argc, argv);
|
||||
} else if (0 == strcmp(argv[1], "coco")){
|
||||
|
308
src/data.c
308
src/data.c
@ -8,25 +8,6 @@
|
||||
|
||||
unsigned int data_seed;
|
||||
|
||||
typedef struct load_args{
|
||||
char **paths;
|
||||
int n;
|
||||
int m;
|
||||
char **labels;
|
||||
int k;
|
||||
int h;
|
||||
int w;
|
||||
int nh;
|
||||
int nw;
|
||||
int num_boxes;
|
||||
int classes;
|
||||
int background;
|
||||
data *d;
|
||||
char *path;
|
||||
image *im;
|
||||
image *resized;
|
||||
} load_args;
|
||||
|
||||
list *get_paths(char *filename)
|
||||
{
|
||||
char *path;
|
||||
@ -138,6 +119,89 @@ void randomize_boxes(box_label *b, int n)
|
||||
}
|
||||
}
|
||||
|
||||
void correct_boxes(box_label *boxes, int n, float dx, float dy, float sx, float sy, int flip)
|
||||
{
|
||||
int i;
|
||||
for(i = 0; i < n; ++i){
|
||||
boxes[i].left = boxes[i].left * sx - dx;
|
||||
boxes[i].right = boxes[i].right * sx - dx;
|
||||
boxes[i].top = boxes[i].top * sy - dy;
|
||||
boxes[i].bottom = boxes[i].bottom* sy - dy;
|
||||
|
||||
if(flip){
|
||||
float swap = boxes[i].left;
|
||||
boxes[i].left = 1. - boxes[i].right;
|
||||
boxes[i].right = 1. - swap;
|
||||
}
|
||||
|
||||
boxes[i].left = constrain(0, 1, boxes[i].left);
|
||||
boxes[i].right = constrain(0, 1, boxes[i].right);
|
||||
boxes[i].top = constrain(0, 1, boxes[i].top);
|
||||
boxes[i].bottom = constrain(0, 1, boxes[i].bottom);
|
||||
|
||||
boxes[i].x = (boxes[i].left+boxes[i].right)/2;
|
||||
boxes[i].y = (boxes[i].top+boxes[i].bottom)/2;
|
||||
boxes[i].w = (boxes[i].right - boxes[i].left);
|
||||
boxes[i].h = (boxes[i].bottom - boxes[i].top);
|
||||
|
||||
boxes[i].w = constrain(0, 1, boxes[i].w);
|
||||
boxes[i].h = constrain(0, 1, boxes[i].h);
|
||||
}
|
||||
}
|
||||
|
||||
void fill_truth_region(char *path, float *truth, int classes, int num_boxes, int flip, float dx, float dy, float sx, float sy)
|
||||
{
|
||||
char *labelpath = find_replace(path, "images", "labels");
|
||||
labelpath = find_replace(labelpath, ".jpg", ".txt");
|
||||
labelpath = find_replace(labelpath, ".JPEG", ".txt");
|
||||
int count = 0;
|
||||
box_label *boxes = read_boxes(labelpath, &count);
|
||||
randomize_boxes(boxes, count);
|
||||
correct_boxes(boxes, count, dx, dy, sx, sy, flip);
|
||||
float x,y,w,h;
|
||||
int id;
|
||||
int i;
|
||||
|
||||
for(i = 0; i < num_boxes*num_boxes*(4+classes); i += 4+classes){
|
||||
truth[i] = 1;
|
||||
}
|
||||
|
||||
for(i = 0; i < count; ++i){
|
||||
x = boxes[i].x;
|
||||
y = boxes[i].y;
|
||||
w = boxes[i].w;
|
||||
h = boxes[i].h;
|
||||
id = boxes[i].id;
|
||||
|
||||
if (x <= 0 || x >= 1 || y <= 0 || y >= 1) continue;
|
||||
if (w < .01 || h < .01) continue;
|
||||
|
||||
int col = (int)(x*num_boxes);
|
||||
int row = (int)(y*num_boxes);
|
||||
|
||||
float xa = (col+.5)/num_boxes;
|
||||
float ya = (row+.5)/num_boxes;
|
||||
float wa = .5;
|
||||
float ha = .5;
|
||||
|
||||
float tx = (x - xa) / wa;
|
||||
float ty = (y - ya) / ha;
|
||||
float tw = log2(w/wa);
|
||||
float th = log2(h/ha);
|
||||
|
||||
int index = (col+row*num_boxes)*(4+classes);
|
||||
if(!truth[index]) continue;
|
||||
truth[index] = 0;
|
||||
truth[index+id+1] = 1;
|
||||
index += classes;
|
||||
truth[index++] = tx;
|
||||
truth[index++] = ty;
|
||||
truth[index++] = tw;
|
||||
truth[index++] = th;
|
||||
}
|
||||
free(boxes);
|
||||
}
|
||||
|
||||
void fill_truth_detection(char *path, float *truth, int classes, int num_boxes, int flip, int background, float dx, float dy, float sx, float sy)
|
||||
{
|
||||
char *labelpath = find_replace(path, "JPEGImages", "labels");
|
||||
@ -178,20 +242,20 @@ void fill_truth_detection(char *path, float *truth, int classes, int num_boxes,
|
||||
w = (right - left);
|
||||
h = (bot - top);
|
||||
|
||||
if (x <= 0 || x >= 1 || y <= 0 || y >= 1) continue;
|
||||
if (x <= 0 || x >= 1 || y <= 0 || y >= 1) continue;
|
||||
|
||||
int i = (int)(x*num_boxes);
|
||||
int j = (int)(y*num_boxes);
|
||||
int col = (int)(x*num_boxes);
|
||||
int row = (int)(y*num_boxes);
|
||||
|
||||
x = x*num_boxes - i;
|
||||
y = y*num_boxes - j;
|
||||
x = x*num_boxes - col;
|
||||
y = y*num_boxes - row;
|
||||
|
||||
/*
|
||||
float maxwidth = distance_from_edge(i, num_boxes);
|
||||
float maxheight = distance_from_edge(j, num_boxes);
|
||||
w = w/maxwidth;
|
||||
h = h/maxheight;
|
||||
*/
|
||||
float maxwidth = distance_from_edge(i, num_boxes);
|
||||
float maxheight = distance_from_edge(j, num_boxes);
|
||||
w = w/maxwidth;
|
||||
h = h/maxheight;
|
||||
*/
|
||||
|
||||
w = constrain(0, 1, w);
|
||||
h = constrain(0, 1, h);
|
||||
@ -201,7 +265,7 @@ void fill_truth_detection(char *path, float *truth, int classes, int num_boxes,
|
||||
h = pow(h, 1./2.);
|
||||
}
|
||||
|
||||
int index = (i+j*num_boxes)*(4+classes+background);
|
||||
int index = (col+row*num_boxes)*(4+classes+background);
|
||||
if(truth[index+classes+background+2]) continue;
|
||||
if(background) truth[index++] = 0;
|
||||
truth[index+id] = 1;
|
||||
@ -214,57 +278,6 @@ void fill_truth_detection(char *path, float *truth, int classes, int num_boxes,
|
||||
free(boxes);
|
||||
}
|
||||
|
||||
void fill_truth_localization(char *path, float *truth, int classes, int flip, float dx, float dy, float sx, float sy)
|
||||
{
|
||||
char *labelpath = find_replace(path, "objects", "object_labels");
|
||||
labelpath = find_replace(labelpath, ".jpg", ".txt");
|
||||
labelpath = find_replace(labelpath, ".JPEG", ".txt");
|
||||
int count;
|
||||
box_label *boxes = read_boxes(labelpath, &count);
|
||||
box_label box = boxes[0];
|
||||
free(boxes);
|
||||
float x,y,w,h;
|
||||
float left, top, right, bot;
|
||||
int id;
|
||||
int i;
|
||||
for(i = 0; i < count; ++i){
|
||||
left = box.left * sx - dx;
|
||||
right = box.right * sx - dx;
|
||||
top = box.top * sy - dy;
|
||||
bot = box.bottom* sy - dy;
|
||||
id = box.id;
|
||||
|
||||
if(flip){
|
||||
float swap = left;
|
||||
left = 1. - right;
|
||||
right = 1. - swap;
|
||||
}
|
||||
|
||||
left = constrain(0, 1, left);
|
||||
right = constrain(0, 1, right);
|
||||
top = constrain(0, 1, top);
|
||||
bot = constrain(0, 1, bot);
|
||||
|
||||
x = (left+right)/2;
|
||||
y = (top+bot)/2;
|
||||
w = (right - left);
|
||||
h = (bot - top);
|
||||
|
||||
if (x <= 0 || x >= 1 || y <= 0 || y >= 1) continue;
|
||||
|
||||
w = constrain(0, 1, w);
|
||||
h = constrain(0, 1, h);
|
||||
if (w == 0 || h == 0) continue;
|
||||
|
||||
int index = id*4;
|
||||
truth[index++] = x;
|
||||
truth[index++] = y;
|
||||
truth[index++] = w;
|
||||
truth[index++] = h;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define NUMCHARS 37
|
||||
|
||||
void print_letters(float *pred, int n)
|
||||
@ -362,7 +375,7 @@ void free_data(data d)
|
||||
}
|
||||
}
|
||||
|
||||
data load_data_localization(int n, char **paths, int m, int classes, int w, int h)
|
||||
data load_data_region(int n, char **paths, int m, int classes, int w, int h, int num_boxes)
|
||||
{
|
||||
char **random_paths = get_random_paths(paths, n, m);
|
||||
int i;
|
||||
@ -373,7 +386,7 @@ data load_data_localization(int n, char **paths, int m, int classes, int w, int
|
||||
d.X.vals = calloc(d.X.rows, sizeof(float*));
|
||||
d.X.cols = h*w*3;
|
||||
|
||||
int k = (4*classes);
|
||||
int k = num_boxes*num_boxes*(4+classes);
|
||||
d.y = make_matrix(n, k);
|
||||
for(i = 0; i < n; ++i){
|
||||
image orig = load_image_color(random_paths[i], 0, 0);
|
||||
@ -381,13 +394,13 @@ data load_data_localization(int n, char **paths, int m, int classes, int w, int
|
||||
int oh = orig.h;
|
||||
int ow = orig.w;
|
||||
|
||||
int dw = 32;
|
||||
int dh = 32;
|
||||
int dw = ow/10;
|
||||
int dh = oh/10;
|
||||
|
||||
int pleft = (rand_uniform() * dw);
|
||||
int pright = (rand_uniform() * dw);
|
||||
int ptop = (rand_uniform() * dh);
|
||||
int pbot = (rand_uniform() * dh);
|
||||
int pleft = (rand_uniform() * 2*dw - dw);
|
||||
int pright = (rand_uniform() * 2*dw - dw);
|
||||
int ptop = (rand_uniform() * 2*dh - dh);
|
||||
int pbot = (rand_uniform() * 2*dh - dh);
|
||||
|
||||
int swidth = ow - pleft - pright;
|
||||
int sheight = oh - ptop - pbot;
|
||||
@ -397,22 +410,24 @@ data load_data_localization(int n, char **paths, int m, int classes, int w, int
|
||||
|
||||
int flip = rand_r(&data_seed)%2;
|
||||
image cropped = crop_image(orig, pleft, ptop, swidth, sheight);
|
||||
|
||||
float dx = ((float)pleft/ow)/sx;
|
||||
float dy = ((float)ptop /oh)/sy;
|
||||
|
||||
free_image(orig);
|
||||
image sized = resize_image(cropped, w, h);
|
||||
free_image(cropped);
|
||||
if(flip) flip_image(sized);
|
||||
d.X.vals[i] = sized.data;
|
||||
|
||||
fill_truth_localization(random_paths[i], d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy);
|
||||
fill_truth_region(random_paths[i], d.y.vals[i], classes, num_boxes, flip, dx, dy, 1./sx, 1./sy);
|
||||
|
||||
free_image(orig);
|
||||
free_image(cropped);
|
||||
}
|
||||
free(random_paths);
|
||||
return d;
|
||||
}
|
||||
|
||||
data load_data_detection_jitter_random(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background)
|
||||
data load_data_detection(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background)
|
||||
{
|
||||
char **random_paths = get_random_paths(paths, n, m);
|
||||
int i;
|
||||
@ -471,81 +486,30 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int classes,
|
||||
return d;
|
||||
}
|
||||
|
||||
void *load_image_in_thread(void *ptr)
|
||||
{
|
||||
load_args a = *(load_args*)ptr;
|
||||
free(ptr);
|
||||
*(a.im) = load_image_color(a.path, 0, 0);
|
||||
*(a.resized) = resize_image(*(a.im), a.w, a.h);
|
||||
return 0;
|
||||
}
|
||||
|
||||
pthread_t load_image_thread(char *path, image *im, image *resized, int w, int h)
|
||||
{
|
||||
pthread_t thread;
|
||||
struct load_args *args = calloc(1, sizeof(struct load_args));
|
||||
args->path = path;
|
||||
args->w = w;
|
||||
args->h = h;
|
||||
args->im = im;
|
||||
args->resized = resized;
|
||||
if(pthread_create(&thread, 0, load_image_in_thread, args)) {
|
||||
error("Thread creation failed");
|
||||
}
|
||||
return thread;
|
||||
}
|
||||
|
||||
void *load_localization_thread(void *ptr)
|
||||
void *load_thread(void *ptr)
|
||||
{
|
||||
printf("Loading data: %d\n", rand_r(&data_seed));
|
||||
struct load_args a = *(struct load_args*)ptr;
|
||||
*a.d = load_data_localization(a.n, a.paths, a.m, a.classes, a.w, a.h);
|
||||
free(ptr);
|
||||
return 0;
|
||||
}
|
||||
|
||||
pthread_t load_data_localization_thread(int n, char **paths, int m, int classes, int w, int h, data *d)
|
||||
{
|
||||
pthread_t thread;
|
||||
struct load_args *args = calloc(1, sizeof(struct load_args));
|
||||
args->n = n;
|
||||
args->paths = paths;
|
||||
args->m = m;
|
||||
args->w = w;
|
||||
args->h = h;
|
||||
args->classes = classes;
|
||||
args->d = d;
|
||||
if(pthread_create(&thread, 0, load_localization_thread, args)) {
|
||||
error("Thread creation failed");
|
||||
load_args a = *(struct load_args*)ptr;
|
||||
if (a.type == CLASSIFICATION_DATA){
|
||||
*a.d = load_data(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
|
||||
} else if (a.type == DETECTION_DATA){
|
||||
*a.d = load_data_detection(a.n, a.paths, a.m, a.classes, a.w, a.h, a.num_boxes, a.background);
|
||||
} else if (a.type == REGION_DATA){
|
||||
*a.d = load_data_region(a.n, a.paths, a.m, a.classes, a.w, a.h, a.num_boxes);
|
||||
} else if (a.type == IMAGE_DATA){
|
||||
*(a.im) = load_image_color(a.path, 0, 0);
|
||||
*(a.resized) = resize_image(*(a.im), a.w, a.h);
|
||||
}
|
||||
return thread;
|
||||
}
|
||||
|
||||
void *load_detection_thread(void *ptr)
|
||||
{
|
||||
printf("Loading data: %d\n", rand_r(&data_seed));
|
||||
struct load_args a = *(struct load_args*)ptr;
|
||||
*a.d = load_data_detection_jitter_random(a.n, a.paths, a.m, a.classes, a.w, a.h, a.num_boxes, a.background);
|
||||
free(ptr);
|
||||
return 0;
|
||||
}
|
||||
|
||||
pthread_t load_data_detection_thread(int n, char **paths, int m, int classes, int w, int h, int nh, int nw, int background, data *d)
|
||||
pthread_t load_data_in_thread(load_args args)
|
||||
{
|
||||
pthread_t thread;
|
||||
struct load_args *args = calloc(1, sizeof(struct load_args));
|
||||
args->n = n;
|
||||
args->paths = paths;
|
||||
args->m = m;
|
||||
args->h = h;
|
||||
args->w = w;
|
||||
args->nh = nh;
|
||||
args->nw = nw;
|
||||
args->num_boxes = nw;
|
||||
args->classes = classes;
|
||||
args->background = background;
|
||||
args->d = d;
|
||||
if(pthread_create(&thread, 0, load_detection_thread, args)) {
|
||||
struct load_args *ptr = calloc(1, sizeof(struct load_args));
|
||||
*ptr = args;
|
||||
if(pthread_create(&thread, 0, load_thread, ptr)) {
|
||||
error("Thread creation failed");
|
||||
}
|
||||
return thread;
|
||||
@ -577,32 +541,6 @@ data load_data(char **paths, int n, int m, char **labels, int k, int w, int h)
|
||||
return d;
|
||||
}
|
||||
|
||||
void *load_in_thread(void *ptr)
|
||||
{
|
||||
struct load_args a = *(struct load_args*)ptr;
|
||||
*a.d = load_data(a.paths, a.n, a.m, a.labels, a.k, a.w, a.h);
|
||||
free(ptr);
|
||||
return 0;
|
||||
}
|
||||
|
||||
pthread_t load_data_thread(char **paths, int n, int m, char **labels, int k, int w, int h, data *d)
|
||||
{
|
||||
pthread_t thread;
|
||||
struct load_args *args = calloc(1, sizeof(struct load_args));
|
||||
args->n = n;
|
||||
args->paths = paths;
|
||||
args->m = m;
|
||||
args->labels = labels;
|
||||
args->k = k;
|
||||
args->h = h;
|
||||
args->w = w;
|
||||
args->d = d;
|
||||
if(pthread_create(&thread, 0, load_in_thread, args)) {
|
||||
error("Thread creation failed");
|
||||
}
|
||||
return thread;
|
||||
}
|
||||
|
||||
matrix concat_matrix(matrix m1, matrix m2)
|
||||
{
|
||||
int i, count = 0;
|
||||
|
33
src/data.h
33
src/data.h
@ -19,26 +19,45 @@ static inline float distance_from_edge(int x, int max)
|
||||
return dist;
|
||||
}
|
||||
|
||||
|
||||
typedef struct{
|
||||
matrix X;
|
||||
matrix y;
|
||||
int shallow;
|
||||
} data;
|
||||
|
||||
typedef enum {
|
||||
CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA
|
||||
} data_type;
|
||||
|
||||
typedef struct load_args{
|
||||
char **paths;
|
||||
char *path;
|
||||
int n;
|
||||
int m;
|
||||
char **labels;
|
||||
int k;
|
||||
int h;
|
||||
int w;
|
||||
int nh;
|
||||
int nw;
|
||||
int num_boxes;
|
||||
int classes;
|
||||
int background;
|
||||
data *d;
|
||||
image *im;
|
||||
image *resized;
|
||||
data_type type;
|
||||
} load_args;
|
||||
|
||||
void free_data(data d);
|
||||
|
||||
pthread_t load_data_in_thread(load_args args);
|
||||
|
||||
void print_letters(float *pred, int n);
|
||||
data load_data_captcha(char **paths, int n, int m, int k, int w, int h);
|
||||
data load_data_captcha_encode(char **paths, int n, int m, int w, int h);
|
||||
data load_data(char **paths, int n, int m, char **labels, int k, int w, int h);
|
||||
pthread_t load_data_thread(char **paths, int n, int m, char **labels, int k, int w, int h, data *d);
|
||||
pthread_t load_image_thread(char *path, image *im, image *resized, int w, int h);
|
||||
|
||||
pthread_t load_data_detection_thread(int n, char **paths, int m, int classes, int w, int h, int nh, int nw, int background, data *d);
|
||||
data load_data_detection_jitter_random(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background);
|
||||
pthread_t load_data_localization_thread(int n, char **paths, int m, int classes, int w, int h, data *d);
|
||||
data load_data_detection(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background);
|
||||
|
||||
data load_cifar10_data(char *filename);
|
||||
data load_all_cifar10();
|
||||
|
305
src/detection.c
305
src/detection.c
@ -1,305 +0,0 @@
|
||||
#include "network.h"
|
||||
#include "detection_layer.h"
|
||||
#include "cost_layer.h"
|
||||
#include "utils.h"
|
||||
#include "parser.h"
|
||||
#include "box.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
char *class_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
|
||||
|
||||
void draw_detection(image im, float *box, int side, int objectness, char *label)
|
||||
{
|
||||
int classes = 20;
|
||||
int elems = 4+classes+objectness;
|
||||
int j;
|
||||
int r, c;
|
||||
|
||||
for(r = 0; r < side; ++r){
|
||||
for(c = 0; c < side; ++c){
|
||||
j = (r*side + c) * elems;
|
||||
float scale = 1;
|
||||
if(objectness) scale = 1 - box[j++];
|
||||
int class = max_index(box+j, classes);
|
||||
if(scale * box[j+class] > 0.2){
|
||||
int width = box[j+class]*5 + 1;
|
||||
printf("%f %s\n", scale * box[j+class], class_names[class]);
|
||||
float red = get_color(0,class,classes);
|
||||
float green = get_color(1,class,classes);
|
||||
float blue = get_color(2,class,classes);
|
||||
|
||||
j += classes;
|
||||
float x = box[j+0];
|
||||
float y = box[j+1];
|
||||
x = (x+c)/side;
|
||||
y = (y+r)/side;
|
||||
float w = box[j+2]; //*maxwidth;
|
||||
float h = box[j+3]; //*maxheight;
|
||||
h = h*h;
|
||||
w = w*w;
|
||||
|
||||
int left = (x-w/2)*im.w;
|
||||
int right = (x+w/2)*im.w;
|
||||
int top = (y-h/2)*im.h;
|
||||
int bot = (y+h/2)*im.h;
|
||||
draw_box_width(im, left, top, right, bot, width, red, green, blue);
|
||||
}
|
||||
}
|
||||
}
|
||||
show_image(im, label);
|
||||
}
|
||||
|
||||
void train_detection(char *cfgfile, char *weightfile)
|
||||
{
|
||||
char *train_images = "/home/pjreddie/data/voc/test/train.txt";
|
||||
char *backup_directory = "/home/pjreddie/backup/";
|
||||
srand(time(0));
|
||||
data_seed = time(0);
|
||||
char *base = basecfg(cfgfile);
|
||||
printf("%s\n", base);
|
||||
float avg_loss = -1;
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
detection_layer layer = get_network_detection_layer(net);
|
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
int imgs = 128;
|
||||
int i = net.seen/imgs;
|
||||
data train, buffer;
|
||||
|
||||
int classes = layer.classes;
|
||||
int background = layer.objectness;
|
||||
int side = sqrt(get_detection_layer_locations(layer));
|
||||
|
||||
char **paths;
|
||||
list *plist = get_paths(train_images);
|
||||
int N = plist->size;
|
||||
|
||||
paths = (char **)list_to_array(plist);
|
||||
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
|
||||
clock_t time;
|
||||
while(i*imgs < N*130){
|
||||
i += 1;
|
||||
time=clock();
|
||||
pthread_join(load_thread, 0);
|
||||
train = buffer;
|
||||
load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
|
||||
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
time=clock();
|
||||
float loss = train_network(net, train);
|
||||
net.seen += imgs;
|
||||
if (avg_loss < 0) avg_loss = loss;
|
||||
avg_loss = avg_loss*.9 + loss*.1;
|
||||
|
||||
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
|
||||
if((i-1)*imgs <= N && i*imgs > N){
|
||||
fprintf(stderr, "First stage done\n");
|
||||
net.learning_rate *= 10;
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s_first_stage.weights", backup_directory, base);
|
||||
save_weights(net, buff);
|
||||
}
|
||||
if((i-1)*imgs <= 80*N && i*imgs > N*80){
|
||||
fprintf(stderr, "Second stage done.\n");
|
||||
net.learning_rate *= .1;
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s_second_stage.weights", backup_directory, base);
|
||||
save_weights(net, buff);
|
||||
return;
|
||||
}
|
||||
if((i-1)*imgs <= 120*N && i*imgs > N*120){
|
||||
fprintf(stderr, "Third stage done.\n");
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s_third_stage.weights", backup_directory, base);
|
||||
net.layers[net.n-1].rescore = 1;
|
||||
save_weights(net, buff);
|
||||
}
|
||||
if(i%1000==0){
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
||||
save_weights(net, buff);
|
||||
}
|
||||
free_data(train);
|
||||
}
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
||||
save_weights(net, buff);
|
||||
}
|
||||
|
||||
void convert_detections(float *predictions, int classes, int objectness, int background, int num_boxes, int w, int h, float thresh, float **probs, box *boxes)
|
||||
{
|
||||
int i,j;
|
||||
int per_box = 4+classes+(background || objectness);
|
||||
for (i = 0; i < num_boxes*num_boxes; ++i){
|
||||
float scale = 1;
|
||||
if(objectness) scale = 1-predictions[i*per_box];
|
||||
int offset = i*per_box+(background||objectness);
|
||||
for(j = 0; j < classes; ++j){
|
||||
float prob = scale*predictions[offset+j];
|
||||
probs[i][j] = (prob > thresh) ? prob : 0;
|
||||
}
|
||||
int row = i / num_boxes;
|
||||
int col = i % num_boxes;
|
||||
offset += classes;
|
||||
boxes[i].x = (predictions[offset + 0] + col) / num_boxes * w;
|
||||
boxes[i].y = (predictions[offset + 1] + row) / num_boxes * h;
|
||||
boxes[i].w = pow(predictions[offset + 2], 2) * w;
|
||||
boxes[i].h = pow(predictions[offset + 3], 2) * h;
|
||||
}
|
||||
}
|
||||
|
||||
void print_detections(FILE **fps, char *id, box *boxes, float **probs, int num_boxes, int classes, int w, int h)
|
||||
{
|
||||
int i, j;
|
||||
for(i = 0; i < num_boxes*num_boxes; ++i){
|
||||
float xmin = boxes[i].x - boxes[i].w/2.;
|
||||
float xmax = boxes[i].x + boxes[i].w/2.;
|
||||
float ymin = boxes[i].y - boxes[i].h/2.;
|
||||
float ymax = boxes[i].y + boxes[i].h/2.;
|
||||
|
||||
if (xmin < 0) xmin = 0;
|
||||
if (ymin < 0) ymin = 0;
|
||||
if (xmax > w) xmax = w;
|
||||
if (ymax > h) ymax = h;
|
||||
|
||||
for(j = 0; j < classes; ++j){
|
||||
if (probs[i][j]) fprintf(fps[j], "%s %f %f %f %f %f\n", id, probs[i][j],
|
||||
xmin, ymin, xmax, ymax);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void validate_detection(char *cfgfile, char *weightfile)
|
||||
{
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
set_batch_network(&net, 1);
|
||||
detection_layer layer = get_network_detection_layer(net);
|
||||
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
srand(time(0));
|
||||
|
||||
char *base = "results/comp4_det_test_";
|
||||
list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
|
||||
int classes = layer.classes;
|
||||
int objectness = layer.objectness;
|
||||
int background = layer.background;
|
||||
int num_boxes = sqrt(get_detection_layer_locations(layer));
|
||||
|
||||
int j;
|
||||
FILE **fps = calloc(classes, sizeof(FILE *));
|
||||
for(j = 0; j < classes; ++j){
|
||||
char buff[1024];
|
||||
snprintf(buff, 1024, "%s%s.txt", base, class_names[j]);
|
||||
fps[j] = fopen(buff, "w");
|
||||
}
|
||||
box *boxes = calloc(num_boxes*num_boxes, sizeof(box));
|
||||
float **probs = calloc(num_boxes*num_boxes, sizeof(float *));
|
||||
for(j = 0; j < num_boxes*num_boxes; ++j) probs[j] = calloc(classes, sizeof(float *));
|
||||
|
||||
int m = plist->size;
|
||||
int i=0;
|
||||
int t;
|
||||
|
||||
float thresh = .001;
|
||||
int nms = 1;
|
||||
float iou_thresh = .5;
|
||||
|
||||
int nthreads = 8;
|
||||
image *val = calloc(nthreads, sizeof(image));
|
||||
image *val_resized = calloc(nthreads, sizeof(image));
|
||||
image *buf = calloc(nthreads, sizeof(image));
|
||||
image *buf_resized = calloc(nthreads, sizeof(image));
|
||||
pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
|
||||
for(t = 0; t < nthreads; ++t){
|
||||
thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h);
|
||||
}
|
||||
time_t start = time(0);
|
||||
for(i = nthreads; i < m+nthreads; i += nthreads){
|
||||
fprintf(stderr, "%d\n", i);
|
||||
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
||||
pthread_join(thr[t], 0);
|
||||
val[t] = buf[t];
|
||||
val_resized[t] = buf_resized[t];
|
||||
}
|
||||
for(t = 0; t < nthreads && i+t < m; ++t){
|
||||
thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h);
|
||||
}
|
||||
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
||||
char *path = paths[i+t-nthreads];
|
||||
char *id = basecfg(path);
|
||||
float *X = val_resized[t].data;
|
||||
float *predictions = network_predict(net, X);
|
||||
int w = val[t].w;
|
||||
int h = val[t].h;
|
||||
convert_detections(predictions, classes, objectness, background, num_boxes, w, h, thresh, probs, boxes);
|
||||
if (nms) do_nms(boxes, probs, num_boxes, classes, iou_thresh);
|
||||
print_detections(fps, id, boxes, probs, num_boxes, classes, w, h);
|
||||
free(id);
|
||||
free_image(val[t]);
|
||||
free_image(val_resized[t]);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
|
||||
}
|
||||
|
||||
void test_detection(char *cfgfile, char *weightfile, char *filename)
|
||||
{
|
||||
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
detection_layer layer = get_network_detection_layer(net);
|
||||
set_batch_network(&net, 1);
|
||||
srand(2222222);
|
||||
clock_t time;
|
||||
char input[256];
|
||||
while(1){
|
||||
if(filename){
|
||||
strncpy(input, filename, 256);
|
||||
} else {
|
||||
printf("Enter Image Path: ");
|
||||
fflush(stdout);
|
||||
fgets(input, 256, stdin);
|
||||
strtok(input, "\n");
|
||||
}
|
||||
image im = load_image_color(input,0,0);
|
||||
image sized = resize_image(im, net.w, net.h);
|
||||
float *X = sized.data;
|
||||
time=clock();
|
||||
float *predictions = network_predict(net, X);
|
||||
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
||||
draw_detection(im, predictions, 7, layer.objectness, "predictions");
|
||||
free_image(im);
|
||||
free_image(sized);
|
||||
#ifdef OPENCV
|
||||
cvWaitKey(0);
|
||||
cvDestroyAllWindows();
|
||||
#endif
|
||||
if (filename) break;
|
||||
}
|
||||
}
|
||||
|
||||
void run_detection(int argc, char **argv)
|
||||
{
|
||||
if(argc < 4){
|
||||
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
||||
return;
|
||||
}
|
||||
|
||||
char *cfg = argv[3];
|
||||
char *weights = (argc > 4) ? argv[4] : 0;
|
||||
char *filename = (argc > 5) ? argv[5]: 0;
|
||||
if(0==strcmp(argv[2], "test")) test_detection(cfg, weights, filename);
|
||||
else if(0==strcmp(argv[2], "train")) train_detection(cfg, weights);
|
||||
else if(0==strcmp(argv[2], "valid")) validate_detection(cfg, weights);
|
||||
}
|
@ -97,6 +97,7 @@ void forward_detection_layer(const detection_layer l, network_state state)
|
||||
truth.y = state.truth[j+1]/7;
|
||||
truth.w = pow(state.truth[j+2], 2);
|
||||
truth.h = pow(state.truth[j+3], 2);
|
||||
|
||||
box out;
|
||||
out.x = l.output[j+0]/7;
|
||||
out.y = l.output[j+1]/7;
|
||||
@ -107,13 +108,6 @@ void forward_detection_layer(const detection_layer l, network_state state)
|
||||
float iou = box_iou(out, truth);
|
||||
avg_iou += iou;
|
||||
++count;
|
||||
dbox delta = diou(out, truth);
|
||||
|
||||
l.delta[j+0] = 10 * delta.dx/7;
|
||||
l.delta[j+1] = 10 * delta.dy/7;
|
||||
l.delta[j+2] = 10 * delta.dw * 2 * sqrt(out.w);
|
||||
l.delta[j+3] = 10 * delta.dh * 2 * sqrt(out.h);
|
||||
|
||||
|
||||
*(l.cost) += pow((1-iou), 2);
|
||||
l.delta[j+0] = 4 * (state.truth[j+0] - l.output[j+0]);
|
||||
|
1145
src/image.c
1145
src/image.c
File diff suppressed because it is too large
Load Diff
@ -6,6 +6,7 @@
|
||||
#include <float.h>
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include "box.h"
|
||||
|
||||
typedef struct {
|
||||
int h;
|
||||
@ -18,6 +19,7 @@ float get_color(int c, int x, int max);
|
||||
void flip_image(image a);
|
||||
void draw_box(image a, int x1, int y1, int x2, int y2, float r, float g, float b);
|
||||
void draw_box_width(image a, int x1, int y1, int x2, int y2, int w, float r, float g, float b);
|
||||
void draw_bbox(image a, box bbox, int w, float r, float g, float b);
|
||||
image image_distance(image a, image b);
|
||||
void scale_image(image m, float s);
|
||||
image crop_image(image im, int dx, int dy, int w, int h);
|
||||
|
@ -30,7 +30,19 @@ void train_imagenet(char *cfgfile, char *weightfile)
|
||||
pthread_t load_thread;
|
||||
data train;
|
||||
data buffer;
|
||||
load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, net.w, net.h, &buffer);
|
||||
|
||||
load_args args = {0};
|
||||
args.w = net.w;
|
||||
args.h = net.h;
|
||||
args.paths = paths;
|
||||
args.classes = 1000;
|
||||
args.n = imgs;
|
||||
args.m = plist->size;
|
||||
args.labels = labels;
|
||||
args.d = &buffer;
|
||||
args.type = CLASSIFICATION_DATA;
|
||||
|
||||
load_thread = load_data_in_thread(args);
|
||||
while(1){
|
||||
++i;
|
||||
time=clock();
|
||||
@ -43,7 +55,7 @@ void train_imagenet(char *cfgfile, char *weightfile)
|
||||
cvWaitKey(0);
|
||||
*/
|
||||
|
||||
load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, net.w, net.h, &buffer);
|
||||
load_thread = load_data_in_thread(args);
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
time=clock();
|
||||
float loss = train_network(net, train);
|
||||
@ -84,7 +96,19 @@ void validate_imagenet(char *filename, char *weightfile)
|
||||
int num = (i+1)*m/splits - i*m/splits;
|
||||
|
||||
data val, buffer;
|
||||
pthread_t load_thread = load_data_thread(paths, num, 0, labels, 1000, 256, 256, &buffer);
|
||||
|
||||
load_args args = {0};
|
||||
args.w = net.w;
|
||||
args.h = net.h;
|
||||
args.paths = paths;
|
||||
args.classes = 1000;
|
||||
args.n = num;
|
||||
args.m = 0;
|
||||
args.labels = labels;
|
||||
args.d = &buffer;
|
||||
args.type = CLASSIFICATION_DATA;
|
||||
|
||||
pthread_t load_thread = load_data_in_thread(args);
|
||||
for(i = 1; i <= splits; ++i){
|
||||
time=clock();
|
||||
|
||||
@ -93,7 +117,10 @@ void validate_imagenet(char *filename, char *weightfile)
|
||||
|
||||
num = (i+1)*m/splits - i*m/splits;
|
||||
char **part = paths+(i*m/splits);
|
||||
if(i != splits) load_thread = load_data_thread(part, num, 0, labels, 1000, 256, 256, &buffer);
|
||||
if(i != splits){
|
||||
args.paths = part;
|
||||
load_thread = load_data_in_thread(args);
|
||||
}
|
||||
printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
|
||||
|
||||
time=clock();
|
||||
|
@ -15,6 +15,7 @@ typedef enum {
|
||||
ROUTE,
|
||||
COST,
|
||||
NORMALIZATION,
|
||||
REGION,
|
||||
AVGPOOL
|
||||
} LAYER_TYPE;
|
||||
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "convolutional_layer.h"
|
||||
#include "deconvolutional_layer.h"
|
||||
#include "detection_layer.h"
|
||||
#include "region_layer.h"
|
||||
#include "normalization_layer.h"
|
||||
#include "maxpool_layer.h"
|
||||
#include "avgpool_layer.h"
|
||||
@ -36,6 +37,8 @@ char *get_layer_string(LAYER_TYPE a)
|
||||
return "softmax";
|
||||
case DETECTION:
|
||||
return "detection";
|
||||
case REGION:
|
||||
return "region";
|
||||
case DROPOUT:
|
||||
return "dropout";
|
||||
case CROP:
|
||||
@ -80,6 +83,8 @@ void forward_network(network net, network_state state)
|
||||
forward_normalization_layer(l, state);
|
||||
} else if(l.type == DETECTION){
|
||||
forward_detection_layer(l, state);
|
||||
} else if(l.type == REGION){
|
||||
forward_region_layer(l, state);
|
||||
} else if(l.type == CONNECTED){
|
||||
forward_connected_layer(l, state);
|
||||
} else if(l.type == CROP){
|
||||
@ -130,12 +135,16 @@ float get_network_cost(network net)
|
||||
float sum = 0;
|
||||
int count = 0;
|
||||
for(i = 0; i < net.n; ++i){
|
||||
if(net.layers[net.n-1].type == COST){
|
||||
sum += net.layers[net.n-1].output[0];
|
||||
if(net.layers[i].type == COST){
|
||||
sum += net.layers[i].output[0];
|
||||
++count;
|
||||
}
|
||||
if(net.layers[net.n-1].type == DETECTION){
|
||||
sum += net.layers[net.n-1].cost[0];
|
||||
if(net.layers[i].type == DETECTION){
|
||||
sum += net.layers[i].cost[0];
|
||||
++count;
|
||||
}
|
||||
if(net.layers[i].type == REGION){
|
||||
sum += net.layers[i].cost[0];
|
||||
++count;
|
||||
}
|
||||
}
|
||||
@ -178,6 +187,8 @@ void backward_network(network net, network_state state)
|
||||
backward_dropout_layer(l, state);
|
||||
} else if(l.type == DETECTION){
|
||||
backward_detection_layer(l, state);
|
||||
} else if(l.type == REGION){
|
||||
backward_region_layer(l, state);
|
||||
} else if(l.type == SOFTMAX){
|
||||
if(i != 0) backward_softmax_layer(l, state);
|
||||
} else if(l.type == CONNECTED){
|
||||
|
@ -12,6 +12,7 @@ extern "C" {
|
||||
#include "crop_layer.h"
|
||||
#include "connected_layer.h"
|
||||
#include "detection_layer.h"
|
||||
#include "region_layer.h"
|
||||
#include "convolutional_layer.h"
|
||||
#include "deconvolutional_layer.h"
|
||||
#include "maxpool_layer.h"
|
||||
@ -42,6 +43,8 @@ void forward_network_gpu(network net, network_state state)
|
||||
forward_deconvolutional_layer_gpu(l, state);
|
||||
} else if(l.type == DETECTION){
|
||||
forward_detection_layer_gpu(l, state);
|
||||
} else if(l.type == REGION){
|
||||
forward_region_layer_gpu(l, state);
|
||||
} else if(l.type == CONNECTED){
|
||||
forward_connected_layer_gpu(l, state);
|
||||
} else if(l.type == CROP){
|
||||
@ -92,6 +95,8 @@ void backward_network_gpu(network net, network_state state)
|
||||
backward_dropout_layer_gpu(l, state);
|
||||
} else if(l.type == DETECTION){
|
||||
backward_detection_layer_gpu(l, state);
|
||||
} else if(l.type == REGION){
|
||||
backward_region_layer_gpu(l, state);
|
||||
} else if(l.type == NORMALIZATION){
|
||||
backward_normalization_layer_gpu(l, state);
|
||||
} else if(l.type == SOFTMAX){
|
||||
|
18
src/parser.c
18
src/parser.c
@ -14,6 +14,7 @@
|
||||
#include "softmax_layer.h"
|
||||
#include "dropout_layer.h"
|
||||
#include "detection_layer.h"
|
||||
#include "region_layer.h"
|
||||
#include "avgpool_layer.h"
|
||||
#include "route_layer.h"
|
||||
#include "list.h"
|
||||
@ -37,6 +38,7 @@ int is_normalization(section *s);
|
||||
int is_crop(section *s);
|
||||
int is_cost(section *s);
|
||||
int is_detection(section *s);
|
||||
int is_region(section *s);
|
||||
int is_route(section *s);
|
||||
list *read_cfg(char *filename);
|
||||
|
||||
@ -172,6 +174,16 @@ detection_layer parse_detection(list *options, size_params params)
|
||||
return layer;
|
||||
}
|
||||
|
||||
region_layer parse_region(list *options, size_params params)
|
||||
{
|
||||
int coords = option_find_int(options, "coords", 1);
|
||||
int classes = option_find_int(options, "classes", 1);
|
||||
int rescore = option_find_int(options, "rescore", 0);
|
||||
int num = option_find_int(options, "num", 1);
|
||||
region_layer layer = make_region_layer(params.batch, params.inputs, num, classes, coords, rescore);
|
||||
return layer;
|
||||
}
|
||||
|
||||
cost_layer parse_cost(list *options, size_params params)
|
||||
{
|
||||
char *type_s = option_find_str(options, "type", "sse");
|
||||
@ -347,6 +359,8 @@ network parse_network_cfg(char *filename)
|
||||
l = parse_cost(options, params);
|
||||
}else if(is_detection(s)){
|
||||
l = parse_detection(options, params);
|
||||
}else if(is_region(s)){
|
||||
l = parse_region(options, params);
|
||||
}else if(is_softmax(s)){
|
||||
l = parse_softmax(options, params);
|
||||
}else if(is_normalization(s)){
|
||||
@ -399,6 +413,10 @@ int is_detection(section *s)
|
||||
{
|
||||
return (strcmp(s->type, "[detection]")==0);
|
||||
}
|
||||
int is_region(section *s)
|
||||
{
|
||||
return (strcmp(s->type, "[region]")==0);
|
||||
}
|
||||
int is_deconvolutional(section *s)
|
||||
{
|
||||
return (strcmp(s->type, "[deconv]")==0
|
||||
|
161
src/region_layer.c
Normal file
161
src/region_layer.c
Normal file
@ -0,0 +1,161 @@
|
||||
#include "region_layer.h"
|
||||
#include "activations.h"
|
||||
#include "softmax_layer.h"
|
||||
#include "blas.h"
|
||||
#include "box.h"
|
||||
#include "cuda.h"
|
||||
#include "utils.h"
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
int get_region_layer_locations(region_layer l)
|
||||
{
|
||||
return l.inputs / (l.classes+l.coords);
|
||||
}
|
||||
|
||||
region_layer make_region_layer(int batch, int inputs, int n, int classes, int coords, int rescore)
|
||||
{
|
||||
region_layer l = {0};
|
||||
l.type = REGION;
|
||||
|
||||
l.n = n;
|
||||
l.batch = batch;
|
||||
l.inputs = inputs;
|
||||
l.classes = classes;
|
||||
l.coords = coords;
|
||||
l.rescore = rescore;
|
||||
l.cost = calloc(1, sizeof(float));
|
||||
int outputs = inputs;
|
||||
l.outputs = outputs;
|
||||
l.output = calloc(batch*outputs, sizeof(float));
|
||||
l.delta = calloc(batch*outputs, sizeof(float));
|
||||
#ifdef GPU
|
||||
l.output_gpu = cuda_make_array(0, batch*outputs);
|
||||
l.delta_gpu = cuda_make_array(0, batch*outputs);
|
||||
#endif
|
||||
|
||||
fprintf(stderr, "Region Layer\n");
|
||||
srand(0);
|
||||
|
||||
return l;
|
||||
}
|
||||
|
||||
void forward_region_layer(const region_layer l, network_state state)
|
||||
{
|
||||
int locations = get_region_layer_locations(l);
|
||||
int i,j;
|
||||
for(i = 0; i < l.batch*locations; ++i){
|
||||
int index = i*(l.classes + l.coords);
|
||||
int mask = (!state.truth || !state.truth[index]);
|
||||
|
||||
for(j = 0; j < l.classes; ++j){
|
||||
l.output[index+j] = state.input[index+j];
|
||||
}
|
||||
|
||||
softmax_array(l.output + index, l.classes, l.output + index);
|
||||
index += l.classes;
|
||||
|
||||
for(j = 0; j < l.coords; ++j){
|
||||
l.output[index+j] = mask*state.input[index+j];
|
||||
}
|
||||
}
|
||||
if(state.train){
|
||||
float avg_iou = 0;
|
||||
int count = 0;
|
||||
*(l.cost) = 0;
|
||||
int size = l.outputs * l.batch;
|
||||
memset(l.delta, 0, size * sizeof(float));
|
||||
for (i = 0; i < l.batch*locations; ++i) {
|
||||
int offset = i*(l.classes+l.coords);
|
||||
int bg = state.truth[offset];
|
||||
for (j = offset; j < offset+l.classes; ++j) {
|
||||
//*(l.cost) += pow(state.truth[j] - l.output[j], 2);
|
||||
//l.delta[j] = state.truth[j] - l.output[j];
|
||||
}
|
||||
|
||||
box anchor = {0,0,.5,.5};
|
||||
box truth_code = {state.truth[j+0], state.truth[j+1], state.truth[j+2], state.truth[j+3]};
|
||||
box out_code = {l.output[j+0], l.output[j+1], l.output[j+2], l.output[j+3]};
|
||||
box out = decode_box(out_code, anchor);
|
||||
box truth = decode_box(truth_code, anchor);
|
||||
|
||||
if(bg) continue;
|
||||
//printf("Box: %f %f %f %f\n", truth.x, truth.y, truth.w, truth.h);
|
||||
//printf("Code: %f %f %f %f\n", truth_code.x, truth_code.y, truth_code.w, truth_code.h);
|
||||
//printf("Pred : %f %f %f %f\n", out.x, out.y, out.w, out.h);
|
||||
// printf("Pred Code: %f %f %f %f\n", out_code.x, out_code.y, out_code.w, out_code.h);
|
||||
float iou = box_iou(out, truth);
|
||||
avg_iou += iou;
|
||||
++count;
|
||||
|
||||
/*
|
||||
*(l.cost) += pow((1-iou), 2);
|
||||
l.delta[j+0] = (state.truth[j+0] - l.output[j+0]);
|
||||
l.delta[j+1] = (state.truth[j+1] - l.output[j+1]);
|
||||
l.delta[j+2] = (state.truth[j+2] - l.output[j+2]);
|
||||
l.delta[j+3] = (state.truth[j+3] - l.output[j+3]);
|
||||
*/
|
||||
|
||||
for (j = offset+l.classes; j < offset+l.classes+l.coords; ++j) {
|
||||
//*(l.cost) += pow(state.truth[j] - l.output[j], 2);
|
||||
//l.delta[j] = state.truth[j] - l.output[j];
|
||||
float diff = state.truth[j] - l.output[j];
|
||||
if (fabs(diff) < 1){
|
||||
l.delta[j] = diff;
|
||||
*(l.cost) += .5*pow(state.truth[j] - l.output[j], 2);
|
||||
} else {
|
||||
l.delta[j] = (diff > 0) ? 1 : -1;
|
||||
*(l.cost) += fabs(diff) - .5;
|
||||
}
|
||||
//l.delta[j] = state.truth[j] - l.output[j];
|
||||
}
|
||||
|
||||
/*
|
||||
if(l.rescore){
|
||||
for (j = offset; j < offset+l.classes; ++j) {
|
||||
if(state.truth[j]) state.truth[j] = iou;
|
||||
l.delta[j] = state.truth[j] - l.output[j];
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
printf("Avg IOU: %f\n", avg_iou/count);
|
||||
}
|
||||
}
|
||||
|
||||
void backward_region_layer(const region_layer l, network_state state)
|
||||
{
|
||||
axpy_cpu(l.batch*l.inputs, 1, l.delta_gpu, 1, state.delta, 1);
|
||||
//copy_cpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1);
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
|
||||
void forward_region_layer_gpu(const region_layer l, network_state state)
|
||||
{
|
||||
float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
|
||||
float *truth_cpu = 0;
|
||||
if(state.truth){
|
||||
truth_cpu = calloc(l.batch*l.outputs, sizeof(float));
|
||||
cuda_pull_array(state.truth, truth_cpu, l.batch*l.outputs);
|
||||
}
|
||||
cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
|
||||
network_state cpu_state;
|
||||
cpu_state.train = state.train;
|
||||
cpu_state.truth = truth_cpu;
|
||||
cpu_state.input = in_cpu;
|
||||
forward_region_layer(l, cpu_state);
|
||||
cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
|
||||
cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs);
|
||||
free(cpu_state.input);
|
||||
if(cpu_state.truth) free(cpu_state.truth);
|
||||
}
|
||||
|
||||
void backward_region_layer_gpu(region_layer l, network_state state)
|
||||
{
|
||||
axpy_ongpu(l.batch*l.inputs, 1, l.delta_gpu, 1, state.delta, 1);
|
||||
//copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1);
|
||||
}
|
||||
#endif
|
||||
|
18
src/region_layer.h
Normal file
18
src/region_layer.h
Normal file
@ -0,0 +1,18 @@
|
||||
#ifndef REGION_LAYER_H
|
||||
#define REGION_LAYER_H
|
||||
|
||||
#include "params.h"
|
||||
#include "layer.h"
|
||||
|
||||
typedef layer region_layer;
|
||||
|
||||
region_layer make_region_layer(int batch, int inputs, int n, int classes, int coords, int rescore);
|
||||
void forward_region_layer(const region_layer l, network_state state);
|
||||
void backward_region_layer(const region_layer l, network_state state);
|
||||
|
||||
#ifdef GPU
|
||||
void forward_region_layer_gpu(const region_layer l, network_state state);
|
||||
void backward_region_layer_gpu(region_layer l, network_state state);
|
||||
#endif
|
||||
|
||||
#endif
|
34
src/yolo.c
34
src/yolo.c
@ -88,14 +88,26 @@ void train_yolo(char *cfgfile, char *weightfile)
|
||||
int background = layer.objectness;
|
||||
int side = sqrt(get_detection_layer_locations(layer));
|
||||
|
||||
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
|
||||
load_args args = {0};
|
||||
args.w = net.w;
|
||||
args.h = net.h;
|
||||
args.paths = paths;
|
||||
args.n = imgs;
|
||||
args.m = plist->size;
|
||||
args.classes = classes;
|
||||
args.num_boxes = side;
|
||||
args.background = background;
|
||||
args.d = &buffer;
|
||||
args.type = DETECTION_DATA;
|
||||
|
||||
pthread_t load_thread = load_data_in_thread(args);
|
||||
clock_t time;
|
||||
while(i*imgs < N*130){
|
||||
i += 1;
|
||||
time=clock();
|
||||
pthread_join(load_thread, 0);
|
||||
train = buffer;
|
||||
load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
|
||||
load_thread = load_data_in_thread(args);
|
||||
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
time=clock();
|
||||
@ -126,7 +138,7 @@ void train_yolo(char *cfgfile, char *weightfile)
|
||||
|
||||
pthread_join(load_thread, 0);
|
||||
free_data(buffer);
|
||||
load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
|
||||
load_thread = load_data_in_thread(args);
|
||||
}
|
||||
|
||||
if((i-1)*imgs <= 120*N && i*imgs > N*120){
|
||||
@ -237,8 +249,17 @@ void validate_yolo(char *cfgfile, char *weightfile)
|
||||
image *buf = calloc(nthreads, sizeof(image));
|
||||
image *buf_resized = calloc(nthreads, sizeof(image));
|
||||
pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
|
||||
|
||||
load_args args = {0};
|
||||
args.w = net.w;
|
||||
args.h = net.h;
|
||||
args.type = IMAGE_DATA;
|
||||
|
||||
for(t = 0; t < nthreads; ++t){
|
||||
thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h);
|
||||
args.path = paths[i+t];
|
||||
args.im = &buf[t];
|
||||
args.resized = &buf_resized[t];
|
||||
thr[t] = load_data_in_thread(args);
|
||||
}
|
||||
time_t start = time(0);
|
||||
for(i = nthreads; i < m+nthreads; i += nthreads){
|
||||
@ -249,7 +270,10 @@ void validate_yolo(char *cfgfile, char *weightfile)
|
||||
val_resized[t] = buf_resized[t];
|
||||
}
|
||||
for(t = 0; t < nthreads && i+t < m; ++t){
|
||||
thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h);
|
||||
args.path = paths[i+t];
|
||||
args.im = &buf[t];
|
||||
args.resized = &buf_resized[t];
|
||||
thr[t] = load_data_in_thread(args);
|
||||
}
|
||||
for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
|
||||
char *path = paths[i+t-nthreads];
|
||||
|
Loading…
Reference in New Issue
Block a user