diff --git a/src/blas.h b/src/blas.h index e7c976d4..1657fc5e 100644 --- a/src/blas.h +++ b/src/blas.h @@ -18,6 +18,6 @@ void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY); void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY); void scal_ongpu(int N, float ALPHA, float * X, int INCX); -void mask_ongpu(int N, float * X, float * mask, float mod); +void mask_ongpu(int N, float * X, float * mask); #endif #endif diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index d6f71433..636a9b5e 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -15,10 +15,10 @@ __global__ void scal_kernel(int N, float ALPHA, float *X, int INCX) if(i < N) X[i*INCX] *= ALPHA; } -__global__ void mask_kernel(int n, float *x, float *mask, int mod) +__global__ void mask_kernel(int n, float *x, float *mask) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; - if(i < n) x[i] = (i%mod && !mask[(i/mod)*mod]) ? 0 : x[i]; + if(i < n && mask[i] == 0) x[i] = 0; } __global__ void copy_kernel(int N, float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY) @@ -49,9 +49,9 @@ extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * check_error(cudaPeekAtLastError()); } -extern "C" void mask_ongpu(int N, float * X, float * mask, float mod) +extern "C" void mask_ongpu(int N, float * X, float * mask) { - mask_kernel<<>>(N, X, mask, mod); + mask_kernel<<>>(N, X, mask); check_error(cudaPeekAtLastError()); } diff --git a/src/cost_layer.c b/src/cost_layer.c index 1ea03bb8..1f36232b 100644 --- a/src/cost_layer.c +++ b/src/cost_layer.c @@ -10,6 +10,7 @@ COST_TYPE get_cost_type(char *s) { if (strcmp(s, "sse")==0) return SSE; + if (strcmp(s, "masked")==0) return MASKED; fprintf(stderr, "Couldn't find activation function %s, going with SSE\n", s); return SSE; } @@ -19,6 +20,8 @@ char *get_cost_string(COST_TYPE a) switch(a){ case SSE: return "sse"; + case MASKED: + return "masked"; } return "sse"; } @@ -41,6 +44,12 @@ cost_layer *make_cost_layer(int batch, int inputs, COST_TYPE type) void forward_cost_layer(cost_layer layer, network_state state) { if (!state.truth) return; + if(layer.type == MASKED){ + int i; + for(i = 0; i < layer.batch*layer.inputs; ++i){ + if(state.truth[i] == 0) state.input[i] = 0; + } + } copy_cpu(layer.batch*layer.inputs, state.truth, 1, layer.delta, 1); axpy_cpu(layer.batch*layer.inputs, -1, state.input, 1, layer.delta, 1); *(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1); @@ -67,6 +76,9 @@ void push_cost_layer(cost_layer layer) void forward_cost_layer_gpu(cost_layer layer, network_state state) { if (!state.truth) return; + if (layer.type == MASKED) { + mask_ongpu(layer.batch*layer.inputs, state.input, state.truth); + } copy_ongpu(layer.batch*layer.inputs, state.truth, 1, layer.delta_gpu, 1); axpy_ongpu(layer.batch*layer.inputs, -1, state.input, 1, layer.delta_gpu, 1); diff --git a/src/cost_layer.h b/src/cost_layer.h index d4416988..0b92a11b 100644 --- a/src/cost_layer.h +++ b/src/cost_layer.h @@ -3,7 +3,7 @@ #include "params.h" typedef enum{ - SSE + SSE, MASKED } COST_TYPE; typedef struct { diff --git a/src/data.c b/src/data.c index 79c203e0..6a111def 100644 --- a/src/data.c +++ b/src/data.c @@ -184,6 +184,57 @@ void fill_truth_detection(char *path, float *truth, int classes, int num_boxes, free(boxes); } +void fill_truth_localization(char *path, float *truth, int classes, int flip, float dx, float dy, float sx, float sy) +{ + char *labelpath = find_replace(path, "objects", "object_labels"); + labelpath = find_replace(labelpath, ".jpg", ".txt"); + labelpath = find_replace(labelpath, ".JPEG", ".txt"); + int count; + box_label *boxes = read_boxes(labelpath, &count); + box_label box = boxes[0]; + free(boxes); + float x,y,w,h; + float left, top, right, bot; + int id; + int i; + for(i = 0; i < count; ++i){ + left = box.left * sx - dx; + right = box.right * sx - dx; + top = box.top * sy - dy; + bot = box.bottom* sy - dy; + id = box.id; + + if(flip){ + float swap = left; + left = 1. - right; + right = 1. - swap; + } + + left = constrain(0, 1, left); + right = constrain(0, 1, right); + top = constrain(0, 1, top); + bot = constrain(0, 1, bot); + + x = (left+right)/2; + y = (top+bot)/2; + w = (right - left); + h = (bot - top); + + if (x <= 0 || x >= 1 || y <= 0 || y >= 1) continue; + + w = constrain(0, 1, w); + h = constrain(0, 1, h); + if (w == 0 || h == 0) continue; + + int index = id*4; + truth[index++] = x; + truth[index++] = y; + truth[index++] = w; + truth[index++] = h; + } +} + + #define NUMCHARS 37 void print_letters(float *pred, int n) @@ -281,6 +332,56 @@ void free_data(data d) } } +data load_data_localization(int n, char **paths, int m, int classes, int w, int h) +{ + char **random_paths = get_random_paths(paths, n, m); + int i; + data d; + d.shallow = 0; + + d.X.rows = n; + d.X.vals = calloc(d.X.rows, sizeof(float*)); + d.X.cols = h*w*3; + + int k = (4*classes); + d.y = make_matrix(n, k); + for(i = 0; i < n; ++i){ + image orig = load_image_color(random_paths[i], 0, 0); + + int oh = orig.h; + int ow = orig.w; + + int dw = 32; + int dh = 32; + + int pleft = (rand_uniform() * dw); + int pright = (rand_uniform() * dw); + int ptop = (rand_uniform() * dh); + int pbot = (rand_uniform() * dh); + + int swidth = ow - pleft - pright; + int sheight = oh - ptop - pbot; + + float sx = (float)swidth / ow; + float sy = (float)sheight / oh; + + int flip = rand_r(&data_seed)%2; + image cropped = crop_image(orig, pleft, ptop, swidth, sheight); + float dx = ((float)pleft/ow)/sx; + float dy = ((float)ptop /oh)/sy; + + free_image(orig); + image sized = resize_image(cropped, w, h); + free_image(cropped); + if(flip) flip_image(sized); + d.X.vals[i] = sized.data; + + fill_truth_localization(random_paths[i], d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy); + } + free(random_paths); + return d; +} + data load_data_detection_jitter_random(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background) { char **random_paths = get_random_paths(paths, n, m); @@ -296,11 +397,6 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int classes, d.y = make_matrix(n, k); for(i = 0; i < n; ++i){ image orig = load_image_color(random_paths[i], 0, 0); - float exposure = rand_uniform()+1; - if(rand_uniform() > .5) exposure = 1/exposure; - - float saturation = rand_uniform()+1; - if(rand_uniform() > .5) saturation = 1/saturation; int oh = orig.h; int ow = orig.w; @@ -343,6 +439,32 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int classes, return d; } +void *load_localization_thread(void *ptr) +{ + printf("Loading data: %d\n", rand_r(&data_seed)); + struct load_args a = *(struct load_args*)ptr; + *a.d = load_data_localization(a.n, a.paths, a.m, a.classes, a.w, a.h); + free(ptr); + return 0; +} + +pthread_t load_data_localization_thread(int n, char **paths, int m, int classes, int w, int h, data *d) +{ + pthread_t thread; + struct load_args *args = calloc(1, sizeof(struct load_args)); + args->n = n; + args->paths = paths; + args->m = m; + args->w = w; + args->h = h; + args->classes = classes; + args->d = d; + if(pthread_create(&thread, 0, load_localization_thread, args)) { + error("Thread creation failed"); + } + return thread; +} + void *load_detection_thread(void *ptr) { printf("Loading data: %d\n", rand_r(&data_seed)); diff --git a/src/data.h b/src/data.h index 22fd248e..63a2b390 100644 --- a/src/data.h +++ b/src/data.h @@ -36,6 +36,7 @@ pthread_t load_data_thread(char **paths, int n, int m, char **labels, int k, int pthread_t load_data_detection_thread(int n, char **paths, int m, int classes, int w, int h, int nh, int nw, int background, data *d); data load_data_detection_jitter_random(int n, char **paths, int m, int classes, int w, int h, int num_boxes, int background); +pthread_t load_data_localization_thread(int n, char **paths, int m, int classes, int w, int h, data *d); data load_cifar10_data(char *filename); data load_all_cifar10(); diff --git a/src/detection.c b/src/detection.c index 93e9fe13..1f1114f3 100644 --- a/src/detection.c +++ b/src/detection.c @@ -1,5 +1,6 @@ #include "network.h" #include "detection_layer.h" +#include "cost_layer.h" #include "utils.h" #include "parser.h" @@ -7,7 +8,7 @@ char *class_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"}; char *inet_class_names[] = {"bg", "accordion", "airplane", "ant", "antelope", "apple", "armadillo", "artichoke", "axe", "baby bed", "backpack", "bagel", "balance beam", "banana", "band aid", "banjo", "baseball", "basketball", "bathing cap", "beaker", "bear", "bee", "bell pepper", "bench", "bicycle", "binder", "bird", "bookshelf", "bow tie", "bow", "bowl", "brassiere", "burrito", "bus", "butterfly", "camel", "can opener", "car", "cart", "cattle", "cello", "centipede", "chain saw", "chair", "chime", "cocktail shaker", "coffee maker", "computer keyboard", "computer mouse", "corkscrew", "cream", "croquet ball", "crutch", "cucumber", "cup or mug", "diaper", "digital clock", "dishwasher", "dog", "domestic cat", "dragonfly", "drum", "dumbbell", "electric fan", "elephant", "face powder", "fig", "filing cabinet", "flower pot", "flute", "fox", "french horn", "frog", "frying pan", "giant panda", "goldfish", "golf ball", "golfcart", "guacamole", "guitar", "hair dryer", "hair spray", "hamburger", "hammer", "hamster", "harmonica", "harp", "hat with a wide brim", "head cabbage", "helmet", "hippopotamus", "horizontal bar", "horse", "hotdog", "iPod", "isopod", "jellyfish", "koala bear", "ladle", "ladybug", "lamp", "laptop", "lemon", "lion", "lipstick", "lizard", "lobster", "maillot", "maraca", "microphone", "microwave", "milk can", "miniskirt", "monkey", "motorcycle", "mushroom", "nail", "neck brace", "oboe", "orange", "otter", "pencil box", "pencil sharpener", "perfume", "person", "piano", "pineapple", "ping-pong ball", "pitcher", "pizza", "plastic bag", "plate rack", "pomegranate", "popsicle", "porcupine", "power drill", "pretzel", "printer", "puck", "punching bag", "purse", "rabbit", "racket", "ray", "red panda", "refrigerator", "remote control", "rubber eraser", "rugby ball", "ruler", "salt or pepper shaker", "saxophone", "scorpion", "screwdriver", "seal", "sheep", "ski", "skunk", "snail", "snake", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sofa", "spatula", "squirrel", "starfish", "stethoscope", "stove", "strainer", "strawberry", "stretcher", "sunglasses", "swimming trunks", "swine", "syringe", "table", "tape player", "tennis ball", "tick", "tie", "tiger", "toaster", "traffic light", "train", "trombone", "trumpet", "turtle", "tv or monitor", "unicycle", "vacuum", "violin", "volleyball", "waffle iron", "washer", "water bottle", "watercraft", "whale", "wine bottle", "zebra"}; #define AMNT 3 -void draw_detection(image im, float *box, int side) +void draw_detection(image im, float *box, int side, char *label) { int classes = 20; int elems = 4+classes; @@ -20,7 +21,7 @@ void draw_detection(image im, float *box, int side) //printf("%d\n", j); //printf("Prob: %f\n", box[j]); int class = max_index(box+j, classes); - if(box[j+class] > .2){ + if(box[j+class] > .4){ //int z; //for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+z], class_names[z]); printf("%f %s\n", box[j+class], class_names[class]); @@ -35,8 +36,8 @@ void draw_detection(image im, float *box, int side) float x = box[j+1]; x = (x+c)/side; y = (y+r)/side; - float h = box[j+2]; //*maxheight; - float w = box[j+3]; //*maxwidth; + float w = box[j+2]; //*maxwidth; + float h = box[j+3]; //*maxheight; h = h*h; w = w*w; //printf("coords %f %f %f %f\n", x, y, w, h); @@ -50,8 +51,176 @@ void draw_detection(image im, float *box, int side) } } //printf("Done\n"); - show_image(im, "box"); - cvWaitKey(0); + show_image(im, label); +} + +void draw_localization(image im, float *box) +{ + int classes = 20; + int class; + for(class = 0; class < classes; ++class){ + //int z; + //for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+z], class_names[z]); + float red = get_color(0,class,classes); + float green = get_color(1,class,classes); + float blue = get_color(2,class,classes); + + int j = class*4; + float x = box[j+0]; + float y = box[j+1]; + float w = box[j+2]; //*maxheight; + float h = box[j+3]; //*maxwidth; + //printf("coords %f %f %f %f\n", x, y, w, h); + + int left = (x-w/2)*im.w; + int right = (x+w/2)*im.w; + int top = (y-h/2)*im.h; + int bot = (y+h/2)*im.h; + draw_box(im, left, top, right, bot, red, green, blue); + } + //printf("Done\n"); +} + +void train_localization(char *cfgfile, char *weightfile) +{ + srand(time(0)); + data_seed = time(0); + char *base = basecfg(cfgfile); + printf("%s\n", base); + float avg_loss = -1; + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + int imgs = 128; + int classes = 20; + int i = net.seen/imgs; + data train, buffer; + + char **paths; + list *plist; + plist = get_paths("/home/pjreddie/data/voc/loc.2012val.txt"); + paths = (char **)list_to_array(plist); + pthread_t load_thread = load_data_localization_thread(imgs, paths, plist->size, classes, net.w, net.h, &buffer); + clock_t time; + while(1){ + i += 1; + time=clock(); + pthread_join(load_thread, 0); + train = buffer; + load_thread = load_data_localization_thread(imgs, paths, plist->size, classes, net.w, net.h, &buffer); + + printf("Loaded: %lf seconds\n", sec(clock()-time)); + time=clock(); + float loss = train_network(net, train); + + float *out = get_network_output_gpu(net); + image im = float_to_image(net.w, net.h, 3, train.X.vals[127]); + image copy = copy_image(im); + draw_localization(copy, &(out[63*80])); + draw_localization(copy, train.y.vals[127]); + show_image(copy, "box"); + cvWaitKey(0); + free_image(copy); + + net.seen += imgs; + if (avg_loss < 0) avg_loss = loss; + avg_loss = avg_loss*.9 + loss*.1; + printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs); + if(i%100==0){ + char buff[256]; + sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i); + save_weights(net, buff); + } + free_data(train); + } +} + +void train_detection_teststuff(char *cfgfile, char *weightfile) +{ + srand(time(0)); + data_seed = time(0); + int imgnet = 0; + char *base = basecfg(cfgfile); + printf("%s\n", base); + float avg_loss = -1; + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + detection_layer *layer = get_network_detection_layer(net); + net.learning_rate = 0; + net.decay = 0; + printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + int imgs = 128; + int i = net.seen/imgs; + data train, buffer; + + int classes = layer->classes; + int background = layer->background; + int side = sqrt(get_detection_layer_locations(*layer)); + + char **paths; + list *plist; + if (imgnet){ + plist = get_paths("/home/pjreddie/data/imagenet/det.train.list"); + }else{ + plist = get_paths("/home/pjreddie/data/voc/val_2012.txt"); + //plist = get_paths("/home/pjreddie/data/voc/no_2007_test.txt"); + //plist = get_paths("/home/pjreddie/data/coco/trainval.txt"); + //plist = get_paths("/home/pjreddie/data/voc/all2007-2012.txt"); + } + paths = (char **)list_to_array(plist); + pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer); + clock_t time; + cost_layer clayer = *((cost_layer *)net.layers[net.n-1]); + while(1){ + i += 1; + time=clock(); + pthread_join(load_thread, 0); + train = buffer; + load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer); + + /* + image im = float_to_image(net.w, net.h, 3, train.X.vals[114]); + image copy = copy_image(im); + draw_detection(copy, train.y.vals[114], 7); + free_image(copy); + */ + + int z; + int count = 0; + float sx, sy, sw, sh; + sx = sy = sw = sh = 0; + for(z = 0; z < clayer.batch*clayer.inputs; z += 24){ + if(clayer.delta[z+20]){ + ++count; + sx += fabs(clayer.delta[z+20])*64; + sy += fabs(clayer.delta[z+21])*64; + sw += fabs(clayer.delta[z+22])*448; + sh += fabs(clayer.delta[z+23])*448; + } + } + printf("Avg error: %f, %f, %f x %f\n", sx/count, sy/count, sw/count, sh/count); + + printf("Loaded: %lf seconds\n", sec(clock()-time)); + time=clock(); + float loss = train_network(net, train); + net.seen += imgs; + if (avg_loss < 0) avg_loss = loss; + avg_loss = avg_loss*.9 + loss*.1; + printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs); + if(i == 100){ + net.learning_rate *= 10; + } + if(i%100==0){ + char buff[256]; + sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i); + save_weights(net, buff); + } + free_data(train); + } } void train_detection(char *cfgfile, char *weightfile) @@ -110,6 +279,9 @@ void train_detection(char *cfgfile, char *weightfile) if (avg_loss < 0) avg_loss = loss; avg_loss = avg_loss*.9 + loss*.1; printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs); + if(i == 100){ + net.learning_rate *= 10; + } if(i%100==0){ char buff[256]; sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i); @@ -140,7 +312,7 @@ void predict_detections(network net, data d, float threshold, int offset, int cl h = h*h; float prob = scale*pred.vals[j][k+class+background+nuisance]; if(prob < threshold) continue; - printf("%d %d %f %f %f %f %f\n", offset + j, class, prob, y, x, h, w); + printf("%d %d %f %f %f %f %f\n", offset + j, class, prob, x, y, w, h); } } } @@ -209,6 +381,130 @@ void validate_detection(char *cfgfile, char *weightfile) } } +void validate_detection_post(char *cfgfile, char *weightfile) +{ + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + set_batch_network(&net, 1); + + network post = parse_network_cfg("cfg/localize.cfg"); + load_weights(&post, "/home/pjreddie/imagenet_backup/localize_1000.weights"); + set_batch_network(&post, 1); + + detection_layer *layer = get_network_detection_layer(net); + fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + srand(time(0)); + + //list *plist = get_paths("/home/pjreddie/data/voc/test_2007.txt"); + list *plist = get_paths("/home/pjreddie/data/voc/val_2012.txt"); + //list *plist = get_paths("/home/pjreddie/data/voc/test.txt"); + //list *plist = get_paths("/home/pjreddie/data/voc/val.expanded.txt"); + //list *plist = get_paths("/home/pjreddie/data/voc/train.txt"); + char **paths = (char **)list_to_array(plist); + + int classes = layer->classes; + int nuisance = layer->nuisance; + int background = (layer->background && !nuisance); + int num_boxes = sqrt(get_detection_layer_locations(*layer)); + + int per_box = 4+classes+background+nuisance; + + int m = plist->size; + int i = 0; + float threshold = .01; + + clock_t time = clock(); + for(i = 0; i < m; ++i){ + image im = load_image_color(paths[i], 0, 0); + if(i % 100 == 0) { + fprintf(stderr, "%d: Loaded: %lf seconds\n", i, sec(clock()-time)); + time = clock(); + } + image sized = resize_image(im, net.w, net.h); + float *out = network_predict(net, sized.data); + free_image(sized); + int k, class; + //show_image(im, "original"); + int num_output = num_boxes*num_boxes*per_box; + //image cp1 = copy_image(im); + //draw_detection(cp1, out, 7, "before"); + for(k = 0; k < num_output; k += per_box){ + float *post_out = 0; + float scale = 1.; + int index = k/per_box; + int row = index / num_boxes; + int col = index % num_boxes; + if (nuisance) scale = 1.-out[k]; + for (class = 0; class < classes; ++class){ + int ci = k+classes+background+nuisance; + float x = (out[ci + 0] + col)/num_boxes; + float y = (out[ci + 1] + row)/num_boxes; + float w = out[ci + 2]; //* distance_from_edge(row, num_boxes); + float h = out[ci + 3]; //* distance_from_edge(col, num_boxes); + w = w*w; + h = h*h; + float prob = scale*out[k+class+background+nuisance]; + if (prob >= threshold) { + x *= im.w; + y *= im.h; + w *= im.w; + h *= im.h; + w += 32; + h += 32; + int left = (x - w/2); + int top = (y - h/2); + int right = (x + w/2); + int bot = (y+h/2); + if (left < 0) left = 0; + if (right > im.w) right = im.w; + if (top < 0) top = 0; + if (bot > im.h) bot = im.h; + + image crop = crop_image(im, left, top, right-left, bot-top); + image resize = resize_image(crop, post.w, post.h); + if (!post_out){ + post_out = network_predict(post, resize.data); + } + /* + draw_localization(resize, post_out); + show_image(resize, "second"); + fprintf(stderr, "%s\n", class_names[class]); + cvWaitKey(0); + */ + int index = 4*class; + float px = post_out[index+0]; + float py = post_out[index+1]; + float pw = post_out[index+2]; + float ph = post_out[index+3]; + px = (px * crop.w + left) / im.w; + py = (py * crop.h + top) / im.h; + pw = (pw * crop.w) / im.w; + ph = (ph * crop.h) / im.h; + + out[ci + 0] = px*num_boxes - col; + out[ci + 1] = py*num_boxes - row; + out[ci + 2] = sqrt(pw); + out[ci + 3] = sqrt(ph); + /* + show_image(crop, "cropped"); + cvWaitKey(0); + */ + free_image(crop); + free_image(resize); + printf("%d %d %f %f %f %f %f\n", i, class, prob, px, py, pw, ph); + } + } + } + /* + image cp2 = copy_image(im); + draw_detection(cp2, out, 7, "after"); + cvWaitKey(0); + */ + } +} + void test_detection(char *cfgfile, char *weightfile) { network net = parse_network_cfg(cfgfile); @@ -229,7 +525,7 @@ void test_detection(char *cfgfile, char *weightfile) time=clock(); float *predictions = network_predict(net, X); printf("%s: Predicted in %f seconds.\n", filename, sec(clock()-time)); - draw_detection(im, predictions, 7); + draw_detection(im, predictions, 7, "detections"); free_image(im); } } @@ -245,5 +541,8 @@ void run_detection(int argc, char **argv) char *weights = (argc > 4) ? argv[4] : 0; if(0==strcmp(argv[2], "test")) test_detection(cfg, weights); else if(0==strcmp(argv[2], "train")) train_detection(cfg, weights); + else if(0==strcmp(argv[2], "teststuff")) train_detection_teststuff(cfg, weights); + else if(0==strcmp(argv[2], "trainloc")) train_localization(cfg, weights); else if(0==strcmp(argv[2], "valid")) validate_detection(cfg, weights); + else if(0==strcmp(argv[2], "validpost")) validate_detection_post(cfg, weights); } diff --git a/src/network.h b/src/network.h index ed8872b7..d05f548f 100644 --- a/src/network.h +++ b/src/network.h @@ -47,6 +47,7 @@ float train_network_datum_gpu(network net, float *x, float *y); float *network_predict_gpu(network net, float *input); float * get_network_output_gpu_layer(network net, int i); float * get_network_delta_gpu_layer(network net, int i); +float *get_network_output_gpu(network net); #endif void compare_networks(network n1, network n2, data d);