From f98efe6c32a064e77712f25fb40a673c3249cfd4 Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Mon, 15 Jun 2015 23:22:44 -0700 Subject: [PATCH] what happened? Conflicts: Makefile --- Makefile | 6 +- src/box.c | 213 +++++++++++++++++++++++++++++++++++++++++ src/box.h | 15 +++ src/data.c | 31 +++++- src/data.h | 2 + src/detection.c | 143 ++++++++++++++++++++++++++++ src/detection_layer.c | 215 +----------------------------------------- src/imagenet.c | 2 +- src/utils.c | 2 - src/utils.h | 3 - 10 files changed, 407 insertions(+), 225 deletions(-) create mode 100644 src/box.c create mode 100644 src/box.h diff --git a/Makefile b/Makefile index 1d1fdf2b..25eacf8f 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ -GPU=0 -OPENCV=0 +GPU=1 +OPENCV=1 DEBUG=0 ARCH= -arch=sm_52 @@ -34,7 +34,7 @@ CFLAGS+= -DGPU LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand endif -OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o +OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o ifeq ($(GPU), 1) OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o endif diff --git a/src/box.c b/src/box.c new file mode 100644 index 00000000..cce56bd0 --- /dev/null +++ b/src/box.c @@ -0,0 +1,213 @@ +#include "box.h" +#include +#include + +dbox derivative(box a, box b) +{ + dbox d; + d.dx = 0; + d.dw = 0; + float l1 = a.x - a.w/2; + float l2 = b.x - b.w/2; + if (l1 > l2){ + d.dx -= 1; + d.dw += .5; + } + float r1 = a.x + a.w/2; + float r2 = b.x + b.w/2; + if(r1 < r2){ + d.dx += 1; + d.dw += .5; + } + if (l1 > r2) { + d.dx = -1; + d.dw = 0; + } + if (r1 < l2){ + d.dx = 1; + d.dw = 0; + } + + d.dy = 0; + d.dh = 0; + float t1 = a.y - a.h/2; + float t2 = b.y - b.h/2; + if (t1 > t2){ + d.dy -= 1; + d.dh += .5; + } + float b1 = a.y + a.h/2; + float b2 = b.y + b.h/2; + if(b1 < b2){ + d.dy += 1; + d.dh += .5; + } + if (t1 > b2) { + d.dy = -1; + d.dh = 0; + } + if (b1 < t2){ + d.dy = 1; + d.dh = 0; + } + return d; +} + +float overlap(float x1, float w1, float x2, float w2) +{ + float l1 = x1 - w1/2; + float l2 = x2 - w2/2; + float left = l1 > l2 ? l1 : l2; + float r1 = x1 + w1/2; + float r2 = x2 + w2/2; + float right = r1 < r2 ? r1 : r2; + return right - left; +} + +float box_intersection(box a, box b) +{ + float w = overlap(a.x, a.w, b.x, b.w); + float h = overlap(a.y, a.h, b.y, b.h); + if(w < 0 || h < 0) return 0; + float area = w*h; + return area; +} + +float box_union(box a, box b) +{ + float i = box_intersection(a, b); + float u = a.w*a.h + b.w*b.h - i; + return u; +} + +float box_iou(box a, box b) +{ + return box_intersection(a, b)/box_union(a, b); +} + +dbox dintersect(box a, box b) +{ + float w = overlap(a.x, a.w, b.x, b.w); + float h = overlap(a.y, a.h, b.y, b.h); + dbox dover = derivative(a, b); + dbox di; + + di.dw = dover.dw*h; + di.dx = dover.dx*h; + di.dh = dover.dh*w; + di.dy = dover.dy*w; + + return di; +} + +dbox dunion(box a, box b) +{ + dbox du; + + dbox di = dintersect(a, b); + du.dw = a.h - di.dw; + du.dh = a.w - di.dh; + du.dx = -di.dx; + du.dy = -di.dy; + + return du; +} + + +void test_dunion() +{ + box a = {0, 0, 1, 1}; + box dxa= {0+.0001, 0, 1, 1}; + box dya= {0, 0+.0001, 1, 1}; + box dwa= {0, 0, 1+.0001, 1}; + box dha= {0, 0, 1, 1+.0001}; + + box b = {.5, .5, .2, .2}; + dbox di = dunion(a,b); + printf("Union: %f %f %f %f\n", di.dx, di.dy, di.dw, di.dh); + float inter = box_union(a, b); + float xinter = box_union(dxa, b); + float yinter = box_union(dya, b); + float winter = box_union(dwa, b); + float hinter = box_union(dha, b); + xinter = (xinter - inter)/(.0001); + yinter = (yinter - inter)/(.0001); + winter = (winter - inter)/(.0001); + hinter = (hinter - inter)/(.0001); + printf("Union Manual %f %f %f %f\n", xinter, yinter, winter, hinter); +} +void test_dintersect() +{ + box a = {0, 0, 1, 1}; + box dxa= {0+.0001, 0, 1, 1}; + box dya= {0, 0+.0001, 1, 1}; + box dwa= {0, 0, 1+.0001, 1}; + box dha= {0, 0, 1, 1+.0001}; + + box b = {.5, .5, .2, .2}; + dbox di = dintersect(a,b); + printf("Inter: %f %f %f %f\n", di.dx, di.dy, di.dw, di.dh); + float inter = box_intersection(a, b); + float xinter = box_intersection(dxa, b); + float yinter = box_intersection(dya, b); + float winter = box_intersection(dwa, b); + float hinter = box_intersection(dha, b); + xinter = (xinter - inter)/(.0001); + yinter = (yinter - inter)/(.0001); + winter = (winter - inter)/(.0001); + hinter = (hinter - inter)/(.0001); + printf("Inter Manual %f %f %f %f\n", xinter, yinter, winter, hinter); +} + +void test_box() +{ + test_dintersect(); + test_dunion(); + box a = {0, 0, 1, 1}; + box dxa= {0+.00001, 0, 1, 1}; + box dya= {0, 0+.00001, 1, 1}; + box dwa= {0, 0, 1+.00001, 1}; + box dha= {0, 0, 1, 1+.00001}; + + box b = {.5, 0, .2, .2}; + + float iou = box_iou(a,b); + iou = (1-iou)*(1-iou); + printf("%f\n", iou); + dbox d = diou(a, b); + printf("%f %f %f %f\n", d.dx, d.dy, d.dw, d.dh); + + float xiou = box_iou(dxa, b); + float yiou = box_iou(dya, b); + float wiou = box_iou(dwa, b); + float hiou = box_iou(dha, b); + xiou = ((1-xiou)*(1-xiou) - iou)/(.00001); + yiou = ((1-yiou)*(1-yiou) - iou)/(.00001); + wiou = ((1-wiou)*(1-wiou) - iou)/(.00001); + hiou = ((1-hiou)*(1-hiou) - iou)/(.00001); + printf("manual %f %f %f %f\n", xiou, yiou, wiou, hiou); +} + +dbox diou(box a, box b) +{ + float u = box_union(a,b); + float i = box_intersection(a,b); + dbox di = dintersect(a,b); + dbox du = dunion(a,b); + dbox dd = {0,0,0,0}; + + if(i <= 0 || 1) { + dd.dx = b.x - a.x; + dd.dy = b.y - a.y; + dd.dw = b.w - a.w; + dd.dh = b.h - a.h; + return dd; + } + + dd.dx = 2*pow((1-(i/u)),1)*(di.dx*u - du.dx*i)/(u*u); + dd.dy = 2*pow((1-(i/u)),1)*(di.dy*u - du.dy*i)/(u*u); + dd.dw = 2*pow((1-(i/u)),1)*(di.dw*u - du.dw*i)/(u*u); + dd.dh = 2*pow((1-(i/u)),1)*(di.dh*u - du.dh*i)/(u*u); + return dd; +} + diff --git a/src/box.h b/src/box.h new file mode 100644 index 00000000..e3831d8e --- /dev/null +++ b/src/box.h @@ -0,0 +1,15 @@ +#ifndef BOX_H +#define BOX_H + +typedef struct{ + float x, y, w, h; +} box; + +typedef struct{ + float dx, dy, dw, dh; +} dbox; + +float box_iou(box a, box b); +dbox diou(box a, box b); + +#endif diff --git a/src/data.c b/src/data.c index ad485920..dafcc98f 100644 --- a/src/data.c +++ b/src/data.c @@ -8,7 +8,7 @@ unsigned int data_seed; -struct load_args{ +typedef struct load_args{ char **paths; int n; int m; @@ -22,7 +22,10 @@ struct load_args{ int classes; int background; data *d; -}; + char *path; + image *im; + image *resized; +} load_args; list *get_paths(char *filename) { @@ -468,6 +471,30 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int classes, return d; } +void *load_image_in_thread(void *ptr) +{ + load_args a = *(load_args*)ptr; + free(ptr); + *(a.im) = load_image_color(a.path, 0, 0); + *(a.resized) = resize_image(*(a.im), a.w, a.h); + return 0; +} + +pthread_t load_image_thread(char *path, image *im, image *resized, int w, int h) +{ + pthread_t thread; + struct load_args *args = calloc(1, sizeof(struct load_args)); + args->path = path; + args->w = w; + args->h = h; + args->im = im; + args->resized = resized; + if(pthread_create(&thread, 0, load_image_in_thread, args)) { + error("Thread creation failed"); + } + return thread; +} + void *load_localization_thread(void *ptr) { printf("Loading data: %d\n", rand_r(&data_seed)); diff --git a/src/data.h b/src/data.h index bda804fc..993a7a7f 100644 --- a/src/data.h +++ b/src/data.h @@ -4,6 +4,7 @@ #include "matrix.h" #include "list.h" +#include "image.h" extern unsigned int data_seed; @@ -33,6 +34,7 @@ data load_data_captcha(char **paths, int n, int m, int k, int w, int h); data load_data_captcha_encode(char **paths, int n, int m, int w, int h); data load_data(char **paths, int n, int m, char **labels, int k, int w, int h); pthread_t load_data_thread(char **paths, int n, int m, char **labels, int k, int w, int h, data *d); +pthread_t load_image_thread(char *path, image *im, image *resized, int w, int h); pthread_t load_data_detection_thread(int n, char **paths, int m, int classes, int w, int h, int nh, int nw, int background, data *d); data load_data_detection_jitter_random(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background); diff --git a/src/detection.c b/src/detection.c index e21e1207..25531157 100644 --- a/src/detection.c +++ b/src/detection.c @@ -3,6 +3,7 @@ #include "cost_layer.h" #include "utils.h" #include "parser.h" +#include "box.h" char *class_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"}; @@ -206,6 +207,147 @@ void validate_detection(char *cfgfile, char *weightfile) fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start)); } + +void convert_detections(float *predictions, int classes, int objectness, int background, int num_boxes, int w, int h, float thresh, float **probs, box *boxes) +{ + int i,j; + int per_box = 4+classes+(background || objectness); + for (i = 0; i < num_boxes*num_boxes; ++i){ + float scale = 1; + if(objectness) scale = 1-predictions[i*per_box]; + int offset = i*per_box+(background||objectness); + for(j = 0; j < classes; ++j){ + float prob = scale*predictions[offset+j]; + probs[i][j] = (prob > thresh) ? prob : 0; + } + int row = i / num_boxes; + int col = i % num_boxes; + offset += classes; + boxes[i].x = (predictions[offset + 0] + col) / num_boxes * w; + boxes[i].y = (predictions[offset + 1] + row) / num_boxes * h; + boxes[i].w = pow(predictions[offset + 2], 2) * w; + boxes[i].h = pow(predictions[offset + 3], 2) * h; + } +} + +void do_nms(box *boxes, float **probs, int num_boxes, int classes, float thresh) +{ + int i, j, k; + for(i = 0; i < num_boxes*num_boxes; ++i){ + int any = 0; + for(k = 0; k < classes; ++k) any = any || (probs[i][k] > 0); + if(!any) { + continue; + } + for(j = i+1; j < num_boxes*num_boxes; ++j){ + if (box_iou(boxes[i], boxes[j]) > thresh){ + for(k = 0; k < classes; ++k){ + if (probs[i][k] < probs[j][k]) probs[i][k] = 0; + else probs[j][k] = 0; + } + } + } + } +} + +void print_detections(FILE **fps, char *id, box *boxes, float **probs, int num_boxes, int classes, int w, int h) +{ + int i, j; + for(i = 0; i < num_boxes*num_boxes; ++i){ + float xmin = boxes[i].x - boxes[i].w/2.; + float xmax = boxes[i].x + boxes[i].w/2.; + float ymin = boxes[i].y - boxes[i].h/2.; + float ymax = boxes[i].y + boxes[i].h/2.; + + if (xmin < 0) xmin = 0; + if (ymin < 0) ymin = 0; + if (xmax > w) xmax = w; + if (ymax > h) ymax = h; + + for(j = 0; j < classes; ++j){ + if (probs[i][j]) fprintf(fps[j], "%s %f %f %f %f %f\n", id, probs[i][j], + xmin, ymin, xmax, ymax); + } + } +} + +void valid_detection(char *cfgfile, char *weightfile) +{ + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + set_batch_network(&net, 1); + detection_layer layer = get_network_detection_layer(net); + fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + srand(time(0)); + + char *base = "/home/pjreddie/data/voc/devkit/results/VOC2012/Main/comp4_det_test_"; + list *plist = get_paths("/home/pjreddie/data/voc/test.txt"); + char **paths = (char **)list_to_array(plist); + + int classes = layer.classes; + int objectness = layer.objectness; + int background = layer.background; + int num_boxes = sqrt(get_detection_layer_locations(layer)); + + int j; + FILE **fps = calloc(classes, sizeof(FILE *)); + for(j = 0; j < classes; ++j){ + char buff[1024]; + snprintf(buff, 1024, "%s%s.txt", base, class_names[j]); + fps[j] = fopen(buff, "w"); + } + box *boxes = calloc(num_boxes*num_boxes, sizeof(box)); + float **probs = calloc(num_boxes*num_boxes, sizeof(float *)); + for(j = 0; j < num_boxes*num_boxes; ++j) probs[j] = calloc(classes, sizeof(float *)); + + int m = plist->size; + int i=0; + int t; + + float thresh = .001; + int nms = 1; + float iou_thresh = .5; + + int nthreads = 8; + image *val = calloc(nthreads, sizeof(image)); + image *val_resized = calloc(nthreads, sizeof(image)); + image *buf = calloc(nthreads, sizeof(image)); + image *buf_resized = calloc(nthreads, sizeof(image)); + pthread_t *thr = calloc(nthreads, sizeof(pthread_t)); + for(t = 0; t < nthreads; ++t){ + thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h); + } + time_t start = time(0); + for(i = nthreads; i < m+nthreads; i += nthreads){ + fprintf(stderr, "%d\n", i); + for(t = 0; t < nthreads && i+t-nthreads < m; ++t){ + pthread_join(thr[t], 0); + val[t] = buf[t]; + val_resized[t] = buf_resized[t]; + } + for(t = 0; t < nthreads && i+t < m; ++t){ + thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h); + } + for(t = 0; t < nthreads && i+t-nthreads < m; ++t){ + char *path = paths[i+t-nthreads]; + char *id = basecfg(path); + float *X = val_resized[t].data; + float *predictions = network_predict(net, X); + int w = val[t].w; + int h = val[t].h; + convert_detections(predictions, classes, objectness, background, num_boxes, w, h, thresh, probs, boxes); + if (nms) do_nms(boxes, probs, num_boxes, classes, iou_thresh); + print_detections(fps, id, boxes, probs, num_boxes, classes, w, h); + free(id); + free_image(val[t]); + free_image(val_resized[t]); + } + } + fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start)); +} + void test_detection(char *cfgfile, char *weightfile, char *filename) { @@ -259,4 +401,5 @@ void run_detection(int argc, char **argv) if(0==strcmp(argv[2], "test")) test_detection(cfg, weights, filename); else if(0==strcmp(argv[2], "train")) train_detection(cfg, weights); else if(0==strcmp(argv[2], "valid")) validate_detection(cfg, weights); + else if(0==strcmp(argv[2], "run")) valid_detection(cfg, weights); } diff --git a/src/detection_layer.c b/src/detection_layer.c index 3ab793ae..9ef89d9a 100644 --- a/src/detection_layer.c +++ b/src/detection_layer.c @@ -2,6 +2,7 @@ #include "activations.h" #include "softmax_layer.h" #include "blas.h" +#include "box.h" #include "cuda.h" #include "utils.h" #include @@ -48,220 +49,6 @@ detection_layer make_detection_layer(int batch, int inputs, int classes, int coo return l; } -typedef struct{ - float dx, dy, dw, dh; -} dbox; - -dbox derivative(box a, box b) -{ - dbox d; - d.dx = 0; - d.dw = 0; - float l1 = a.x - a.w/2; - float l2 = b.x - b.w/2; - if (l1 > l2){ - d.dx -= 1; - d.dw += .5; - } - float r1 = a.x + a.w/2; - float r2 = b.x + b.w/2; - if(r1 < r2){ - d.dx += 1; - d.dw += .5; - } - if (l1 > r2) { - d.dx = -1; - d.dw = 0; - } - if (r1 < l2){ - d.dx = 1; - d.dw = 0; - } - - d.dy = 0; - d.dh = 0; - float t1 = a.y - a.h/2; - float t2 = b.y - b.h/2; - if (t1 > t2){ - d.dy -= 1; - d.dh += .5; - } - float b1 = a.y + a.h/2; - float b2 = b.y + b.h/2; - if(b1 < b2){ - d.dy += 1; - d.dh += .5; - } - if (t1 > b2) { - d.dy = -1; - d.dh = 0; - } - if (b1 < t2){ - d.dy = 1; - d.dh = 0; - } - return d; -} - -float overlap(float x1, float w1, float x2, float w2) -{ - float l1 = x1 - w1/2; - float l2 = x2 - w2/2; - float left = l1 > l2 ? l1 : l2; - float r1 = x1 + w1/2; - float r2 = x2 + w2/2; - float right = r1 < r2 ? r1 : r2; - return right - left; -} - -float box_intersection(box a, box b) -{ - float w = overlap(a.x, a.w, b.x, b.w); - float h = overlap(a.y, a.h, b.y, b.h); - if(w < 0 || h < 0) return 0; - float area = w*h; - return area; -} - -float box_union(box a, box b) -{ - float i = box_intersection(a, b); - float u = a.w*a.h + b.w*b.h - i; - return u; -} - -float box_iou(box a, box b) -{ - return box_intersection(a, b)/box_union(a, b); -} - -dbox dintersect(box a, box b) -{ - float w = overlap(a.x, a.w, b.x, b.w); - float h = overlap(a.y, a.h, b.y, b.h); - dbox dover = derivative(a, b); - dbox di; - - di.dw = dover.dw*h; - di.dx = dover.dx*h; - di.dh = dover.dh*w; - di.dy = dover.dy*w; - - return di; -} - -dbox dunion(box a, box b) -{ - dbox du; - - dbox di = dintersect(a, b); - du.dw = a.h - di.dw; - du.dh = a.w - di.dh; - du.dx = -di.dx; - du.dy = -di.dy; - - return du; -} - -dbox diou(box a, box b); - -void test_dunion() -{ - box a = {0, 0, 1, 1}; - box dxa= {0+.0001, 0, 1, 1}; - box dya= {0, 0+.0001, 1, 1}; - box dwa= {0, 0, 1+.0001, 1}; - box dha= {0, 0, 1, 1+.0001}; - - box b = {.5, .5, .2, .2}; - dbox di = dunion(a,b); - printf("Union: %f %f %f %f\n", di.dx, di.dy, di.dw, di.dh); - float inter = box_union(a, b); - float xinter = box_union(dxa, b); - float yinter = box_union(dya, b); - float winter = box_union(dwa, b); - float hinter = box_union(dha, b); - xinter = (xinter - inter)/(.0001); - yinter = (yinter - inter)/(.0001); - winter = (winter - inter)/(.0001); - hinter = (hinter - inter)/(.0001); - printf("Union Manual %f %f %f %f\n", xinter, yinter, winter, hinter); -} -void test_dintersect() -{ - box a = {0, 0, 1, 1}; - box dxa= {0+.0001, 0, 1, 1}; - box dya= {0, 0+.0001, 1, 1}; - box dwa= {0, 0, 1+.0001, 1}; - box dha= {0, 0, 1, 1+.0001}; - - box b = {.5, .5, .2, .2}; - dbox di = dintersect(a,b); - printf("Inter: %f %f %f %f\n", di.dx, di.dy, di.dw, di.dh); - float inter = box_intersection(a, b); - float xinter = box_intersection(dxa, b); - float yinter = box_intersection(dya, b); - float winter = box_intersection(dwa, b); - float hinter = box_intersection(dha, b); - xinter = (xinter - inter)/(.0001); - yinter = (yinter - inter)/(.0001); - winter = (winter - inter)/(.0001); - hinter = (hinter - inter)/(.0001); - printf("Inter Manual %f %f %f %f\n", xinter, yinter, winter, hinter); -} - -void test_box() -{ - test_dintersect(); - test_dunion(); - box a = {0, 0, 1, 1}; - box dxa= {0+.00001, 0, 1, 1}; - box dya= {0, 0+.00001, 1, 1}; - box dwa= {0, 0, 1+.00001, 1}; - box dha= {0, 0, 1, 1+.00001}; - - box b = {.5, 0, .2, .2}; - - float iou = box_iou(a,b); - iou = (1-iou)*(1-iou); - printf("%f\n", iou); - dbox d = diou(a, b); - printf("%f %f %f %f\n", d.dx, d.dy, d.dw, d.dh); - - float xiou = box_iou(dxa, b); - float yiou = box_iou(dya, b); - float wiou = box_iou(dwa, b); - float hiou = box_iou(dha, b); - xiou = ((1-xiou)*(1-xiou) - iou)/(.00001); - yiou = ((1-yiou)*(1-yiou) - iou)/(.00001); - wiou = ((1-wiou)*(1-wiou) - iou)/(.00001); - hiou = ((1-hiou)*(1-hiou) - iou)/(.00001); - printf("manual %f %f %f %f\n", xiou, yiou, wiou, hiou); -} - -dbox diou(box a, box b) -{ - float u = box_union(a,b); - float i = box_intersection(a,b); - dbox di = dintersect(a,b); - dbox du = dunion(a,b); - dbox dd = {0,0,0,0}; - - if(i <= 0 || 1) { - dd.dx = b.x - a.x; - dd.dy = b.y - a.y; - dd.dw = b.w - a.w; - dd.dh = b.h - a.h; - return dd; - } - - dd.dx = 2*pow((1-(i/u)),1)*(di.dx*u - du.dx*i)/(u*u); - dd.dy = 2*pow((1-(i/u)),1)*(di.dy*u - du.dy*i)/(u*u); - dd.dw = 2*pow((1-(i/u)),1)*(di.dw*u - du.dw*i)/(u*u); - dd.dh = 2*pow((1-(i/u)),1)*(di.dh*u - du.dh*i)/(u*u); - return dd; -} - void forward_detection_layer(const detection_layer l, network_state state) { int in_i = 0; diff --git a/src/imagenet.c b/src/imagenet.c index ef932d16..487039cd 100644 --- a/src/imagenet.c +++ b/src/imagenet.c @@ -47,7 +47,7 @@ void train_imagenet(char *cfgfile, char *weightfile) avg_loss = avg_loss*.9 + loss*.1; printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen); free_data(train); - if((i % 15000) == 0) net.learning_rate *= .1; + if((i % 20000) == 0) net.learning_rate *= .1; //if(i%100 == 0 && net.learning_rate > .00001) net.learning_rate *= .97; if(i%1000==0){ char buff[256]; diff --git a/src/utils.c b/src/utils.c index bd2f54dd..63664ec9 100644 --- a/src/utils.c +++ b/src/utils.c @@ -18,8 +18,6 @@ char *basecfg(char *cfgfile) c = next+1; } c = copy_string(c); - next = strchr(c, '_'); - if (next) *next = 0; next = strchr(c, '.'); if (next) *next = 0; return c; diff --git a/src/utils.h b/src/utils.h index 0db16de1..7148e547 100644 --- a/src/utils.h +++ b/src/utils.h @@ -37,8 +37,5 @@ float mag_array(float *a, int n); float **one_hot_encode(float *a, int n, int k); float sec(clock_t clocks); -typedef struct{ - float x, y, w, h; -} box; #endif