diff --git a/src/batchnorm_layer.c b/src/batchnorm_layer.c index 510f1b2f..7eac44ef 100644 --- a/src/batchnorm_layer.c +++ b/src/batchnorm_layer.c @@ -166,10 +166,10 @@ void forward_batchnorm_layer_gpu(layer l, network_state state) fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu); fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.variance_gpu); - scal_ongpu(l.out_c, .95, l.rolling_mean_gpu, 1); - axpy_ongpu(l.out_c, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1); - scal_ongpu(l.out_c, .95, l.rolling_variance_gpu, 1); - axpy_ongpu(l.out_c, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1); + scal_ongpu(l.out_c, .99, l.rolling_mean_gpu, 1); + axpy_ongpu(l.out_c, .01, l.mean_gpu, 1, l.rolling_mean_gpu, 1); + scal_ongpu(l.out_c, .99, l.rolling_variance_gpu, 1); + axpy_ongpu(l.out_c, .01, l.variance_gpu, 1, l.rolling_variance_gpu, 1); copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1); normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_h*l.out_w); diff --git a/src/blas.c b/src/blas.c index d6ab88bd..c6d59eac 100644 --- a/src/blas.c +++ b/src/blas.c @@ -6,7 +6,7 @@ #include #include -void reorg(float *x, int size, int layers, int batch, int forward) +void flatten(float *x, int size, int layers, int batch, int forward) { float *swap = calloc(size*layers*batch, sizeof(float)); int i,c,b; @@ -189,12 +189,12 @@ void softmax(float *input, int n, float temp, float *output) if(input[i] > largest) largest = input[i]; } for(i = 0; i < n; ++i){ - sum += exp(input[i]/temp-largest/temp); + float e = exp(input[i]/temp - largest/temp); + sum += e; + output[i] = e; } - if(sum) sum = largest/temp+log(sum); - else sum = largest-100; for(i = 0; i < n; ++i){ - output[i] = exp(input[i]/temp-sum); + output[i] /= sum; } } diff --git a/src/blas.h b/src/blas.h index 51554a83..a942024b 100644 --- a/src/blas.h +++ b/src/blas.h @@ -1,6 +1,6 @@ #ifndef BLAS_H #define BLAS_H -void reorg(float *x, int size, int layers, int batch, int forward); +void flatten(float *x, int size, int layers, int batch, int forward); void pm(int M, int N, float *A); float *random_matrix(int rows, int cols); void time_random_matrix(int TA, int TB, int m, int k, int n); @@ -80,5 +80,7 @@ void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forwa void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output); void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t); +void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out); + #endif #endif diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index 684e66d9..d9401766 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -543,6 +543,30 @@ extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * check_error(cudaPeekAtLastError()); } +__global__ void flatten_kernel(int N, float *x, int spatial, int layers, int batch, int forward, float *out) +{ + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if(i >= N) return; + int in_s = i%spatial; + i = i/spatial; + int in_c = i%layers; + i = i/layers; + int b = i; + + int i1 = b*layers*spatial + in_c*spatial + in_s; + int i2 = b*layers*spatial + in_s*layers + in_c; + + if (forward) out[i2] = x[i1]; + else out[i1] = x[i2]; +} + +extern "C" void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out) +{ + int size = spatial*batch*layers; + flatten_kernel<<>>(size, x, spatial, layers, batch, forward, out); + check_error(cudaPeekAtLastError()); +} + extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out) { int size = w*h*c*batch; @@ -718,11 +742,12 @@ __device__ void softmax_device(int n, float *input, float temp, float *output) largest = (val>largest) ? val : largest; } for(i = 0; i < n; ++i){ - sum += exp(input[i]/temp-largest/temp); + float e = exp(input[i]/temp - largest/temp); + sum += e; + output[i] = e; } - sum = (sum != 0) ? largest/temp+log(sum) : largest-100; for(i = 0; i < n; ++i){ - output[i] = exp(input[i]/temp-sum); + output[i] /= sum; } } diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 1d93b3fc..86285e03 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -368,6 +368,14 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h) l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n); l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n); + + if(l->batch_normalize){ + cuda_free(l->x_gpu); + cuda_free(l->x_norm_gpu); + + l->x_gpu = cuda_make_array(l->output, l->batch*l->outputs); + l->x_norm_gpu = cuda_make_array(l->output, l->batch*l->outputs); + } #ifdef CUDNN cudnn_convolutional_setup(l); #endif diff --git a/src/cuda.c b/src/cuda.c index 617e7b35..1b51271f 100644 --- a/src/cuda.c +++ b/src/cuda.c @@ -26,6 +26,7 @@ int cuda_get_device() void check_error(cudaError_t status) { + //cudaDeviceSynchronize(); cudaError_t status2 = cudaGetLastError(); if (status != cudaSuccess) { diff --git a/src/darknet.c b/src/darknet.c index 989bf6fc..4419107a 100644 --- a/src/darknet.c +++ b/src/darknet.c @@ -127,7 +127,7 @@ void oneoff(char *cfgfile, char *weightfile, char *outfile) network net = parse_network_cfg(cfgfile); int oldn = net.layers[net.n - 2].n; int c = net.layers[net.n - 2].c; - net.layers[net.n - 2].n = 7879; + net.layers[net.n - 2].n = 9372; net.layers[net.n - 2].biases += 5; net.layers[net.n - 2].weights += 5*c; if(weightfile){ diff --git a/src/data.c b/src/data.c index a2390a9b..8fb1a258 100644 --- a/src/data.c +++ b/src/data.c @@ -171,6 +171,13 @@ void correct_boxes(box_label *boxes, int n, float dx, float dy, float sx, float { int i; for(i = 0; i < n; ++i){ + if(boxes[i].x == 0 && boxes[i].y == 0) { + boxes[i].x = 999999; + boxes[i].y = 999999; + boxes[i].w = 999999; + boxes[i].h = 999999; + continue; + } boxes[i].left = boxes[i].left * sx - dx; boxes[i].right = boxes[i].right * sx - dx; boxes[i].top = boxes[i].top * sy - dy; @@ -289,6 +296,7 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, find_replace(path, "images", "labels", labelpath); find_replace(labelpath, "JPEGImages", "labels", labelpath); + find_replace(labelpath, "raw", "labels", labelpath); find_replace(labelpath, ".jpg", ".txt", labelpath); find_replace(labelpath, ".png", ".txt", labelpath); find_replace(labelpath, ".JPG", ".txt", labelpath); @@ -309,7 +317,7 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, h = boxes[i].h; id = boxes[i].id; - if (w < .01 || h < .01) continue; + if ((w < .01 || h < .01)) continue; truth[i*5+0] = x; truth[i*5+1] = y; diff --git a/src/detector.c b/src/detector.c index f18ae517..3853ebb3 100644 --- a/src/detector.c +++ b/src/detector.c @@ -75,8 +75,27 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i pthread_t load_thread = load_data(args); clock_t time; + int count = 0; //while(i*imgs < N*120){ while(get_current_batch(net) < net.max_batches){ + if(l.random && count++%10 == 0){ + printf("Resizing\n"); + int dim = (rand() % 10 + 10) * 32; + //int dim = (rand() % 4 + 16) * 32; + printf("%d\n", dim); + args.w = dim; + args.h = dim; + + pthread_join(load_thread, 0); + train = buffer; + free_data(train); + load_thread = load_data(args); + + for(i = 0; i < ngpus; ++i){ + resize_network(nets + i, dim, dim); + } + net = nets[0]; + } time=clock(); pthread_join(load_thread, 0); train = buffer; @@ -117,13 +136,15 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i i = get_current_batch(net); printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs); - if(i%1000==0 || (i < 1000 && i%100 == 0)){ + if(i%100==0 || (i < 1000 && i%100 == 0)){ + if(ngpus != 1) sync_nets(nets, ngpus, 0); char buff[256]; sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); save_weights(net, buff); } free_data(train); } + if(ngpus != 1) sync_nets(nets, ngpus, 0); char buff[256]; sprintf(buff, "%s/%s_final.weights", backup_directory, base); save_weights(net, buff); @@ -183,6 +204,29 @@ void print_detector_detections(FILE **fps, char *id, box *boxes, float **probs, } } +void print_imagenet_detections(FILE *fp, int id, box *boxes, float **probs, int total, int classes, int w, int h, int *map) +{ + int i, j; + for(i = 0; i < total; ++i){ + float xmin = boxes[i].x - boxes[i].w/2.; + float xmax = boxes[i].x + boxes[i].w/2.; + float ymin = boxes[i].y - boxes[i].h/2.; + float ymax = boxes[i].y + boxes[i].h/2.; + + if (xmin < 0) xmin = 0; + if (ymin < 0) ymin = 0; + if (xmax > w) xmax = w; + if (ymax > h) ymax = h; + + for(j = 0; j < classes; ++j){ + int class = j; + if (map) class = map[j]; + if (probs[i][class]) fprintf(fp, "%d %d %f %f %f %f %f\n", id, j+1, probs[i][class], + xmin, ymin, xmax, ymax); + } + } +} + void validate_detector(char *datacfg, char *cfgfile, char *weightfile) { list *options = read_data_cfg(datacfg); @@ -190,15 +234,25 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile) char *name_list = option_find_str(options, "names", "data/names.list"); char *prefix = option_find_str(options, "results", "results"); char **names = get_labels(name_list); + char *mapf = option_find_str(options, "map", 0); + int *map = 0; + if (mapf) map = read_map(mapf); char buff[1024]; - int coco = option_find_int_quiet(options, "coco", 0); - FILE *coco_fp = 0; - if(coco){ + char *type = option_find_str(options, "eval", "voc"); + FILE *fp = 0; + int coco = 0; + int imagenet = 0; + if(0==strcmp(type, "coco")){ snprintf(buff, 1024, "%s/coco_results.json", prefix); - coco_fp = fopen(buff, "w"); - fprintf(coco_fp, "[\n"); + fp = fopen(buff, "w"); + fprintf(fp, "[\n"); + coco = 1; + } else if(0==strcmp(type, "imagenet")){ + snprintf(buff, 1024, "%s/imagenet-detection.txt", prefix); + fp = fopen(buff, "w"); + imagenet = 1; } network net = parse_network_cfg(cfgfile); @@ -230,10 +284,10 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile) int i=0; int t; - float thresh = .001; - float nms = .5; + float thresh = .005; + float nms = .45; - int nthreads = 2; + int nthreads = 4; image *val = calloc(nthreads, sizeof(image)); image *val_resized = calloc(nthreads, sizeof(image)); image *buf = calloc(nthreads, sizeof(image)); @@ -274,9 +328,11 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile) int h = val[t].h; get_region_boxes(l, w, h, thresh, probs, boxes, 0); if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, classes, nms); - if(coco_fp){ - print_cocos(coco_fp, path, boxes, probs, l.w*l.h*l.n, classes, w, h); - }else{ + if (coco){ + print_cocos(fp, path, boxes, probs, l.w*l.h*l.n, classes, w, h); + } else if (imagenet){ + print_imagenet_detections(fp, i+t-nthreads+1 + 9741, boxes, probs, l.w*l.h*l.n, 200, w, h, map); + } else { print_detector_detections(fps, id, boxes, probs, l.w*l.h*l.n, classes, w, h); } free(id); @@ -287,10 +343,10 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile) for(j = 0; j < classes; ++j){ fclose(fps[j]); } - if(coco_fp){ - fseek(coco_fp, -2, SEEK_CUR); - fprintf(coco_fp, "\n]\n"); - fclose(coco_fp); + if(coco){ + fseek(fp, -2, SEEK_CUR); + fprintf(fp, "\n]\n"); + fclose(fp); } fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start)); } diff --git a/src/layer.h b/src/layer.h index c149f29d..eb480c00 100644 --- a/src/layer.h +++ b/src/layer.h @@ -120,6 +120,7 @@ struct layer{ int random; float thresh; int classfix; + int absolute; int dontload; int dontloadscales; diff --git a/src/network.c b/src/network.c index 8d46c55b..0914e37e 100644 --- a/src/network.c +++ b/src/network.c @@ -41,7 +41,7 @@ void reset_momentum(network net) net.momentum = 0; net.decay = 0; #ifdef GPU - if(gpu_index >= 0) update_network_gpu(net); + //if(net.gpu_index >= 0) update_network_gpu(net); #endif } @@ -60,7 +60,7 @@ float get_current_rate(network net) for(i = 0; i < net.num_steps; ++i){ if(net.steps[i] > batch_num) return rate; rate *= net.scales[i]; - if(net.steps[i] > batch_num - 1) reset_momentum(net); + //if(net.steps[i] > batch_num - 1 && net.scales[i] > 1) reset_momentum(net); } return rate; case EXP: @@ -321,6 +321,12 @@ void set_batch_network(network *net, int b) int resize_network(network *net, int w, int h) { +#ifdef GPU + cuda_set_device(net->gpu_index); + if(gpu_index >= 0){ + cuda_free(net->workspace); + } +#endif int i; //if(w == net->w && h == net->h) return 0; net->w = w; @@ -337,6 +343,10 @@ int resize_network(network *net, int w, int h) resize_crop_layer(&l, w, h); }else if(l.type == MAXPOOL){ resize_maxpool_layer(&l, w, h); + }else if(l.type == REGION){ + resize_region_layer(&l, w, h); + }else if(l.type == ROUTE){ + resize_route_layer(&l, net); }else if(l.type == REORG){ resize_reorg_layer(&l, w, h); }else if(l.type == AVGPOOL){ @@ -357,7 +367,12 @@ int resize_network(network *net, int w, int h) } #ifdef GPU if(gpu_index >= 0){ - cuda_free(net->workspace); + if(net->input_gpu) { + cuda_free(*net->input_gpu); + *net->input_gpu = 0; + cuda_free(*net->truth_gpu); + *net->truth_gpu = 0; + } net->workspace = cuda_make_array(0, (workspace_size-1)/sizeof(float)+1); }else { free(net->workspace); diff --git a/src/network_kernels.cu b/src/network_kernels.cu index a7510e83..313cd6d1 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -78,6 +78,7 @@ void backward_network_gpu(network net, network_state state) void update_network_gpu(network net) { + cuda_set_device(net.gpu_index); int i; int update_batch = net.batch*net.subdivisions; float rate = get_current_rate(net); @@ -377,7 +378,7 @@ float train_networks(network *nets, int n, data d, int interval) float *get_network_output_layer_gpu(network net, int i) { layer l = net.layers[i]; - cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch); + if(l.type != REGION) cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch); return l.output; } diff --git a/src/parser.c b/src/parser.c index 4e71fe5e..db4cf368 100644 --- a/src/parser.c +++ b/src/parser.c @@ -2,32 +2,32 @@ #include #include -#include "blas.h" -#include "parser.h" -#include "assert.h" -#include "activations.h" -#include "crop_layer.h" -#include "cost_layer.h" -#include "convolutional_layer.h" #include "activation_layer.h" -#include "normalization_layer.h" -#include "batchnorm_layer.h" -#include "connected_layer.h" -#include "rnn_layer.h" -#include "gru_layer.h" -#include "crnn_layer.h" -#include "maxpool_layer.h" -#include "reorg_layer.h" -#include "softmax_layer.h" -#include "dropout_layer.h" -#include "detection_layer.h" -#include "region_layer.h" +#include "activations.h" +#include "assert.h" #include "avgpool_layer.h" +#include "batchnorm_layer.h" +#include "blas.h" +#include "connected_layer.h" +#include "convolutional_layer.h" +#include "cost_layer.h" +#include "crnn_layer.h" +#include "crop_layer.h" +#include "detection_layer.h" +#include "dropout_layer.h" +#include "gru_layer.h" +#include "list.h" #include "local_layer.h" +#include "maxpool_layer.h" +#include "normalization_layer.h" +#include "option_list.h" +#include "parser.h" +#include "region_layer.h" +#include "reorg_layer.h" +#include "rnn_layer.h" #include "route_layer.h" #include "shortcut_layer.h" -#include "list.h" -#include "option_list.h" +#include "softmax_layer.h" #include "utils.h" typedef struct{ @@ -232,21 +232,6 @@ softmax_layer parse_softmax(list *options, size_params params) return layer; } -int *read_map(char *filename) -{ - int n = 0; - int *map = 0; - char *str; - FILE *file = fopen(filename, "r"); - if(!file) file_error(filename); - while((str=fgetl(file))){ - ++n; - map = realloc(map, n*sizeof(int)); - map[n-1] = atoi(str); - } - return map; -} - layer parse_region(list *options, size_params params) { int coords = option_find_int(options, "coords", 4); @@ -269,6 +254,8 @@ layer parse_region(list *options, size_params params) l.thresh = option_find_float(options, "thresh", .5); l.classfix = option_find_int_quiet(options, "classfix", 0); + l.absolute = option_find_int_quiet(options, "absolute", 0); + l.random = option_find_int_quiet(options, "random", 0); l.coord_scale = option_find_float(options, "coord_scale", 1); l.object_scale = option_find_float(options, "object_scale", 1); diff --git a/src/region_layer.c b/src/region_layer.c index ac30e889..5e8387dd 100644 --- a/src/region_layer.c +++ b/src/region_layer.c @@ -9,6 +9,8 @@ #include #include +#define DOABS 1 + region_layer make_region_layer(int batch, int w, int h, int n, int classes, int coords) { region_layer l = {0}; @@ -48,7 +50,26 @@ region_layer make_region_layer(int batch, int w, int h, int n, int classes, int return l; } -#define DOABS 1 +void resize_region_layer(layer *l, int w, int h) +{ + l->w = w; + l->h = h; + + l->outputs = h*w*l->n*(l->classes + l->coords + 1); + l->inputs = l->outputs; + + l->output = realloc(l->output, l->batch*l->outputs*sizeof(float)); + l->delta = realloc(l->delta, l->batch*l->outputs*sizeof(float)); + +#ifdef GPU + cuda_free(l->delta_gpu); + cuda_free(l->output_gpu); + + l->delta_gpu = cuda_make_array(l->delta, l->batch*l->outputs); + l->output_gpu = cuda_make_array(l->output, l->batch*l->outputs); +#endif +} + box get_region_box(float *x, float *biases, int n, int index, int i, int j, int w, int h) { box b; @@ -125,7 +146,9 @@ void forward_region_layer(const region_layer l, network_state state) int i,j,b,t,n; int size = l.coords + l.classes + 1; memcpy(l.output, state.input, l.outputs*l.batch*sizeof(float)); - reorg(l.output, l.w*l.h, size*l.n, l.batch, 1); + #ifndef GPU + flatten(l.output, l.w*l.h, size*l.n, l.batch, 1); + #endif for (b = 0; b < l.batch; ++b){ for(i = 0; i < l.h*l.w*l.n; ++i){ int index = size*i + b*l.outputs; @@ -134,25 +157,14 @@ void forward_region_layer(const region_layer l, network_state state) } +#ifndef GPU if (l.softmax_tree){ -#ifdef GPU - cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs); - int i; - int count = 5; - for (i = 0; i < l.softmax_tree->groups; ++i) { - int group_size = l.softmax_tree->group_size[i]; - softmax_gpu(l.output_gpu+count, group_size, l.classes + 5, l.w*l.h*l.n*l.batch, 1, l.output_gpu + count); - count += group_size; - } - cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs); -#else for (b = 0; b < l.batch; ++b){ for(i = 0; i < l.h*l.w*l.n; ++i){ int index = size*i + b*l.outputs; softmax_tree(l.output + index + 5, 1, 0, 1, l.softmax_tree, l.output + index + 5); } } -#endif } else if (l.softmax){ for (b = 0; b < l.batch; ++b){ for(i = 0; i < l.h*l.w*l.n; ++i){ @@ -161,6 +173,7 @@ void forward_region_layer(const region_layer l, network_state state) } } } +#endif if(!state.train) return; memset(l.delta, 0, l.outputs * l.batch * sizeof(float)); float avg_iou = 0; @@ -172,6 +185,32 @@ void forward_region_layer(const region_layer l, network_state state) int class_count = 0; *(l.cost) = 0; for (b = 0; b < l.batch; ++b) { + if(l.softmax_tree){ + int onlyclass = 0; + for(t = 0; t < 30; ++t){ + box truth = float_to_box(state.truth + t*5 + b*l.truths); + if(!truth.x) break; + int class = state.truth[t*5 + b*l.truths + 4]; + float maxp = 0; + int maxi = 0; + if(truth.x > 100000 && truth.y > 100000){ + for(n = 0; n < l.n*l.w*l.h; ++n){ + int index = size*n + b*l.outputs + 5; + float p = get_hierarchy_probability(l.output + index, l.softmax_tree, class); + if(p > maxp){ + maxp = p; + maxi = n; + } + } + int index = size*maxi + b*l.outputs + 5; + delta_region_class(l.output, l.delta, index, class, l.classes, l.softmax_tree, l.class_scale, &avg_cat); + ++class_count; + onlyclass = 1; + break; + } + } + if(onlyclass) continue; + } for (j = 0; j < l.h; ++j) { for (i = 0; i < l.w; ++i) { for (n = 0; n < l.n; ++n) { @@ -273,7 +312,9 @@ void forward_region_layer(const region_layer l, network_state state) } } //printf("\n"); - reorg(l.delta, l.w*l.h, size*l.n, l.batch, 0); + #ifndef GPU + flatten(l.delta, l.w*l.h, size*l.n, l.batch, 0); + #endif *(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2); printf("Region Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, Avg Recall: %f, count: %d\n", avg_iou/count, avg_cat/class_count, avg_obj/count, avg_anyobj/(l.w*l.h*l.n*l.batch), recall/count, count); } @@ -308,13 +349,18 @@ void get_region_boxes(layer l, int w, int h, float thresh, float **probs, box *b hierarchy_predictions(predictions + class_index, l.classes, l.softmax_tree, 0); int found = 0; for(j = l.classes - 1; j >= 0; --j){ - if(!found && predictions[class_index + j] > .5){ - found = 1; - } else { - predictions[class_index + j] = 0; + if(1){ + if(!found && predictions[class_index + j] > .5){ + found = 1; + } else { + predictions[class_index + j] = 0; + } + float prob = predictions[class_index+j]; + probs[index][j] = (scale > thresh) ? prob : 0; + }else{ + float prob = scale*predictions[class_index+j]; + probs[index][j] = (prob > thresh) ? prob : 0; } - float prob = predictions[class_index+j]; - probs[index][j] = (scale > thresh) ? prob : 0; } }else{ for(j = 0; j < l.classes; ++j){ @@ -339,6 +385,18 @@ void forward_region_layer_gpu(const region_layer l, network_state state) return; } */ + flatten_ongpu(state.input, l.h*l.w, l.n*(l.coords + l.classes + 1), l.batch, 1, l.output_gpu); + if(l.softmax_tree){ + int i; + int count = 5; + for (i = 0; i < l.softmax_tree->groups; ++i) { + int group_size = l.softmax_tree->group_size[i]; + softmax_gpu(l.output_gpu+count, group_size, l.classes + 5, l.w*l.h*l.n*l.batch, 1, l.output_gpu + count); + count += group_size; + } + }else if (l.softmax){ + softmax_gpu(l.output_gpu+5, l.classes, l.classes + 5, l.w*l.h*l.n*l.batch, 1, l.output_gpu + 5); + } float *in_cpu = calloc(l.batch*l.inputs, sizeof(float)); float *truth_cpu = 0; @@ -347,22 +405,22 @@ void forward_region_layer_gpu(const region_layer l, network_state state) truth_cpu = calloc(num_truth, sizeof(float)); cuda_pull_array(state.truth, truth_cpu, num_truth); } - cuda_pull_array(state.input, in_cpu, l.batch*l.inputs); + cuda_pull_array(l.output_gpu, in_cpu, l.batch*l.inputs); network_state cpu_state = state; cpu_state.train = state.train; cpu_state.truth = truth_cpu; cpu_state.input = in_cpu; forward_region_layer(l, cpu_state); - cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs); - cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs); + //cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs); free(cpu_state.input); + if(!state.train) return; + cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs); if(cpu_state.truth) free(cpu_state.truth); } void backward_region_layer_gpu(region_layer l, network_state state) { - axpy_ongpu(l.batch*l.outputs, 1, l.delta_gpu, 1, state.delta, 1); - //copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1); + flatten_ongpu(l.delta_gpu, l.h*l.w, l.n*(l.coords + l.classes + 1), l.batch, 0, state.delta); } #endif diff --git a/src/region_layer.h b/src/region_layer.h index 01901e07..3d04d662 100644 --- a/src/region_layer.h +++ b/src/region_layer.h @@ -10,6 +10,7 @@ region_layer make_region_layer(int batch, int h, int w, int n, int classes, int void forward_region_layer(const region_layer l, network_state state); void backward_region_layer(const region_layer l, network_state state); void get_region_boxes(layer l, int w, int h, float thresh, float **probs, box *boxes, int only_objectness); +void resize_region_layer(layer *l, int w, int h); #ifdef GPU void forward_region_layer_gpu(const region_layer l, network_state state); diff --git a/src/reorg_layer.c b/src/reorg_layer.c index 0f2a1c21..d93dd976 100644 --- a/src/reorg_layer.c +++ b/src/reorg_layer.c @@ -22,6 +22,7 @@ layer make_reorg_layer(int batch, int h, int w, int c, int stride, int reverse) l.out_h = h/stride; l.out_c = c*(stride*stride); } + l.reverse = reverse; fprintf(stderr, "Reorg Layer: %d x %d x %d image -> %d x %d x %d image, \n", w,h,c,l.out_w, l.out_h, l.out_c); l.outputs = l.out_h * l.out_w * l.out_c; l.inputs = h*w*c; @@ -44,12 +45,20 @@ layer make_reorg_layer(int batch, int h, int w, int c, int stride, int reverse) void resize_reorg_layer(layer *l, int w, int h) { int stride = l->stride; + int c = l->c; l->h = h; l->w = w; - l->out_w = w*stride; - l->out_h = h*stride; + if(l->reverse){ + l->out_w = w*stride; + l->out_h = h*stride; + l->out_c = c/(stride*stride); + }else{ + l->out_w = w/stride; + l->out_h = h/stride; + l->out_c = c*(stride*stride); + } l->outputs = l->out_h * l->out_w * l->out_c; l->inputs = l->outputs; diff --git a/src/route_layer.c b/src/route_layer.c index 47e3d703..d18427a8 100644 --- a/src/route_layer.c +++ b/src/route_layer.c @@ -36,6 +36,40 @@ route_layer make_route_layer(int batch, int n, int *input_layers, int *input_siz return l; } +void resize_route_layer(route_layer *l, network *net) +{ + int i; + layer first = net->layers[l->input_layers[0]]; + l->out_w = first.out_w; + l->out_h = first.out_h; + l->out_c = first.out_c; + l->outputs = first.outputs; + l->input_sizes[0] = first.outputs; + for(i = 1; i < l->n; ++i){ + int index = l->input_layers[i]; + layer next = net->layers[index]; + l->outputs += next.outputs; + l->input_sizes[i] = next.outputs; + if(next.out_w == first.out_w && next.out_h == first.out_h){ + l->out_c += next.out_c; + }else{ + printf("%d %d, %d %d\n", next.out_w, next.out_h, first.out_w, first.out_h); + l->out_h = l->out_w = l->out_c = 0; + } + } + l->inputs = l->outputs; + l->delta = realloc(l->delta, l->outputs*l->batch*sizeof(float)); + l->output = realloc(l->output, l->outputs*l->batch*sizeof(float)); + +#ifdef GPU + cuda_free(l->output_gpu); + cuda_free(l->delta_gpu); + l->output_gpu = cuda_make_array(l->output, l->outputs*l->batch); + l->delta_gpu = cuda_make_array(l->delta, l->outputs*l->batch); +#endif + +} + void forward_route_layer(const route_layer l, network_state state) { int i, j; diff --git a/src/route_layer.h b/src/route_layer.h index 77245a63..45467d95 100644 --- a/src/route_layer.h +++ b/src/route_layer.h @@ -8,6 +8,7 @@ typedef layer route_layer; route_layer make_route_layer(int batch, int n, int *input_layers, int *input_size); void forward_route_layer(const route_layer l, network_state state); void backward_route_layer(const route_layer l, network_state state); +void resize_route_layer(route_layer *l, network *net); #ifdef GPU void forward_route_layer_gpu(const route_layer l, network_state state); diff --git a/src/tree.c b/src/tree.c index cd9fcd12..dfa41787 100644 --- a/src/tree.c +++ b/src/tree.c @@ -24,6 +24,16 @@ void change_leaves(tree *t, char *leaf_list) fprintf(stderr, "Found %d leaves.\n", found); } +float get_hierarchy_probability(float *x, tree *hier, int c) +{ + float p = 1; + while(c >= 0){ + p = p * x[c]; + c = hier->parent[c]; + } + return p; +} + void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves) { int j; diff --git a/src/tree.h b/src/tree.h index b0b0ecef..c3f49797 100644 --- a/src/tree.h +++ b/src/tree.h @@ -16,5 +16,6 @@ typedef struct{ tree *read_tree(char *filename); void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves); void change_leaves(tree *t, char *leaf_list); +float get_hierarchy_probability(float *x, tree *hier, int c); #endif diff --git a/src/utils.c b/src/utils.c index e8128b91..b5181d78 100644 --- a/src/utils.c +++ b/src/utils.c @@ -9,6 +9,21 @@ #include "utils.h" +int *read_map(char *filename) +{ + int n = 0; + int *map = 0; + char *str; + FILE *file = fopen(filename, "r"); + if(!file) file_error(filename); + while((str=fgetl(file))){ + ++n; + map = realloc(map, n*sizeof(int)); + map[n-1] = atoi(str); + } + return map; +} + void sorta_shuffle(void *arr, size_t n, size_t size, size_t sections) { size_t i; diff --git a/src/utils.h b/src/utils.h index 46676344..bbc67654 100644 --- a/src/utils.h +++ b/src/utils.h @@ -7,6 +7,7 @@ #define SECRET_NUM -1234 #define TWO_PI 6.2831853071795864769252866 +int *read_map(char *filename); void shuffle(void *arr, size_t n, size_t size); void sorta_shuffle(void *arr, size_t n, size_t size, size_t sections); void free_ptrs(void **ptrs, int n);