From 9d42f49a240136a8cd643cdc1f98230d4f22b05e Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Mon, 24 Aug 2015 18:27:42 -0700 Subject: [PATCH] changing data loading --- Makefile | 9 +- cfg/darknet.cfg | 2 +- src/box.c | 19 + src/box.h | 2 + src/captcha.c | 117 +--- src/coco.c | 91 ++-- src/darknet.c | 3 - src/data.c | 308 +++++------ src/data.h | 33 +- src/detection.c | 305 ----------- src/detection_layer.c | 8 +- src/image.c | 1145 ++++++++++++++++++++-------------------- src/image.h | 2 + src/imagenet.c | 35 +- src/layer.h | 1 + src/network.c | 19 +- src/network_kernels.cu | 5 + src/parser.c | 18 + src/region_layer.c | 161 ++++++ src/region_layer.h | 18 + src/yolo.c | 34 +- 21 files changed, 1104 insertions(+), 1231 deletions(-) delete mode 100644 src/detection.c create mode 100644 src/region_layer.c create mode 100644 src/region_layer.h diff --git a/Makefile b/Makefile index 556f2a4a..8ce68885 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,9 @@ -GPU=0 -OPENCV=0 +GPU=1 +OPENCV=1 DEBUG=0 ARCH= --gpu-architecture=compute_20 --gpu-code=compute_20 +ARCH= -arch sm_52 VPATH=./src/ EXEC=darknet @@ -10,7 +11,7 @@ OBJDIR=./obj/ CC=gcc NVCC=nvcc -OPTS=-Ofast +OPTS=-O2 LDFLAGS= -lm -pthread -lstdc++ COMMON= -I/usr/local/cuda/include/ CFLAGS=-Wall -Wfatal-errors @@ -34,7 +35,7 @@ CFLAGS+= -DGPU LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand endif -OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o +OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o ifeq ($(GPU), 1) OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o endif diff --git a/cfg/darknet.cfg b/cfg/darknet.cfg index 2a99279c..2e0a6247 100644 --- a/cfg/darknet.cfg +++ b/cfg/darknet.cfg @@ -1,6 +1,6 @@ [net] batch=128 -subdivisions=1 +subdivisions=32 height=256 width=256 channels=3 diff --git a/src/box.c b/src/box.c index 0518c050..d49be410 100644 --- a/src/box.c +++ b/src/box.c @@ -231,3 +231,22 @@ void do_nms(box *boxes, float **probs, int num_boxes, int classes, float thresh) } } +box encode_box(box b, box anchor) +{ + box encode; + encode.x = (b.x - anchor.x) / anchor.w; + encode.y = (b.y - anchor.y) / anchor.h; + encode.w = log2(b.w / anchor.w); + encode.h = log2(b.h / anchor.h); + return encode; +} + +box decode_box(box b, box anchor) +{ + box decode; + decode.x = b.x * anchor.w + anchor.x; + decode.y = b.y * anchor.h + anchor.y; + decode.w = pow(2., b.w) * anchor.w; + decode.h = pow(2., b.h) * anchor.h; + return decode; +} diff --git a/src/box.h b/src/box.h index 998f58a2..e45dd890 100644 --- a/src/box.h +++ b/src/box.h @@ -12,5 +12,7 @@ typedef struct{ float box_iou(box a, box b); dbox diou(box a, box b); void do_nms(box *boxes, float **probs, int num_boxes, int classes, float thresh); +box decode_box(box b, box anchor); +box encode_box(box b, box anchor); #endif diff --git a/src/captcha.c b/src/captcha.c index 772f9d7c..68d8915f 100644 --- a/src/captcha.c +++ b/src/captcha.c @@ -26,7 +26,7 @@ void fix_data_captcha(data d, int mask) } } -void train_captcha2(char *cfgfile, char *weightfile) +void train_captcha(char *cfgfile, char *weightfile) { data_seed = time(0); srand(time(0)); @@ -55,7 +55,19 @@ void train_captcha2(char *cfgfile, char *weightfile) pthread_t load_thread; data train; data buffer; - load_thread = load_data_thread(paths, imgs, plist->size, labels, 26, net.w, net.h, &buffer); + + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.paths = paths; + args.classes = 26; + args.n = imgs; + args.m = plist->size; + args.labels = labels; + args.d = &buffer; + args.type = CLASSIFICATION_DATA; + + load_thread = load_data_in_thread(args); while(1){ ++i; time=clock(); @@ -69,7 +81,7 @@ void train_captcha2(char *cfgfile, char *weightfile) cvWaitKey(0); */ - load_thread = load_data_thread(paths, imgs, plist->size, labels, 26, net.w, net.h, &buffer); + load_thread = load_data_in_thread(args); printf("Loaded: %lf seconds\n", sec(clock()-time)); time=clock(); float loss = train_network(net, train); @@ -86,7 +98,7 @@ void train_captcha2(char *cfgfile, char *weightfile) } } -void test_captcha2(char *cfgfile, char *weightfile, char *filename) +void test_captcha(char *cfgfile, char *weightfile, char *filename) { network net = parse_network_cfg(cfgfile); if(weightfile){ @@ -165,99 +177,6 @@ void valid_captcha(char *cfgfile, char *weightfile, char *filename) } } -void train_captcha(char *cfgfile, char *weightfile) -{ - data_seed = time(0); - srand(time(0)); - float avg_loss = -1; - char *base = basecfg(cfgfile); - printf("%s\n", base); - network net = parse_network_cfg(cfgfile); - if(weightfile){ - load_weights(&net, weightfile); - } - printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); - //net.seen=0; - int imgs = 1024; - int i = net.seen/imgs; - char **labels = get_labels("/data/captcha/reimgs.labels.list"); - list *plist = get_paths("/data/captcha/reimgs.train.list"); - char **paths = (char **)list_to_array(plist); - printf("%d\n", plist->size); - clock_t time; - pthread_t load_thread; - data train; - data buffer; - load_thread = load_data_thread(paths, imgs, plist->size, labels, 13, net.w, net.h, &buffer); - while(1){ - ++i; - time=clock(); - pthread_join(load_thread, 0); - train = buffer; - - /* - image im = float_to_image(256, 256, 3, train.X.vals[114]); - show_image(im, "training"); - cvWaitKey(0); - */ - - load_thread = load_data_thread(paths, imgs, plist->size, labels, 13, net.w, net.h, &buffer); - printf("Loaded: %lf seconds\n", sec(clock()-time)); - time=clock(); - float loss = train_network(net, train); - net.seen += imgs; - if(avg_loss == -1) avg_loss = loss; - avg_loss = avg_loss*.9 + loss*.1; - printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen); - free_data(train); - if(i%100==0){ - char buff[256]; - sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i); - save_weights(net, buff); - } - } -} - -void test_captcha(char *cfgfile, char *weightfile, char *filename) -{ - network net = parse_network_cfg(cfgfile); - if(weightfile){ - load_weights(&net, weightfile); - } - set_batch_network(&net, 1); - srand(2222222); - int i = 0; - char **names = get_labels("/data/captcha/reimgs.labels.list"); - char input[256]; - int indexes[13]; - while(1){ - if(filename){ - strncpy(input, filename, 256); - }else{ - //printf("Enter Image Path: "); - //fflush(stdout); - fgets(input, 256, stdin); - strtok(input, "\n"); - } - image im = load_image_color(input, net.w, net.h); - float *X = im.data; - float *predictions = network_predict(net, X); - top_predictions(net, 13, indexes); - //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); - for(i = 0; i < 13; ++i){ - int index = indexes[i]; - if(i != 0) printf(", "); - printf("%s %f", names[index], predictions[index]); - } - printf("\n"); - fflush(stdout); - free_image(im); - if (filename) break; - } -} - - - /* void train_captcha(char *cfgfile, char *weightfile) { @@ -435,8 +354,8 @@ void run_captcha(int argc, char **argv) char *cfg = argv[3]; char *weights = (argc > 4) ? argv[4] : 0; char *filename = (argc > 5) ? argv[5]: 0; - if(0==strcmp(argv[2], "train")) train_captcha2(cfg, weights); - else if(0==strcmp(argv[2], "test")) test_captcha2(cfg, weights, filename); + if(0==strcmp(argv[2], "train")) train_captcha(cfg, weights); + else if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights, filename); else if(0==strcmp(argv[2], "valid")) valid_captcha(cfg, weights, filename); //if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights); //else if(0==strcmp(argv[2], "encode")) encode_captcha(cfg, weights); diff --git a/src/coco.c b/src/coco.c index 66fddf54..d2a108a7 100644 --- a/src/coco.c +++ b/src/coco.c @@ -15,41 +15,32 @@ char *coco_classes[] = {"person","bicycle","car","motorcycle","airplane","bus"," int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90}; -void draw_coco(image im, float *box, int side, int objectness, char *label) +void draw_coco(image im, float *pred, int side, char *label) { - int classes = 80; - int elems = 4+classes+objectness; + int classes = 81; + int elems = 4+classes; int j; int r, c; for(r = 0; r < side; ++r){ for(c = 0; c < side; ++c){ j = (r*side + c) * elems; - float scale = 1; - if(objectness) scale = 1 - box[j++]; - int class = max_index(box+j, classes); - if(scale * box[j+class] > 0.2){ - int width = box[j+class]*5 + 1; - printf("%f %s\n", scale * box[j+class], coco_classes[class]); + int class = max_index(pred+j, classes); + if (class == 0) continue; + if (pred[j+class] > 0.2){ + int width = pred[j+class]*5 + 1; + printf("%f %s\n", pred[j+class], coco_classes[class-1]); float red = get_color(0,class,classes); float green = get_color(1,class,classes); float blue = get_color(2,class,classes); j += classes; - float x = box[j+0]; - float y = box[j+1]; - x = (x+c)/side; - y = (y+r)/side; - float w = box[j+2]; //*maxwidth; - float h = box[j+3]; //*maxheight; - h = h*h; - w = w*w; - int left = (x-w/2)*im.w; - int right = (x+w/2)*im.w; - int top = (y-h/2)*im.h; - int bot = (y+h/2)*im.h; - draw_box_width(im, left, top, right, bot, width, red, green, blue); + box predict = {pred[j+0], pred[j+1], pred[j+2], pred[j+3]}; + box anchor = {(c+.5)/side, (r+.5)/side, .5, .5}; + box decode = decode_box(predict, anchor); + + draw_bbox(im, decode, width, red, green, blue); } } } @@ -69,39 +60,47 @@ void train_coco(char *cfgfile, char *weightfile) if(weightfile){ load_weights(&net, weightfile); } - detection_layer layer = get_network_detection_layer(net); printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); int imgs = 128; int i = net.seen/imgs; data train, buffer; - int classes = layer.classes; - int background = layer.objectness; - int side = sqrt(get_detection_layer_locations(layer)); + int classes = 81; + int side = 7; - char **paths; list *plist = get_paths(train_images); int N = plist->size; + char **paths = (char **)list_to_array(plist); - paths = (char **)list_to_array(plist); - pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer); + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.paths = paths; + args.n = imgs; + args.m = plist->size; + args.classes = classes; + args.num_boxes = side; + args.d = &buffer; + args.type = REGION_DATA; + + pthread_t load_thread = load_data_in_thread(args); clock_t time; while(i*imgs < N*120){ i += 1; time=clock(); pthread_join(load_thread, 0); train = buffer; - load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer); + load_thread = load_data_in_thread(args); printf("Loaded: %lf seconds\n", sec(clock()-time)); - /* - image im = float_to_image(net.w, net.h, 3, train.X.vals[114]); - image copy = copy_image(im); - draw_coco(copy, train.y.vals[114], 7, layer.objectness, "truth"); - cvWaitKey(0); - free_image(copy); - */ +/* + image im = float_to_image(net.w, net.h, 3, train.X.vals[114]); + image copy = copy_image(im); + draw_coco(copy, train.y.vals[114], 7, "truth"); + cvWaitKey(0); + free_image(copy); + */ time=clock(); float loss = train_network(net, train); @@ -220,6 +219,11 @@ void validate_coco(char *cfgfile, char *weightfile) int nms = 1; float iou_thresh = .5; + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.type = IMAGE_DATA; + int nthreads = 8; image *val = calloc(nthreads, sizeof(image)); image *val_resized = calloc(nthreads, sizeof(image)); @@ -227,7 +231,10 @@ void validate_coco(char *cfgfile, char *weightfile) image *buf_resized = calloc(nthreads, sizeof(image)); pthread_t *thr = calloc(nthreads, sizeof(pthread_t)); for(t = 0; t < nthreads; ++t){ - thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h); + args.path = paths[i+t]; + args.im = &buf[t]; + args.resized = &buf_resized[t]; + thr[t] = load_data_in_thread(args); } time_t start = time(0); for(i = nthreads; i < m+nthreads; i += nthreads){ @@ -238,7 +245,10 @@ void validate_coco(char *cfgfile, char *weightfile) val_resized[t] = buf_resized[t]; } for(t = 0; t < nthreads && i+t < m; ++t){ - thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h); + args.path = paths[i+t]; + args.im = &buf[t]; + args.resized = &buf_resized[t]; + thr[t] = load_data_in_thread(args); } for(t = 0; t < nthreads && i+t-nthreads < m; ++t){ char *path = paths[i+t-nthreads]; @@ -267,7 +277,6 @@ void test_coco(char *cfgfile, char *weightfile, char *filename) if(weightfile){ load_weights(&net, weightfile); } - detection_layer layer = get_network_detection_layer(net); set_batch_network(&net, 1); srand(2222222); clock_t time; @@ -287,7 +296,7 @@ void test_coco(char *cfgfile, char *weightfile, char *filename) time=clock(); float *predictions = network_predict(net, X); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); - draw_coco(im, predictions, 7, layer.objectness, "predictions"); + draw_coco(im, predictions, 7, "predictions"); free_image(im); free_image(sized); #ifdef OPENCV diff --git a/src/darknet.c b/src/darknet.c index 9b6eadbb..0928f28b 100644 --- a/src/darknet.c +++ b/src/darknet.c @@ -12,7 +12,6 @@ #endif extern void run_imagenet(int argc, char **argv); -extern void run_detection(int argc, char **argv); extern void run_yolo(int argc, char **argv); extern void run_coco(int argc, char **argv); extern void run_writing(int argc, char **argv); @@ -164,8 +163,6 @@ int main(int argc, char **argv) run_imagenet(argc, argv); } else if (0 == strcmp(argv[1], "average")){ average(argc, argv); - } else if (0 == strcmp(argv[1], "detection")){ - run_detection(argc, argv); } else if (0 == strcmp(argv[1], "yolo")){ run_yolo(argc, argv); } else if (0 == strcmp(argv[1], "coco")){ diff --git a/src/data.c b/src/data.c index 6a41d135..a335e07f 100644 --- a/src/data.c +++ b/src/data.c @@ -8,25 +8,6 @@ unsigned int data_seed; -typedef struct load_args{ - char **paths; - int n; - int m; - char **labels; - int k; - int h; - int w; - int nh; - int nw; - int num_boxes; - int classes; - int background; - data *d; - char *path; - image *im; - image *resized; -} load_args; - list *get_paths(char *filename) { char *path; @@ -138,6 +119,89 @@ void randomize_boxes(box_label *b, int n) } } +void correct_boxes(box_label *boxes, int n, float dx, float dy, float sx, float sy, int flip) +{ + int i; + for(i = 0; i < n; ++i){ + boxes[i].left = boxes[i].left * sx - dx; + boxes[i].right = boxes[i].right * sx - dx; + boxes[i].top = boxes[i].top * sy - dy; + boxes[i].bottom = boxes[i].bottom* sy - dy; + + if(flip){ + float swap = boxes[i].left; + boxes[i].left = 1. - boxes[i].right; + boxes[i].right = 1. - swap; + } + + boxes[i].left = constrain(0, 1, boxes[i].left); + boxes[i].right = constrain(0, 1, boxes[i].right); + boxes[i].top = constrain(0, 1, boxes[i].top); + boxes[i].bottom = constrain(0, 1, boxes[i].bottom); + + boxes[i].x = (boxes[i].left+boxes[i].right)/2; + boxes[i].y = (boxes[i].top+boxes[i].bottom)/2; + boxes[i].w = (boxes[i].right - boxes[i].left); + boxes[i].h = (boxes[i].bottom - boxes[i].top); + + boxes[i].w = constrain(0, 1, boxes[i].w); + boxes[i].h = constrain(0, 1, boxes[i].h); + } +} + +void fill_truth_region(char *path, float *truth, int classes, int num_boxes, int flip, float dx, float dy, float sx, float sy) +{ + char *labelpath = find_replace(path, "images", "labels"); + labelpath = find_replace(labelpath, ".jpg", ".txt"); + labelpath = find_replace(labelpath, ".JPEG", ".txt"); + int count = 0; + box_label *boxes = read_boxes(labelpath, &count); + randomize_boxes(boxes, count); + correct_boxes(boxes, count, dx, dy, sx, sy, flip); + float x,y,w,h; + int id; + int i; + + for(i = 0; i < num_boxes*num_boxes*(4+classes); i += 4+classes){ + truth[i] = 1; + } + + for(i = 0; i < count; ++i){ + x = boxes[i].x; + y = boxes[i].y; + w = boxes[i].w; + h = boxes[i].h; + id = boxes[i].id; + + if (x <= 0 || x >= 1 || y <= 0 || y >= 1) continue; + if (w < .01 || h < .01) continue; + + int col = (int)(x*num_boxes); + int row = (int)(y*num_boxes); + + float xa = (col+.5)/num_boxes; + float ya = (row+.5)/num_boxes; + float wa = .5; + float ha = .5; + + float tx = (x - xa) / wa; + float ty = (y - ya) / ha; + float tw = log2(w/wa); + float th = log2(h/ha); + + int index = (col+row*num_boxes)*(4+classes); + if(!truth[index]) continue; + truth[index] = 0; + truth[index+id+1] = 1; + index += classes; + truth[index++] = tx; + truth[index++] = ty; + truth[index++] = tw; + truth[index++] = th; + } + free(boxes); +} + void fill_truth_detection(char *path, float *truth, int classes, int num_boxes, int flip, int background, float dx, float dy, float sx, float sy) { char *labelpath = find_replace(path, "JPEGImages", "labels"); @@ -178,20 +242,20 @@ void fill_truth_detection(char *path, float *truth, int classes, int num_boxes, w = (right - left); h = (bot - top); - if (x <= 0 || x >= 1 || y <= 0 || y >= 1) continue; + if (x <= 0 || x >= 1 || y <= 0 || y >= 1) continue; - int i = (int)(x*num_boxes); - int j = (int)(y*num_boxes); + int col = (int)(x*num_boxes); + int row = (int)(y*num_boxes); - x = x*num_boxes - i; - y = y*num_boxes - j; + x = x*num_boxes - col; + y = y*num_boxes - row; /* - float maxwidth = distance_from_edge(i, num_boxes); - float maxheight = distance_from_edge(j, num_boxes); - w = w/maxwidth; - h = h/maxheight; - */ + float maxwidth = distance_from_edge(i, num_boxes); + float maxheight = distance_from_edge(j, num_boxes); + w = w/maxwidth; + h = h/maxheight; + */ w = constrain(0, 1, w); h = constrain(0, 1, h); @@ -201,7 +265,7 @@ void fill_truth_detection(char *path, float *truth, int classes, int num_boxes, h = pow(h, 1./2.); } - int index = (i+j*num_boxes)*(4+classes+background); + int index = (col+row*num_boxes)*(4+classes+background); if(truth[index+classes+background+2]) continue; if(background) truth[index++] = 0; truth[index+id] = 1; @@ -214,57 +278,6 @@ void fill_truth_detection(char *path, float *truth, int classes, int num_boxes, free(boxes); } -void fill_truth_localization(char *path, float *truth, int classes, int flip, float dx, float dy, float sx, float sy) -{ - char *labelpath = find_replace(path, "objects", "object_labels"); - labelpath = find_replace(labelpath, ".jpg", ".txt"); - labelpath = find_replace(labelpath, ".JPEG", ".txt"); - int count; - box_label *boxes = read_boxes(labelpath, &count); - box_label box = boxes[0]; - free(boxes); - float x,y,w,h; - float left, top, right, bot; - int id; - int i; - for(i = 0; i < count; ++i){ - left = box.left * sx - dx; - right = box.right * sx - dx; - top = box.top * sy - dy; - bot = box.bottom* sy - dy; - id = box.id; - - if(flip){ - float swap = left; - left = 1. - right; - right = 1. - swap; - } - - left = constrain(0, 1, left); - right = constrain(0, 1, right); - top = constrain(0, 1, top); - bot = constrain(0, 1, bot); - - x = (left+right)/2; - y = (top+bot)/2; - w = (right - left); - h = (bot - top); - - if (x <= 0 || x >= 1 || y <= 0 || y >= 1) continue; - - w = constrain(0, 1, w); - h = constrain(0, 1, h); - if (w == 0 || h == 0) continue; - - int index = id*4; - truth[index++] = x; - truth[index++] = y; - truth[index++] = w; - truth[index++] = h; - } -} - - #define NUMCHARS 37 void print_letters(float *pred, int n) @@ -362,7 +375,7 @@ void free_data(data d) } } -data load_data_localization(int n, char **paths, int m, int classes, int w, int h) +data load_data_region(int n, char **paths, int m, int classes, int w, int h, int num_boxes) { char **random_paths = get_random_paths(paths, n, m); int i; @@ -373,7 +386,7 @@ data load_data_localization(int n, char **paths, int m, int classes, int w, int d.X.vals = calloc(d.X.rows, sizeof(float*)); d.X.cols = h*w*3; - int k = (4*classes); + int k = num_boxes*num_boxes*(4+classes); d.y = make_matrix(n, k); for(i = 0; i < n; ++i){ image orig = load_image_color(random_paths[i], 0, 0); @@ -381,13 +394,13 @@ data load_data_localization(int n, char **paths, int m, int classes, int w, int int oh = orig.h; int ow = orig.w; - int dw = 32; - int dh = 32; + int dw = ow/10; + int dh = oh/10; - int pleft = (rand_uniform() * dw); - int pright = (rand_uniform() * dw); - int ptop = (rand_uniform() * dh); - int pbot = (rand_uniform() * dh); + int pleft = (rand_uniform() * 2*dw - dw); + int pright = (rand_uniform() * 2*dw - dw); + int ptop = (rand_uniform() * 2*dh - dh); + int pbot = (rand_uniform() * 2*dh - dh); int swidth = ow - pleft - pright; int sheight = oh - ptop - pbot; @@ -397,22 +410,24 @@ data load_data_localization(int n, char **paths, int m, int classes, int w, int int flip = rand_r(&data_seed)%2; image cropped = crop_image(orig, pleft, ptop, swidth, sheight); + float dx = ((float)pleft/ow)/sx; float dy = ((float)ptop /oh)/sy; - free_image(orig); image sized = resize_image(cropped, w, h); - free_image(cropped); if(flip) flip_image(sized); d.X.vals[i] = sized.data; - fill_truth_localization(random_paths[i], d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy); + fill_truth_region(random_paths[i], d.y.vals[i], classes, num_boxes, flip, dx, dy, 1./sx, 1./sy); + + free_image(orig); + free_image(cropped); } free(random_paths); return d; } -data load_data_detection_jitter_random(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background) +data load_data_detection(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background) { char **random_paths = get_random_paths(paths, n, m); int i; @@ -471,81 +486,30 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int classes, return d; } -void *load_image_in_thread(void *ptr) -{ - load_args a = *(load_args*)ptr; - free(ptr); - *(a.im) = load_image_color(a.path, 0, 0); - *(a.resized) = resize_image(*(a.im), a.w, a.h); - return 0; -} - -pthread_t load_image_thread(char *path, image *im, image *resized, int w, int h) -{ - pthread_t thread; - struct load_args *args = calloc(1, sizeof(struct load_args)); - args->path = path; - args->w = w; - args->h = h; - args->im = im; - args->resized = resized; - if(pthread_create(&thread, 0, load_image_in_thread, args)) { - error("Thread creation failed"); - } - return thread; -} - -void *load_localization_thread(void *ptr) +void *load_thread(void *ptr) { printf("Loading data: %d\n", rand_r(&data_seed)); - struct load_args a = *(struct load_args*)ptr; - *a.d = load_data_localization(a.n, a.paths, a.m, a.classes, a.w, a.h); - free(ptr); - return 0; -} - -pthread_t load_data_localization_thread(int n, char **paths, int m, int classes, int w, int h, data *d) -{ - pthread_t thread; - struct load_args *args = calloc(1, sizeof(struct load_args)); - args->n = n; - args->paths = paths; - args->m = m; - args->w = w; - args->h = h; - args->classes = classes; - args->d = d; - if(pthread_create(&thread, 0, load_localization_thread, args)) { - error("Thread creation failed"); + load_args a = *(struct load_args*)ptr; + if (a.type == CLASSIFICATION_DATA){ + *a.d = load_data(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h); + } else if (a.type == DETECTION_DATA){ + *a.d = load_data_detection(a.n, a.paths, a.m, a.classes, a.w, a.h, a.num_boxes, a.background); + } else if (a.type == REGION_DATA){ + *a.d = load_data_region(a.n, a.paths, a.m, a.classes, a.w, a.h, a.num_boxes); + } else if (a.type == IMAGE_DATA){ + *(a.im) = load_image_color(a.path, 0, 0); + *(a.resized) = resize_image(*(a.im), a.w, a.h); } - return thread; -} - -void *load_detection_thread(void *ptr) -{ - printf("Loading data: %d\n", rand_r(&data_seed)); - struct load_args a = *(struct load_args*)ptr; - *a.d = load_data_detection_jitter_random(a.n, a.paths, a.m, a.classes, a.w, a.h, a.num_boxes, a.background); free(ptr); return 0; } -pthread_t load_data_detection_thread(int n, char **paths, int m, int classes, int w, int h, int nh, int nw, int background, data *d) +pthread_t load_data_in_thread(load_args args) { pthread_t thread; - struct load_args *args = calloc(1, sizeof(struct load_args)); - args->n = n; - args->paths = paths; - args->m = m; - args->h = h; - args->w = w; - args->nh = nh; - args->nw = nw; - args->num_boxes = nw; - args->classes = classes; - args->background = background; - args->d = d; - if(pthread_create(&thread, 0, load_detection_thread, args)) { + struct load_args *ptr = calloc(1, sizeof(struct load_args)); + *ptr = args; + if(pthread_create(&thread, 0, load_thread, ptr)) { error("Thread creation failed"); } return thread; @@ -577,32 +541,6 @@ data load_data(char **paths, int n, int m, char **labels, int k, int w, int h) return d; } -void *load_in_thread(void *ptr) -{ - struct load_args a = *(struct load_args*)ptr; - *a.d = load_data(a.paths, a.n, a.m, a.labels, a.k, a.w, a.h); - free(ptr); - return 0; -} - -pthread_t load_data_thread(char **paths, int n, int m, char **labels, int k, int w, int h, data *d) -{ - pthread_t thread; - struct load_args *args = calloc(1, sizeof(struct load_args)); - args->n = n; - args->paths = paths; - args->m = m; - args->labels = labels; - args->k = k; - args->h = h; - args->w = w; - args->d = d; - if(pthread_create(&thread, 0, load_in_thread, args)) { - error("Thread creation failed"); - } - return thread; -} - matrix concat_matrix(matrix m1, matrix m2) { int i, count = 0; diff --git a/src/data.h b/src/data.h index 993a7a7f..f71e04a5 100644 --- a/src/data.h +++ b/src/data.h @@ -19,26 +19,45 @@ static inline float distance_from_edge(int x, int max) return dist; } - typedef struct{ matrix X; matrix y; int shallow; } data; +typedef enum { + CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA +} data_type; + +typedef struct load_args{ + char **paths; + char *path; + int n; + int m; + char **labels; + int k; + int h; + int w; + int nh; + int nw; + int num_boxes; + int classes; + int background; + data *d; + image *im; + image *resized; + data_type type; +} load_args; void free_data(data d); +pthread_t load_data_in_thread(load_args args); + void print_letters(float *pred, int n); data load_data_captcha(char **paths, int n, int m, int k, int w, int h); data load_data_captcha_encode(char **paths, int n, int m, int w, int h); data load_data(char **paths, int n, int m, char **labels, int k, int w, int h); -pthread_t load_data_thread(char **paths, int n, int m, char **labels, int k, int w, int h, data *d); -pthread_t load_image_thread(char *path, image *im, image *resized, int w, int h); - -pthread_t load_data_detection_thread(int n, char **paths, int m, int classes, int w, int h, int nh, int nw, int background, data *d); -data load_data_detection_jitter_random(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background); -pthread_t load_data_localization_thread(int n, char **paths, int m, int classes, int w, int h, data *d); +data load_data_detection(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background); data load_cifar10_data(char *filename); data load_all_cifar10(); diff --git a/src/detection.c b/src/detection.c deleted file mode 100644 index c9dea3e8..00000000 --- a/src/detection.c +++ /dev/null @@ -1,305 +0,0 @@ -#include "network.h" -#include "detection_layer.h" -#include "cost_layer.h" -#include "utils.h" -#include "parser.h" -#include "box.h" - -#ifdef OPENCV -#include "opencv2/highgui/highgui_c.h" -#endif - -char *class_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"}; - -void draw_detection(image im, float *box, int side, int objectness, char *label) -{ - int classes = 20; - int elems = 4+classes+objectness; - int j; - int r, c; - - for(r = 0; r < side; ++r){ - for(c = 0; c < side; ++c){ - j = (r*side + c) * elems; - float scale = 1; - if(objectness) scale = 1 - box[j++]; - int class = max_index(box+j, classes); - if(scale * box[j+class] > 0.2){ - int width = box[j+class]*5 + 1; - printf("%f %s\n", scale * box[j+class], class_names[class]); - float red = get_color(0,class,classes); - float green = get_color(1,class,classes); - float blue = get_color(2,class,classes); - - j += classes; - float x = box[j+0]; - float y = box[j+1]; - x = (x+c)/side; - y = (y+r)/side; - float w = box[j+2]; //*maxwidth; - float h = box[j+3]; //*maxheight; - h = h*h; - w = w*w; - - int left = (x-w/2)*im.w; - int right = (x+w/2)*im.w; - int top = (y-h/2)*im.h; - int bot = (y+h/2)*im.h; - draw_box_width(im, left, top, right, bot, width, red, green, blue); - } - } - } - show_image(im, label); -} - -void train_detection(char *cfgfile, char *weightfile) -{ - char *train_images = "/home/pjreddie/data/voc/test/train.txt"; - char *backup_directory = "/home/pjreddie/backup/"; - srand(time(0)); - data_seed = time(0); - char *base = basecfg(cfgfile); - printf("%s\n", base); - float avg_loss = -1; - network net = parse_network_cfg(cfgfile); - if(weightfile){ - load_weights(&net, weightfile); - } - detection_layer layer = get_network_detection_layer(net); - printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); - int imgs = 128; - int i = net.seen/imgs; - data train, buffer; - - int classes = layer.classes; - int background = layer.objectness; - int side = sqrt(get_detection_layer_locations(layer)); - - char **paths; - list *plist = get_paths(train_images); - int N = plist->size; - - paths = (char **)list_to_array(plist); - pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer); - clock_t time; - while(i*imgs < N*130){ - i += 1; - time=clock(); - pthread_join(load_thread, 0); - train = buffer; - load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer); - - printf("Loaded: %lf seconds\n", sec(clock()-time)); - time=clock(); - float loss = train_network(net, train); - net.seen += imgs; - if (avg_loss < 0) avg_loss = loss; - avg_loss = avg_loss*.9 + loss*.1; - - printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs); - if((i-1)*imgs <= N && i*imgs > N){ - fprintf(stderr, "First stage done\n"); - net.learning_rate *= 10; - char buff[256]; - sprintf(buff, "%s/%s_first_stage.weights", backup_directory, base); - save_weights(net, buff); - } - if((i-1)*imgs <= 80*N && i*imgs > N*80){ - fprintf(stderr, "Second stage done.\n"); - net.learning_rate *= .1; - char buff[256]; - sprintf(buff, "%s/%s_second_stage.weights", backup_directory, base); - save_weights(net, buff); - return; - } - if((i-1)*imgs <= 120*N && i*imgs > N*120){ - fprintf(stderr, "Third stage done.\n"); - char buff[256]; - sprintf(buff, "%s/%s_third_stage.weights", backup_directory, base); - net.layers[net.n-1].rescore = 1; - save_weights(net, buff); - } - if(i%1000==0){ - char buff[256]; - sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); - save_weights(net, buff); - } - free_data(train); - } - char buff[256]; - sprintf(buff, "%s/%s_final.weights", backup_directory, base); - save_weights(net, buff); -} - -void convert_detections(float *predictions, int classes, int objectness, int background, int num_boxes, int w, int h, float thresh, float **probs, box *boxes) -{ - int i,j; - int per_box = 4+classes+(background || objectness); - for (i = 0; i < num_boxes*num_boxes; ++i){ - float scale = 1; - if(objectness) scale = 1-predictions[i*per_box]; - int offset = i*per_box+(background||objectness); - for(j = 0; j < classes; ++j){ - float prob = scale*predictions[offset+j]; - probs[i][j] = (prob > thresh) ? prob : 0; - } - int row = i / num_boxes; - int col = i % num_boxes; - offset += classes; - boxes[i].x = (predictions[offset + 0] + col) / num_boxes * w; - boxes[i].y = (predictions[offset + 1] + row) / num_boxes * h; - boxes[i].w = pow(predictions[offset + 2], 2) * w; - boxes[i].h = pow(predictions[offset + 3], 2) * h; - } -} - -void print_detections(FILE **fps, char *id, box *boxes, float **probs, int num_boxes, int classes, int w, int h) -{ - int i, j; - for(i = 0; i < num_boxes*num_boxes; ++i){ - float xmin = boxes[i].x - boxes[i].w/2.; - float xmax = boxes[i].x + boxes[i].w/2.; - float ymin = boxes[i].y - boxes[i].h/2.; - float ymax = boxes[i].y + boxes[i].h/2.; - - if (xmin < 0) xmin = 0; - if (ymin < 0) ymin = 0; - if (xmax > w) xmax = w; - if (ymax > h) ymax = h; - - for(j = 0; j < classes; ++j){ - if (probs[i][j]) fprintf(fps[j], "%s %f %f %f %f %f\n", id, probs[i][j], - xmin, ymin, xmax, ymax); - } - } -} - -void validate_detection(char *cfgfile, char *weightfile) -{ - network net = parse_network_cfg(cfgfile); - if(weightfile){ - load_weights(&net, weightfile); - } - set_batch_network(&net, 1); - detection_layer layer = get_network_detection_layer(net); - fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); - srand(time(0)); - - char *base = "results/comp4_det_test_"; - list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt"); - char **paths = (char **)list_to_array(plist); - - int classes = layer.classes; - int objectness = layer.objectness; - int background = layer.background; - int num_boxes = sqrt(get_detection_layer_locations(layer)); - - int j; - FILE **fps = calloc(classes, sizeof(FILE *)); - for(j = 0; j < classes; ++j){ - char buff[1024]; - snprintf(buff, 1024, "%s%s.txt", base, class_names[j]); - fps[j] = fopen(buff, "w"); - } - box *boxes = calloc(num_boxes*num_boxes, sizeof(box)); - float **probs = calloc(num_boxes*num_boxes, sizeof(float *)); - for(j = 0; j < num_boxes*num_boxes; ++j) probs[j] = calloc(classes, sizeof(float *)); - - int m = plist->size; - int i=0; - int t; - - float thresh = .001; - int nms = 1; - float iou_thresh = .5; - - int nthreads = 8; - image *val = calloc(nthreads, sizeof(image)); - image *val_resized = calloc(nthreads, sizeof(image)); - image *buf = calloc(nthreads, sizeof(image)); - image *buf_resized = calloc(nthreads, sizeof(image)); - pthread_t *thr = calloc(nthreads, sizeof(pthread_t)); - for(t = 0; t < nthreads; ++t){ - thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h); - } - time_t start = time(0); - for(i = nthreads; i < m+nthreads; i += nthreads){ - fprintf(stderr, "%d\n", i); - for(t = 0; t < nthreads && i+t-nthreads < m; ++t){ - pthread_join(thr[t], 0); - val[t] = buf[t]; - val_resized[t] = buf_resized[t]; - } - for(t = 0; t < nthreads && i+t < m; ++t){ - thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h); - } - for(t = 0; t < nthreads && i+t-nthreads < m; ++t){ - char *path = paths[i+t-nthreads]; - char *id = basecfg(path); - float *X = val_resized[t].data; - float *predictions = network_predict(net, X); - int w = val[t].w; - int h = val[t].h; - convert_detections(predictions, classes, objectness, background, num_boxes, w, h, thresh, probs, boxes); - if (nms) do_nms(boxes, probs, num_boxes, classes, iou_thresh); - print_detections(fps, id, boxes, probs, num_boxes, classes, w, h); - free(id); - free_image(val[t]); - free_image(val_resized[t]); - } - } - fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start)); -} - -void test_detection(char *cfgfile, char *weightfile, char *filename) -{ - - network net = parse_network_cfg(cfgfile); - if(weightfile){ - load_weights(&net, weightfile); - } - detection_layer layer = get_network_detection_layer(net); - set_batch_network(&net, 1); - srand(2222222); - clock_t time; - char input[256]; - while(1){ - if(filename){ - strncpy(input, filename, 256); - } else { - printf("Enter Image Path: "); - fflush(stdout); - fgets(input, 256, stdin); - strtok(input, "\n"); - } - image im = load_image_color(input,0,0); - image sized = resize_image(im, net.w, net.h); - float *X = sized.data; - time=clock(); - float *predictions = network_predict(net, X); - printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); - draw_detection(im, predictions, 7, layer.objectness, "predictions"); - free_image(im); - free_image(sized); -#ifdef OPENCV - cvWaitKey(0); - cvDestroyAllWindows(); -#endif - if (filename) break; - } -} - -void run_detection(int argc, char **argv) -{ - if(argc < 4){ - fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); - return; - } - - char *cfg = argv[3]; - char *weights = (argc > 4) ? argv[4] : 0; - char *filename = (argc > 5) ? argv[5]: 0; - if(0==strcmp(argv[2], "test")) test_detection(cfg, weights, filename); - else if(0==strcmp(argv[2], "train")) train_detection(cfg, weights); - else if(0==strcmp(argv[2], "valid")) validate_detection(cfg, weights); -} diff --git a/src/detection_layer.c b/src/detection_layer.c index e48b8b37..f83e2e47 100644 --- a/src/detection_layer.c +++ b/src/detection_layer.c @@ -97,6 +97,7 @@ void forward_detection_layer(const detection_layer l, network_state state) truth.y = state.truth[j+1]/7; truth.w = pow(state.truth[j+2], 2); truth.h = pow(state.truth[j+3], 2); + box out; out.x = l.output[j+0]/7; out.y = l.output[j+1]/7; @@ -107,13 +108,6 @@ void forward_detection_layer(const detection_layer l, network_state state) float iou = box_iou(out, truth); avg_iou += iou; ++count; - dbox delta = diou(out, truth); - - l.delta[j+0] = 10 * delta.dx/7; - l.delta[j+1] = 10 * delta.dy/7; - l.delta[j+2] = 10 * delta.dw * 2 * sqrt(out.w); - l.delta[j+3] = 10 * delta.dh * 2 * sqrt(out.h); - *(l.cost) += pow((1-iou), 2); l.delta[j+0] = 4 * (state.truth[j+0] - l.output[j+0]); diff --git a/src/image.c b/src/image.c index d86f752a..8669294c 100644 --- a/src/image.c +++ b/src/image.c @@ -72,6 +72,19 @@ void draw_box_width(image a, int x1, int y1, int x2, int y2, int w, float r, flo } } +void draw_bbox(image a, box bbox, int w, float r, float g, float b) +{ + int left = (bbox.x-bbox.w/2)*a.w; + int right = (bbox.x+bbox.w/2)*a.w; + int top = (bbox.y-bbox.h/2)*a.h; + int bot = (bbox.y+bbox.h/2)*a.h; + + int i; + for(i = 0; i < w; ++i){ + draw_box(a, left+i, top+i, right-i, bot-i, r, g, b); + } +} + void flip_image(image a) { int i,j,k; @@ -214,7 +227,7 @@ void show_image_cv(image p, char *name) } free_image(copy); if(0){ - //if(disp->height < 448 || disp->width < 448 || disp->height > 1000){ + //if(disp->height < 448 || disp->width < 448 || disp->height > 1000){ int w = 448; int h = w*p.h/p.w; if(h > 1000){ @@ -228,41 +241,41 @@ void show_image_cv(image p, char *name) } cvShowImage(buff, disp); cvReleaseImage(&disp); -} + } #endif -void show_image(image p, char *name) -{ - #ifdef OPENCV - show_image_cv(p, name); - #else - fprintf(stderr, "Not compiled with OpenCV, saving to %s.png instead\n", name); - save_image(p, name); - #endif -} - -void save_image(image im, char *name) -{ - char buff[256]; - //sprintf(buff, "%s (%d)", name, windows); - sprintf(buff, "%s.png", name); - unsigned char *data = calloc(im.w*im.h*im.c, sizeof(char)); - int i,k; - for(k = 0; k < im.c; ++k){ - for(i = 0; i < im.w*im.h; ++i){ - data[i*im.c+k] = (unsigned char) (255*im.data[i + k*im.w*im.h]); - } + void show_image(image p, char *name) + { +#ifdef OPENCV + show_image_cv(p, name); +#else + fprintf(stderr, "Not compiled with OpenCV, saving to %s.png instead\n", name); + save_image(p, name); +#endif } - int success = stbi_write_png(buff, im.w, im.h, im.c, data, im.w*im.c); - free(data); - if(!success) fprintf(stderr, "Failed to write image %s\n", buff); -} -/* -void save_image_cv(image p, char *name) -{ - int x,y,k; - image copy = copy_image(p); + void save_image(image im, char *name) + { + char buff[256]; + //sprintf(buff, "%s (%d)", name, windows); + sprintf(buff, "%s.png", name); + unsigned char *data = calloc(im.w*im.h*im.c, sizeof(char)); + int i,k; + for(k = 0; k < im.c; ++k){ + for(i = 0; i < im.w*im.h; ++i){ + data[i*im.c+k] = (unsigned char) (255*im.data[i + k*im.w*im.h]); + } + } + int success = stbi_write_png(buff, im.w, im.h, im.c, data, im.w*im.c); + free(data); + if(!success) fprintf(stderr, "Failed to write image %s\n", buff); + } + + /* + void save_image_cv(image p, char *name) + { + int x,y,k; + image copy = copy_image(p); //normalize_image(copy); char buff[256]; @@ -272,603 +285,603 @@ void save_image_cv(image p, char *name) IplImage *disp = cvCreateImage(cvSize(p.w,p.h), IPL_DEPTH_8U, p.c); int step = disp->widthStep; for(y = 0; y < p.h; ++y){ - for(x = 0; x < p.w; ++x){ - for(k= 0; k < p.c; ++k){ - disp->imageData[y*step + x*p.c + k] = (unsigned char)(get_pixel(copy,x,y,k)*255); - } - } + for(x = 0; x < p.w; ++x){ + for(k= 0; k < p.c; ++k){ + disp->imageData[y*step + x*p.c + k] = (unsigned char)(get_pixel(copy,x,y,k)*255); + } + } } free_image(copy); cvSaveImage(buff, disp,0); cvReleaseImage(&disp); -} -*/ - -void show_image_layers(image p, char *name) -{ - int i; - char buff[256]; - for(i = 0; i < p.c; ++i){ - sprintf(buff, "%s - Layer %d", name, i); - image layer = get_image_layer(p, i); - show_image(layer, buff); - free_image(layer); } -} + */ -void show_image_collapsed(image p, char *name) -{ - image c = collapse_image_layers(p, 1); - show_image(c, name); - free_image(c); -} - -image make_empty_image(int w, int h, int c) -{ - image out; - out.data = 0; - out.h = h; - out.w = w; - out.c = c; - return out; -} - -image make_image(int w, int h, int c) -{ - image out = make_empty_image(w,h,c); - out.data = calloc(h*w*c, sizeof(float)); - return out; -} - -image float_to_image(int w, int h, int c, float *data) -{ - image out = make_empty_image(w,h,c); - out.data = data; - return out; -} - -image rotate_image(image im, float rad) -{ - int x, y, c; - float cx = im.w/2.; - float cy = im.h/2.; - image rot = make_image(im.w, im.h, im.c); - for(c = 0; c < im.c; ++c){ - for(y = 0; y < im.h; ++y){ - for(x = 0; x < im.w; ++x){ - float rx = cos(rad)*(x-cx) - sin(rad)*(y-cy) + cx; - float ry = sin(rad)*(x-cx) + cos(rad)*(y-cy) + cy; - float val = bilinear_interpolate(im, rx, ry, c); - set_pixel(rot, x, y, c, val); - } + void show_image_layers(image p, char *name) + { + int i; + char buff[256]; + for(i = 0; i < p.c; ++i){ + sprintf(buff, "%s - Layer %d", name, i); + image layer = get_image_layer(p, i); + show_image(layer, buff); + free_image(layer); } } - return rot; -} -void translate_image(image m, float s) -{ - int i; - for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] += s; -} - -void scale_image(image m, float s) -{ - int i; - for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] *= s; -} - -image crop_image(image im, int dx, int dy, int w, int h) -{ - image cropped = make_image(w, h, im.c); - int i, j, k; - for(k = 0; k < im.c; ++k){ - for(j = 0; j < h; ++j){ - for(i = 0; i < w; ++i){ - int r = j + dy; - int c = i + dx; - float val = 0; - if (r >= 0 && r < im.h && c >= 0 && c < im.w) { - val = get_pixel(im, c, r, k); - } - set_pixel(cropped, i, j, k, val); - } - } + void show_image_collapsed(image p, char *name) + { + image c = collapse_image_layers(p, 1); + show_image(c, name); + free_image(c); } - return cropped; -} -float three_way_max(float a, float b, float c) -{ - return (a > b) ? ( (a > c) ? a : c) : ( (b > c) ? b : c) ; -} - -float three_way_min(float a, float b, float c) -{ - return (a < b) ? ( (a < c) ? a : c) : ( (b < c) ? b : c) ; -} - -// http://www.cs.rit.edu/~ncs/color/t_convert.html -void rgb_to_hsv(image im) -{ - assert(im.c == 3); - int i, j; - float r, g, b; - float h, s, v; - for(j = 0; j < im.h; ++j){ - for(i = 0; i < im.w; ++i){ - r = get_pixel(im, i , j, 0); - g = get_pixel(im, i , j, 1); - b = get_pixel(im, i , j, 2); - float max = three_way_max(r,g,b); - float min = three_way_min(r,g,b); - float delta = max - min; - v = max; - if(max == 0){ - s = 0; - h = -1; - }else{ - s = delta/max; - if(r == max){ - h = (g - b) / delta; - } else if (g == max) { - h = 2 + (b - r) / delta; - } else { - h = 4 + (r - g) / delta; - } - if (h < 0) h += 6; - } - set_pixel(im, i, j, 0, h); - set_pixel(im, i, j, 1, s); - set_pixel(im, i, j, 2, v); - } + image make_empty_image(int w, int h, int c) + { + image out; + out.data = 0; + out.h = h; + out.w = w; + out.c = c; + return out; } -} -void hsv_to_rgb(image im) -{ - assert(im.c == 3); - int i, j; - float r, g, b; - float h, s, v; - float f, p, q, t; - for(j = 0; j < im.h; ++j){ - for(i = 0; i < im.w; ++i){ - h = get_pixel(im, i , j, 0); - s = get_pixel(im, i , j, 1); - v = get_pixel(im, i , j, 2); - if (s == 0) { - r = g = b = v; - } else { - int index = floor(h); - f = h - index; - p = v*(1-s); - q = v*(1-s*f); - t = v*(1-s*(1-f)); - if(index == 0){ - r = v; g = t; b = p; - } else if(index == 1){ - r = q; g = v; b = p; - } else if(index == 2){ - r = p; g = v; b = t; - } else if(index == 3){ - r = p; g = q; b = v; - } else if(index == 4){ - r = t; g = p; b = v; - } else { - r = v; g = p; b = q; + image make_image(int w, int h, int c) + { + image out = make_empty_image(w,h,c); + out.data = calloc(h*w*c, sizeof(float)); + return out; + } + + image float_to_image(int w, int h, int c, float *data) + { + image out = make_empty_image(w,h,c); + out.data = data; + return out; + } + + image rotate_image(image im, float rad) + { + int x, y, c; + float cx = im.w/2.; + float cy = im.h/2.; + image rot = make_image(im.w, im.h, im.c); + for(c = 0; c < im.c; ++c){ + for(y = 0; y < im.h; ++y){ + for(x = 0; x < im.w; ++x){ + float rx = cos(rad)*(x-cx) - sin(rad)*(y-cy) + cx; + float ry = sin(rad)*(x-cx) + cos(rad)*(y-cy) + cy; + float val = bilinear_interpolate(im, rx, ry, c); + set_pixel(rot, x, y, c, val); } } - set_pixel(im, i, j, 0, r); - set_pixel(im, i, j, 1, g); - set_pixel(im, i, j, 2, b); } + return rot; } -} -image grayscale_image(image im) -{ - assert(im.c == 3); - int i, j, k; - image gray = make_image(im.w, im.h, im.c); - float scale[] = {0.587, 0.299, 0.114}; - for(k = 0; k < im.c; ++k){ + void translate_image(image m, float s) + { + int i; + for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] += s; + } + + void scale_image(image m, float s) + { + int i; + for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] *= s; + } + + image crop_image(image im, int dx, int dy, int w, int h) + { + image cropped = make_image(w, h, im.c); + int i, j, k; + for(k = 0; k < im.c; ++k){ + for(j = 0; j < h; ++j){ + for(i = 0; i < w; ++i){ + int r = j + dy; + int c = i + dx; + float val = 0; + if (r >= 0 && r < im.h && c >= 0 && c < im.w) { + val = get_pixel(im, c, r, k); + } + set_pixel(cropped, i, j, k, val); + } + } + } + return cropped; + } + + float three_way_max(float a, float b, float c) + { + return (a > b) ? ( (a > c) ? a : c) : ( (b > c) ? b : c) ; + } + + float three_way_min(float a, float b, float c) + { + return (a < b) ? ( (a < c) ? a : c) : ( (b < c) ? b : c) ; + } + + // http://www.cs.rit.edu/~ncs/color/t_convert.html + void rgb_to_hsv(image im) + { + assert(im.c == 3); + int i, j; + float r, g, b; + float h, s, v; for(j = 0; j < im.h; ++j){ for(i = 0; i < im.w; ++i){ - gray.data[i+im.w*j] += scale[k]*get_pixel(im, i, j, k); - } - } - } - memcpy(gray.data + im.w*im.h*1, gray.data, sizeof(float)*im.w*im.h); - memcpy(gray.data + im.w*im.h*2, gray.data, sizeof(float)*im.w*im.h); - return gray; -} - -image blend_image(image fore, image back, float alpha) -{ - assert(fore.w == back.w && fore.h == back.h && fore.c == back.c); - image blend = make_image(fore.w, fore.h, fore.c); - int i, j, k; - for(k = 0; k < fore.c; ++k){ - for(j = 0; j < fore.h; ++j){ - for(i = 0; i < fore.w; ++i){ - float val = alpha * get_pixel(fore, i, j, k) + - (1 - alpha)* get_pixel(back, i, j, k); - set_pixel(blend, i, j, k, val); - } - } - } - return blend; -} - -void scale_image_channel(image im, int c, float v) -{ - int i, j; - for(j = 0; j < im.h; ++j){ - for(i = 0; i < im.w; ++i){ - float pix = get_pixel(im, i, j, c); - pix = pix*v; - set_pixel(im, i, j, c, pix); - } - } -} - -void saturate_image(image im, float sat) -{ - rgb_to_hsv(im); - scale_image_channel(im, 1, sat); - hsv_to_rgb(im); - constrain_image(im); -} - -void exposure_image(image im, float sat) -{ - rgb_to_hsv(im); - scale_image_channel(im, 2, sat); - hsv_to_rgb(im); - constrain_image(im); -} - -void saturate_exposure_image(image im, float sat, float exposure) -{ - rgb_to_hsv(im); - scale_image_channel(im, 1, sat); - scale_image_channel(im, 2, exposure); - hsv_to_rgb(im); - constrain_image(im); -} - -/* - image saturate_image(image im, float sat) - { - image gray = grayscale_image(im); - image blend = blend_image(im, gray, sat); - free_image(gray); - constrain_image(blend); - return blend; - } - - image brightness_image(image im, float b) - { - image bright = make_image(im.w, im.h, im.c); - return bright; - } - */ - -float bilinear_interpolate(image im, float x, float y, int c) -{ - int ix = (int) floorf(x); - int iy = (int) floorf(y); - - float dx = x - ix; - float dy = y - iy; - - float val = (1-dy) * (1-dx) * get_pixel_extend(im, ix, iy, c) + - dy * (1-dx) * get_pixel_extend(im, ix, iy+1, c) + - (1-dy) * dx * get_pixel_extend(im, ix+1, iy, c) + - dy * dx * get_pixel_extend(im, ix+1, iy+1, c); - return val; -} - -image resize_image(image im, int w, int h) -{ - image resized = make_image(w, h, im.c); - image part = make_image(w, im.h, im.c); - int r, c, k; - float w_scale = (float)(im.w - 1) / (w - 1); - float h_scale = (float)(im.h - 1) / (h - 1); - for(k = 0; k < im.c; ++k){ - for(r = 0; r < im.h; ++r){ - for(c = 0; c < w; ++c){ - float val = 0; - if(c == w-1){ - val = get_pixel(im, im.w-1, r, k); - } else { - float sx = c*w_scale; - int ix = (int) sx; - float dx = sx - ix; - val = (1 - dx) * get_pixel(im, ix, r, k) + dx * get_pixel(im, ix+1, r, k); + r = get_pixel(im, i , j, 0); + g = get_pixel(im, i , j, 1); + b = get_pixel(im, i , j, 2); + float max = three_way_max(r,g,b); + float min = three_way_min(r,g,b); + float delta = max - min; + v = max; + if(max == 0){ + s = 0; + h = -1; + }else{ + s = delta/max; + if(r == max){ + h = (g - b) / delta; + } else if (g == max) { + h = 2 + (b - r) / delta; + } else { + h = 4 + (r - g) / delta; + } + if (h < 0) h += 6; } - set_pixel(part, c, r, k, val); - } - } - } - for(k = 0; k < im.c; ++k){ - for(r = 0; r < h; ++r){ - float sy = r*h_scale; - int iy = (int) sy; - float dy = sy - iy; - for(c = 0; c < w; ++c){ - float val = (1-dy) * get_pixel(part, c, iy, k); - set_pixel(resized, c, r, k, val); - } - if(r == h-1) continue; - for(c = 0; c < w; ++c){ - float val = dy * get_pixel(part, c, iy+1, k); - add_pixel(resized, c, r, k, val); + set_pixel(im, i, j, 0, h); + set_pixel(im, i, j, 1, s); + set_pixel(im, i, j, 2, v); } } } - free_image(part); - return resized; -} + void hsv_to_rgb(image im) + { + assert(im.c == 3); + int i, j; + float r, g, b; + float h, s, v; + float f, p, q, t; + for(j = 0; j < im.h; ++j){ + for(i = 0; i < im.w; ++i){ + h = get_pixel(im, i , j, 0); + s = get_pixel(im, i , j, 1); + v = get_pixel(im, i , j, 2); + if (s == 0) { + r = g = b = v; + } else { + int index = floor(h); + f = h - index; + p = v*(1-s); + q = v*(1-s*f); + t = v*(1-s*(1-f)); + if(index == 0){ + r = v; g = t; b = p; + } else if(index == 1){ + r = q; g = v; b = p; + } else if(index == 2){ + r = p; g = v; b = t; + } else if(index == 3){ + r = p; g = q; b = v; + } else if(index == 4){ + r = t; g = p; b = v; + } else { + r = v; g = p; b = q; + } + } + set_pixel(im, i, j, 0, r); + set_pixel(im, i, j, 1, g); + set_pixel(im, i, j, 2, b); + } + } + } -void test_resize(char *filename) -{ - image im = load_image(filename, 0,0, 3); - image gray = grayscale_image(im); + image grayscale_image(image im) + { + assert(im.c == 3); + int i, j, k; + image gray = make_image(im.w, im.h, im.c); + float scale[] = {0.587, 0.299, 0.114}; + for(k = 0; k < im.c; ++k){ + for(j = 0; j < im.h; ++j){ + for(i = 0; i < im.w; ++i){ + gray.data[i+im.w*j] += scale[k]*get_pixel(im, i, j, k); + } + } + } + memcpy(gray.data + im.w*im.h*1, gray.data, sizeof(float)*im.w*im.h); + memcpy(gray.data + im.w*im.h*2, gray.data, sizeof(float)*im.w*im.h); + return gray; + } - image sat2 = copy_image(im); - saturate_image(sat2, 2); + image blend_image(image fore, image back, float alpha) + { + assert(fore.w == back.w && fore.h == back.h && fore.c == back.c); + image blend = make_image(fore.w, fore.h, fore.c); + int i, j, k; + for(k = 0; k < fore.c; ++k){ + for(j = 0; j < fore.h; ++j){ + for(i = 0; i < fore.w; ++i){ + float val = alpha * get_pixel(fore, i, j, k) + + (1 - alpha)* get_pixel(back, i, j, k); + set_pixel(blend, i, j, k, val); + } + } + } + return blend; + } - image sat5 = copy_image(im); - saturate_image(sat5, .5); + void scale_image_channel(image im, int c, float v) + { + int i, j; + for(j = 0; j < im.h; ++j){ + for(i = 0; i < im.w; ++i){ + float pix = get_pixel(im, i, j, c); + pix = pix*v; + set_pixel(im, i, j, c, pix); + } + } + } - image exp2 = copy_image(im); - exposure_image(exp2, 2); + void saturate_image(image im, float sat) + { + rgb_to_hsv(im); + scale_image_channel(im, 1, sat); + hsv_to_rgb(im); + constrain_image(im); + } - image exp5 = copy_image(im); - exposure_image(exp5, .5); + void exposure_image(image im, float sat) + { + rgb_to_hsv(im); + scale_image_channel(im, 2, sat); + hsv_to_rgb(im); + constrain_image(im); + } - show_image(im, "Original"); - show_image(gray, "Gray"); - show_image(sat2, "Saturation-2"); - show_image(sat5, "Saturation-.5"); - show_image(exp2, "Exposure-2"); - show_image(exp5, "Exposure-.5"); - #ifdef OPENCV - cvWaitKey(0); - #endif -} + void saturate_exposure_image(image im, float sat, float exposure) + { + rgb_to_hsv(im); + scale_image_channel(im, 1, sat); + scale_image_channel(im, 2, exposure); + hsv_to_rgb(im); + constrain_image(im); + } + + /* + image saturate_image(image im, float sat) + { + image gray = grayscale_image(im); + image blend = blend_image(im, gray, sat); + free_image(gray); + constrain_image(blend); + return blend; + } + + image brightness_image(image im, float b) + { + image bright = make_image(im.w, im.h, im.c); + return bright; + } + */ + + float bilinear_interpolate(image im, float x, float y, int c) + { + int ix = (int) floorf(x); + int iy = (int) floorf(y); + + float dx = x - ix; + float dy = y - iy; + + float val = (1-dy) * (1-dx) * get_pixel_extend(im, ix, iy, c) + + dy * (1-dx) * get_pixel_extend(im, ix, iy+1, c) + + (1-dy) * dx * get_pixel_extend(im, ix+1, iy, c) + + dy * dx * get_pixel_extend(im, ix+1, iy+1, c); + return val; + } + + image resize_image(image im, int w, int h) + { + image resized = make_image(w, h, im.c); + image part = make_image(w, im.h, im.c); + int r, c, k; + float w_scale = (float)(im.w - 1) / (w - 1); + float h_scale = (float)(im.h - 1) / (h - 1); + for(k = 0; k < im.c; ++k){ + for(r = 0; r < im.h; ++r){ + for(c = 0; c < w; ++c){ + float val = 0; + if(c == w-1){ + val = get_pixel(im, im.w-1, r, k); + } else { + float sx = c*w_scale; + int ix = (int) sx; + float dx = sx - ix; + val = (1 - dx) * get_pixel(im, ix, r, k) + dx * get_pixel(im, ix+1, r, k); + } + set_pixel(part, c, r, k, val); + } + } + } + for(k = 0; k < im.c; ++k){ + for(r = 0; r < h; ++r){ + float sy = r*h_scale; + int iy = (int) sy; + float dy = sy - iy; + for(c = 0; c < w; ++c){ + float val = (1-dy) * get_pixel(part, c, iy, k); + set_pixel(resized, c, r, k, val); + } + if(r == h-1) continue; + for(c = 0; c < w; ++c){ + float val = dy * get_pixel(part, c, iy+1, k); + add_pixel(resized, c, r, k, val); + } + } + } + + free_image(part); + return resized; + } + + void test_resize(char *filename) + { + image im = load_image(filename, 0,0, 3); + image gray = grayscale_image(im); + + image sat2 = copy_image(im); + saturate_image(sat2, 2); + + image sat5 = copy_image(im); + saturate_image(sat5, .5); + + image exp2 = copy_image(im); + exposure_image(exp2, 2); + + image exp5 = copy_image(im); + exposure_image(exp5, .5); + + show_image(im, "Original"); + show_image(gray, "Gray"); + show_image(sat2, "Saturation-2"); + show_image(sat5, "Saturation-.5"); + show_image(exp2, "Exposure-2"); + show_image(exp5, "Exposure-.5"); +#ifdef OPENCV + cvWaitKey(0); +#endif + } #ifdef OPENCV -image ipl_to_image(IplImage* src) -{ - unsigned char *data = (unsigned char *)src->imageData; - int h = src->height; - int w = src->width; - int c = src->nChannels; - int step = src->widthStep; - image out = make_image(w, h, c); - int i, j, k, count=0;; + image ipl_to_image(IplImage* src) + { + unsigned char *data = (unsigned char *)src->imageData; + int h = src->height; + int w = src->width; + int c = src->nChannels; + int step = src->widthStep; + image out = make_image(w, h, c); + int i, j, k, count=0;; - for(k= 0; k < c; ++k){ - for(i = 0; i < h; ++i){ - for(j = 0; j < w; ++j){ - out.data[count++] = data[i*step + j*c + k]/255.; + for(k= 0; k < c; ++k){ + for(i = 0; i < h; ++i){ + for(j = 0; j < w; ++j){ + out.data[count++] = data[i*step + j*c + k]/255.; + } } } + return out; } - return out; -} -image load_image_cv(char *filename, int channels) -{ - IplImage* src = 0; - int flag = -1; - if (channels == 0) flag = -1; - else if (channels == 1) flag = 0; - else if (channels == 3) flag = 1; - else { - fprintf(stderr, "OpenCV can't force load with %d channels\n", channels); - } - - if( (src = cvLoadImage(filename, flag)) == 0 ) + image load_image_cv(char *filename, int channels) { - printf("Cannot load file image %s\n", filename); - exit(0); + IplImage* src = 0; + int flag = -1; + if (channels == 0) flag = -1; + else if (channels == 1) flag = 0; + else if (channels == 3) flag = 1; + else { + fprintf(stderr, "OpenCV can't force load with %d channels\n", channels); + } + + if( (src = cvLoadImage(filename, flag)) == 0 ) + { + printf("Cannot load file image %s\n", filename); + exit(0); + } + image out = ipl_to_image(src); + cvReleaseImage(&src); + rgbgr_image(out); + return out; } - image out = ipl_to_image(src); - cvReleaseImage(&src); - rgbgr_image(out); - return out; -} #endif -image load_image_stb(char *filename, int channels) -{ - int w, h, c; - unsigned char *data = stbi_load(filename, &w, &h, &c, channels); - if (!data) { - fprintf(stderr, "Cannot load file image %s\nSTB Reason: %s\n", filename, stbi_failure_reason()); - exit(0); - } - if(channels) c = channels; - int i,j,k; - image im = make_image(w, h, c); - for(k = 0; k < c; ++k){ - for(j = 0; j < h; ++j){ - for(i = 0; i < w; ++i){ - int dst_index = i + w*j + w*h*k; - int src_index = k + c*i + c*w*j; - im.data[dst_index] = (float)data[src_index]/255.; + image load_image_stb(char *filename, int channels) + { + int w, h, c; + unsigned char *data = stbi_load(filename, &w, &h, &c, channels); + if (!data) { + fprintf(stderr, "Cannot load file image %s\nSTB Reason: %s\n", filename, stbi_failure_reason()); + exit(0); + } + if(channels) c = channels; + int i,j,k; + image im = make_image(w, h, c); + for(k = 0; k < c; ++k){ + for(j = 0; j < h; ++j){ + for(i = 0; i < w; ++i){ + int dst_index = i + w*j + w*h*k; + int src_index = k + c*i + c*w*j; + im.data[dst_index] = (float)data[src_index]/255.; + } } } + free(data); + return im; } - free(data); - return im; -} -image load_image(char *filename, int w, int h, int c) -{ - #ifdef OPENCV - image out = load_image_cv(filename, c); - #else - image out = load_image_stb(filename, c); - #endif + image load_image(char *filename, int w, int h, int c) + { +#ifdef OPENCV + image out = load_image_cv(filename, c); +#else + image out = load_image_stb(filename, c); +#endif - if((h && w) && (h != out.h || w != out.w)){ - image resized = resize_image(out, w, h); - free_image(out); - out = resized; + if((h && w) && (h != out.h || w != out.w)){ + image resized = resize_image(out, w, h); + free_image(out); + out = resized; + } + return out; } - return out; -} -image load_image_color(char *filename, int w, int h) -{ - return load_image(filename, w, h, 3); -} - -image get_image_layer(image m, int l) -{ - image out = make_image(m.w, m.h, 1); - int i; - for(i = 0; i < m.h*m.w; ++i){ - out.data[i] = m.data[i+l*m.h*m.w]; + image load_image_color(char *filename, int w, int h) + { + return load_image(filename, w, h, 3); } - return out; -} -float get_pixel(image m, int x, int y, int c) -{ - assert(x < m.w && y < m.h && c < m.c); - return m.data[c*m.h*m.w + y*m.w + x]; -} -float get_pixel_extend(image m, int x, int y, int c) -{ - if(x < 0 || x >= m.w || y < 0 || y >= m.h || c < 0 || c >= m.c) return 0; - return get_pixel(m, x, y, c); -} -void set_pixel(image m, int x, int y, int c, float val) -{ - assert(x < m.w && y < m.h && c < m.c); - m.data[c*m.h*m.w + y*m.w + x] = val; -} -void add_pixel(image m, int x, int y, int c, float val) -{ - assert(x < m.w && y < m.h && c < m.c); - m.data[c*m.h*m.w + y*m.w + x] += val; -} + image get_image_layer(image m, int l) + { + image out = make_image(m.w, m.h, 1); + int i; + for(i = 0; i < m.h*m.w; ++i){ + out.data[i] = m.data[i+l*m.h*m.w]; + } + return out; + } -void print_image(image m) -{ - int i, j, k; - for(i =0 ; i < m.c; ++i){ - for(j =0 ; j < m.h; ++j){ - for(k = 0; k < m.w; ++k){ - printf("%.2lf, ", m.data[i*m.h*m.w + j*m.w + k]); - if(k > 30) break; + float get_pixel(image m, int x, int y, int c) + { + assert(x < m.w && y < m.h && c < m.c); + return m.data[c*m.h*m.w + y*m.w + x]; + } + float get_pixel_extend(image m, int x, int y, int c) + { + if(x < 0 || x >= m.w || y < 0 || y >= m.h || c < 0 || c >= m.c) return 0; + return get_pixel(m, x, y, c); + } + void set_pixel(image m, int x, int y, int c, float val) + { + assert(x < m.w && y < m.h && c < m.c); + m.data[c*m.h*m.w + y*m.w + x] = val; + } + void add_pixel(image m, int x, int y, int c, float val) + { + assert(x < m.w && y < m.h && c < m.c); + m.data[c*m.h*m.w + y*m.w + x] += val; + } + + void print_image(image m) + { + int i, j, k; + for(i =0 ; i < m.c; ++i){ + for(j =0 ; j < m.h; ++j){ + for(k = 0; k < m.w; ++k){ + printf("%.2lf, ", m.data[i*m.h*m.w + j*m.w + k]); + if(k > 30) break; + } + printf("\n"); + if(j > 30) break; } printf("\n"); - if(j > 30) break; } printf("\n"); } - printf("\n"); -} -image collapse_images_vert(image *ims, int n) -{ - int color = 1; - int border = 1; - int h,w,c; - w = ims[0].w; - h = (ims[0].h + border) * n - border; - c = ims[0].c; - if(c != 3 || !color){ - w = (w+border)*c - border; - c = 1; - } - - image filters = make_image(w, h, c); - int i,j; - for(i = 0; i < n; ++i){ - int h_offset = i*(ims[0].h+border); - image copy = copy_image(ims[i]); - //normalize_image(copy); - if(c == 3 && color){ - embed_image(copy, filters, 0, h_offset); + image collapse_images_vert(image *ims, int n) + { + int color = 1; + int border = 1; + int h,w,c; + w = ims[0].w; + h = (ims[0].h + border) * n - border; + c = ims[0].c; + if(c != 3 || !color){ + w = (w+border)*c - border; + c = 1; } - else{ - for(j = 0; j < copy.c; ++j){ - int w_offset = j*(ims[0].w+border); - image layer = get_image_layer(copy, j); - embed_image(layer, filters, w_offset, h_offset); - free_image(layer); + + image filters = make_image(w, h, c); + int i,j; + for(i = 0; i < n; ++i){ + int h_offset = i*(ims[0].h+border); + image copy = copy_image(ims[i]); + //normalize_image(copy); + if(c == 3 && color){ + embed_image(copy, filters, 0, h_offset); } - } - free_image(copy); - } - return filters; -} - -image collapse_images_horz(image *ims, int n) -{ - int color = 1; - int border = 1; - int h,w,c; - int size = ims[0].h; - h = size; - w = (ims[0].w + border) * n - border; - c = ims[0].c; - if(c != 3 || !color){ - h = (h+border)*c - border; - c = 1; - } - - image filters = make_image(w, h, c); - int i,j; - for(i = 0; i < n; ++i){ - int w_offset = i*(size+border); - image copy = copy_image(ims[i]); - //normalize_image(copy); - if(c == 3 && color){ - embed_image(copy, filters, w_offset, 0); - } - else{ - for(j = 0; j < copy.c; ++j){ - int h_offset = j*(size+border); - image layer = get_image_layer(copy, j); - embed_image(layer, filters, w_offset, h_offset); - free_image(layer); + else{ + for(j = 0; j < copy.c; ++j){ + int w_offset = j*(ims[0].w+border); + image layer = get_image_layer(copy, j); + embed_image(layer, filters, w_offset, h_offset); + free_image(layer); + } } + free_image(copy); } - free_image(copy); - } - return filters; -} + return filters; + } -void show_images(image *ims, int n, char *window) -{ - image m = collapse_images_vert(ims, n); - int w = 448; - int h = ((float)m.h/m.w) * 448; - if(h > 896){ - h = 896; - w = ((float)m.w/m.h) * 896; - } - image sized = resize_image(m, w, h); - save_image(sized, window); - show_image(sized, window); - free_image(sized); - free_image(m); -} + image collapse_images_horz(image *ims, int n) + { + int color = 1; + int border = 1; + int h,w,c; + int size = ims[0].h; + h = size; + w = (ims[0].w + border) * n - border; + c = ims[0].c; + if(c != 3 || !color){ + h = (h+border)*c - border; + c = 1; + } -void free_image(image m) -{ - free(m.data); -} + image filters = make_image(w, h, c); + int i,j; + for(i = 0; i < n; ++i){ + int w_offset = i*(size+border); + image copy = copy_image(ims[i]); + //normalize_image(copy); + if(c == 3 && color){ + embed_image(copy, filters, w_offset, 0); + } + else{ + for(j = 0; j < copy.c; ++j){ + int h_offset = j*(size+border); + image layer = get_image_layer(copy, j); + embed_image(layer, filters, w_offset, h_offset); + free_image(layer); + } + } + free_image(copy); + } + return filters; + } + + void show_images(image *ims, int n, char *window) + { + image m = collapse_images_vert(ims, n); + int w = 448; + int h = ((float)m.h/m.w) * 448; + if(h > 896){ + h = 896; + w = ((float)m.w/m.h) * 896; + } + image sized = resize_image(m, w, h); + save_image(sized, window); + show_image(sized, window); + free_image(sized); + free_image(m); + } + + void free_image(image m) + { + free(m.data); + } diff --git a/src/image.h b/src/image.h index b635ff1f..f8577cd7 100644 --- a/src/image.h +++ b/src/image.h @@ -6,6 +6,7 @@ #include #include #include +#include "box.h" typedef struct { int h; @@ -18,6 +19,7 @@ float get_color(int c, int x, int max); void flip_image(image a); void draw_box(image a, int x1, int y1, int x2, int y2, float r, float g, float b); void draw_box_width(image a, int x1, int y1, int x2, int y2, int w, float r, float g, float b); +void draw_bbox(image a, box bbox, int w, float r, float g, float b); image image_distance(image a, image b); void scale_image(image m, float s); image crop_image(image im, int dx, int dy, int w, int h); diff --git a/src/imagenet.c b/src/imagenet.c index aeb7e690..fb573071 100644 --- a/src/imagenet.c +++ b/src/imagenet.c @@ -30,7 +30,19 @@ void train_imagenet(char *cfgfile, char *weightfile) pthread_t load_thread; data train; data buffer; - load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, net.w, net.h, &buffer); + + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.paths = paths; + args.classes = 1000; + args.n = imgs; + args.m = plist->size; + args.labels = labels; + args.d = &buffer; + args.type = CLASSIFICATION_DATA; + + load_thread = load_data_in_thread(args); while(1){ ++i; time=clock(); @@ -43,7 +55,7 @@ void train_imagenet(char *cfgfile, char *weightfile) cvWaitKey(0); */ - load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, net.w, net.h, &buffer); + load_thread = load_data_in_thread(args); printf("Loaded: %lf seconds\n", sec(clock()-time)); time=clock(); float loss = train_network(net, train); @@ -84,7 +96,19 @@ void validate_imagenet(char *filename, char *weightfile) int num = (i+1)*m/splits - i*m/splits; data val, buffer; - pthread_t load_thread = load_data_thread(paths, num, 0, labels, 1000, 256, 256, &buffer); + + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.paths = paths; + args.classes = 1000; + args.n = num; + args.m = 0; + args.labels = labels; + args.d = &buffer; + args.type = CLASSIFICATION_DATA; + + pthread_t load_thread = load_data_in_thread(args); for(i = 1; i <= splits; ++i){ time=clock(); @@ -93,7 +117,10 @@ void validate_imagenet(char *filename, char *weightfile) num = (i+1)*m/splits - i*m/splits; char **part = paths+(i*m/splits); - if(i != splits) load_thread = load_data_thread(part, num, 0, labels, 1000, 256, 256, &buffer); + if(i != splits){ + args.paths = part; + load_thread = load_data_in_thread(args); + } printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time)); time=clock(); diff --git a/src/layer.h b/src/layer.h index 29abec05..4cd9f288 100644 --- a/src/layer.h +++ b/src/layer.h @@ -15,6 +15,7 @@ typedef enum { ROUTE, COST, NORMALIZATION, + REGION, AVGPOOL } LAYER_TYPE; diff --git a/src/network.c b/src/network.c index ff5cd616..de3e569f 100644 --- a/src/network.c +++ b/src/network.c @@ -11,6 +11,7 @@ #include "convolutional_layer.h" #include "deconvolutional_layer.h" #include "detection_layer.h" +#include "region_layer.h" #include "normalization_layer.h" #include "maxpool_layer.h" #include "avgpool_layer.h" @@ -36,6 +37,8 @@ char *get_layer_string(LAYER_TYPE a) return "softmax"; case DETECTION: return "detection"; + case REGION: + return "region"; case DROPOUT: return "dropout"; case CROP: @@ -80,6 +83,8 @@ void forward_network(network net, network_state state) forward_normalization_layer(l, state); } else if(l.type == DETECTION){ forward_detection_layer(l, state); + } else if(l.type == REGION){ + forward_region_layer(l, state); } else if(l.type == CONNECTED){ forward_connected_layer(l, state); } else if(l.type == CROP){ @@ -130,12 +135,16 @@ float get_network_cost(network net) float sum = 0; int count = 0; for(i = 0; i < net.n; ++i){ - if(net.layers[net.n-1].type == COST){ - sum += net.layers[net.n-1].output[0]; + if(net.layers[i].type == COST){ + sum += net.layers[i].output[0]; ++count; } - if(net.layers[net.n-1].type == DETECTION){ - sum += net.layers[net.n-1].cost[0]; + if(net.layers[i].type == DETECTION){ + sum += net.layers[i].cost[0]; + ++count; + } + if(net.layers[i].type == REGION){ + sum += net.layers[i].cost[0]; ++count; } } @@ -178,6 +187,8 @@ void backward_network(network net, network_state state) backward_dropout_layer(l, state); } else if(l.type == DETECTION){ backward_detection_layer(l, state); + } else if(l.type == REGION){ + backward_region_layer(l, state); } else if(l.type == SOFTMAX){ if(i != 0) backward_softmax_layer(l, state); } else if(l.type == CONNECTED){ diff --git a/src/network_kernels.cu b/src/network_kernels.cu index 3340afae..593de0ae 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -12,6 +12,7 @@ extern "C" { #include "crop_layer.h" #include "connected_layer.h" #include "detection_layer.h" +#include "region_layer.h" #include "convolutional_layer.h" #include "deconvolutional_layer.h" #include "maxpool_layer.h" @@ -42,6 +43,8 @@ void forward_network_gpu(network net, network_state state) forward_deconvolutional_layer_gpu(l, state); } else if(l.type == DETECTION){ forward_detection_layer_gpu(l, state); + } else if(l.type == REGION){ + forward_region_layer_gpu(l, state); } else if(l.type == CONNECTED){ forward_connected_layer_gpu(l, state); } else if(l.type == CROP){ @@ -92,6 +95,8 @@ void backward_network_gpu(network net, network_state state) backward_dropout_layer_gpu(l, state); } else if(l.type == DETECTION){ backward_detection_layer_gpu(l, state); + } else if(l.type == REGION){ + backward_region_layer_gpu(l, state); } else if(l.type == NORMALIZATION){ backward_normalization_layer_gpu(l, state); } else if(l.type == SOFTMAX){ diff --git a/src/parser.c b/src/parser.c index b373c01a..242a83c8 100644 --- a/src/parser.c +++ b/src/parser.c @@ -14,6 +14,7 @@ #include "softmax_layer.h" #include "dropout_layer.h" #include "detection_layer.h" +#include "region_layer.h" #include "avgpool_layer.h" #include "route_layer.h" #include "list.h" @@ -37,6 +38,7 @@ int is_normalization(section *s); int is_crop(section *s); int is_cost(section *s); int is_detection(section *s); +int is_region(section *s); int is_route(section *s); list *read_cfg(char *filename); @@ -172,6 +174,16 @@ detection_layer parse_detection(list *options, size_params params) return layer; } +region_layer parse_region(list *options, size_params params) +{ + int coords = option_find_int(options, "coords", 1); + int classes = option_find_int(options, "classes", 1); + int rescore = option_find_int(options, "rescore", 0); + int num = option_find_int(options, "num", 1); + region_layer layer = make_region_layer(params.batch, params.inputs, num, classes, coords, rescore); + return layer; +} + cost_layer parse_cost(list *options, size_params params) { char *type_s = option_find_str(options, "type", "sse"); @@ -347,6 +359,8 @@ network parse_network_cfg(char *filename) l = parse_cost(options, params); }else if(is_detection(s)){ l = parse_detection(options, params); + }else if(is_region(s)){ + l = parse_region(options, params); }else if(is_softmax(s)){ l = parse_softmax(options, params); }else if(is_normalization(s)){ @@ -399,6 +413,10 @@ int is_detection(section *s) { return (strcmp(s->type, "[detection]")==0); } +int is_region(section *s) +{ + return (strcmp(s->type, "[region]")==0); +} int is_deconvolutional(section *s) { return (strcmp(s->type, "[deconv]")==0 diff --git a/src/region_layer.c b/src/region_layer.c new file mode 100644 index 00000000..7c34b5cc --- /dev/null +++ b/src/region_layer.c @@ -0,0 +1,161 @@ +#include "region_layer.h" +#include "activations.h" +#include "softmax_layer.h" +#include "blas.h" +#include "box.h" +#include "cuda.h" +#include "utils.h" +#include +#include +#include + +int get_region_layer_locations(region_layer l) +{ + return l.inputs / (l.classes+l.coords); +} + +region_layer make_region_layer(int batch, int inputs, int n, int classes, int coords, int rescore) +{ + region_layer l = {0}; + l.type = REGION; + + l.n = n; + l.batch = batch; + l.inputs = inputs; + l.classes = classes; + l.coords = coords; + l.rescore = rescore; + l.cost = calloc(1, sizeof(float)); + int outputs = inputs; + l.outputs = outputs; + l.output = calloc(batch*outputs, sizeof(float)); + l.delta = calloc(batch*outputs, sizeof(float)); + #ifdef GPU + l.output_gpu = cuda_make_array(0, batch*outputs); + l.delta_gpu = cuda_make_array(0, batch*outputs); + #endif + + fprintf(stderr, "Region Layer\n"); + srand(0); + + return l; +} + +void forward_region_layer(const region_layer l, network_state state) +{ + int locations = get_region_layer_locations(l); + int i,j; + for(i = 0; i < l.batch*locations; ++i){ + int index = i*(l.classes + l.coords); + int mask = (!state.truth || !state.truth[index]); + + for(j = 0; j < l.classes; ++j){ + l.output[index+j] = state.input[index+j]; + } + + softmax_array(l.output + index, l.classes, l.output + index); + index += l.classes; + + for(j = 0; j < l.coords; ++j){ + l.output[index+j] = mask*state.input[index+j]; + } + } + if(state.train){ + float avg_iou = 0; + int count = 0; + *(l.cost) = 0; + int size = l.outputs * l.batch; + memset(l.delta, 0, size * sizeof(float)); + for (i = 0; i < l.batch*locations; ++i) { + int offset = i*(l.classes+l.coords); + int bg = state.truth[offset]; + for (j = offset; j < offset+l.classes; ++j) { + //*(l.cost) += pow(state.truth[j] - l.output[j], 2); + //l.delta[j] = state.truth[j] - l.output[j]; + } + + box anchor = {0,0,.5,.5}; + box truth_code = {state.truth[j+0], state.truth[j+1], state.truth[j+2], state.truth[j+3]}; + box out_code = {l.output[j+0], l.output[j+1], l.output[j+2], l.output[j+3]}; + box out = decode_box(out_code, anchor); + box truth = decode_box(truth_code, anchor); + + if(bg) continue; + //printf("Box: %f %f %f %f\n", truth.x, truth.y, truth.w, truth.h); + //printf("Code: %f %f %f %f\n", truth_code.x, truth_code.y, truth_code.w, truth_code.h); + //printf("Pred : %f %f %f %f\n", out.x, out.y, out.w, out.h); + // printf("Pred Code: %f %f %f %f\n", out_code.x, out_code.y, out_code.w, out_code.h); + float iou = box_iou(out, truth); + avg_iou += iou; + ++count; + + /* + *(l.cost) += pow((1-iou), 2); + l.delta[j+0] = (state.truth[j+0] - l.output[j+0]); + l.delta[j+1] = (state.truth[j+1] - l.output[j+1]); + l.delta[j+2] = (state.truth[j+2] - l.output[j+2]); + l.delta[j+3] = (state.truth[j+3] - l.output[j+3]); + */ + + for (j = offset+l.classes; j < offset+l.classes+l.coords; ++j) { + //*(l.cost) += pow(state.truth[j] - l.output[j], 2); + //l.delta[j] = state.truth[j] - l.output[j]; + float diff = state.truth[j] - l.output[j]; + if (fabs(diff) < 1){ + l.delta[j] = diff; + *(l.cost) += .5*pow(state.truth[j] - l.output[j], 2); + } else { + l.delta[j] = (diff > 0) ? 1 : -1; + *(l.cost) += fabs(diff) - .5; + } + //l.delta[j] = state.truth[j] - l.output[j]; + } + + /* + if(l.rescore){ + for (j = offset; j < offset+l.classes; ++j) { + if(state.truth[j]) state.truth[j] = iou; + l.delta[j] = state.truth[j] - l.output[j]; + } + } + */ + } + printf("Avg IOU: %f\n", avg_iou/count); + } +} + +void backward_region_layer(const region_layer l, network_state state) +{ + axpy_cpu(l.batch*l.inputs, 1, l.delta_gpu, 1, state.delta, 1); + //copy_cpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1); +} + +#ifdef GPU + +void forward_region_layer_gpu(const region_layer l, network_state state) +{ + float *in_cpu = calloc(l.batch*l.inputs, sizeof(float)); + float *truth_cpu = 0; + if(state.truth){ + truth_cpu = calloc(l.batch*l.outputs, sizeof(float)); + cuda_pull_array(state.truth, truth_cpu, l.batch*l.outputs); + } + cuda_pull_array(state.input, in_cpu, l.batch*l.inputs); + network_state cpu_state; + cpu_state.train = state.train; + cpu_state.truth = truth_cpu; + cpu_state.input = in_cpu; + forward_region_layer(l, cpu_state); + cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs); + cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs); + free(cpu_state.input); + if(cpu_state.truth) free(cpu_state.truth); +} + +void backward_region_layer_gpu(region_layer l, network_state state) +{ + axpy_ongpu(l.batch*l.inputs, 1, l.delta_gpu, 1, state.delta, 1); + //copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1); +} +#endif + diff --git a/src/region_layer.h b/src/region_layer.h new file mode 100644 index 00000000..00fbeba3 --- /dev/null +++ b/src/region_layer.h @@ -0,0 +1,18 @@ +#ifndef REGION_LAYER_H +#define REGION_LAYER_H + +#include "params.h" +#include "layer.h" + +typedef layer region_layer; + +region_layer make_region_layer(int batch, int inputs, int n, int classes, int coords, int rescore); +void forward_region_layer(const region_layer l, network_state state); +void backward_region_layer(const region_layer l, network_state state); + +#ifdef GPU +void forward_region_layer_gpu(const region_layer l, network_state state); +void backward_region_layer_gpu(region_layer l, network_state state); +#endif + +#endif diff --git a/src/yolo.c b/src/yolo.c index 5ad95347..13f08240 100644 --- a/src/yolo.c +++ b/src/yolo.c @@ -88,14 +88,26 @@ void train_yolo(char *cfgfile, char *weightfile) int background = layer.objectness; int side = sqrt(get_detection_layer_locations(layer)); - pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer); + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.paths = paths; + args.n = imgs; + args.m = plist->size; + args.classes = classes; + args.num_boxes = side; + args.background = background; + args.d = &buffer; + args.type = DETECTION_DATA; + + pthread_t load_thread = load_data_in_thread(args); clock_t time; while(i*imgs < N*130){ i += 1; time=clock(); pthread_join(load_thread, 0); train = buffer; - load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer); + load_thread = load_data_in_thread(args); printf("Loaded: %lf seconds\n", sec(clock()-time)); time=clock(); @@ -126,7 +138,7 @@ void train_yolo(char *cfgfile, char *weightfile) pthread_join(load_thread, 0); free_data(buffer); - load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer); + load_thread = load_data_in_thread(args); } if((i-1)*imgs <= 120*N && i*imgs > N*120){ @@ -237,8 +249,17 @@ void validate_yolo(char *cfgfile, char *weightfile) image *buf = calloc(nthreads, sizeof(image)); image *buf_resized = calloc(nthreads, sizeof(image)); pthread_t *thr = calloc(nthreads, sizeof(pthread_t)); + + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.type = IMAGE_DATA; + for(t = 0; t < nthreads; ++t){ - thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h); + args.path = paths[i+t]; + args.im = &buf[t]; + args.resized = &buf_resized[t]; + thr[t] = load_data_in_thread(args); } time_t start = time(0); for(i = nthreads; i < m+nthreads; i += nthreads){ @@ -249,7 +270,10 @@ void validate_yolo(char *cfgfile, char *weightfile) val_resized[t] = buf_resized[t]; } for(t = 0; t < nthreads && i+t < m; ++t){ - thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h); + args.path = paths[i+t]; + args.im = &buf[t]; + args.resized = &buf_resized[t]; + thr[t] = load_data_in_thread(args); } for(t = 0; t < nthreads && i+t-nthreads < m; ++t){ char *path = paths[i+t-nthreads];