This commit is contained in:
Joseph Redmon 2016-11-15 22:53:58 -08:00
parent 9a01e6ccb7
commit 0d6b107ed2
22 changed files with 333 additions and 99 deletions

View File

@ -166,10 +166,10 @@ void forward_batchnorm_layer_gpu(layer l, network_state state)
fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu);
fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.variance_gpu);
scal_ongpu(l.out_c, .95, l.rolling_mean_gpu, 1);
axpy_ongpu(l.out_c, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
scal_ongpu(l.out_c, .95, l.rolling_variance_gpu, 1);
axpy_ongpu(l.out_c, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
scal_ongpu(l.out_c, .99, l.rolling_mean_gpu, 1);
axpy_ongpu(l.out_c, .01, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
scal_ongpu(l.out_c, .99, l.rolling_variance_gpu, 1);
axpy_ongpu(l.out_c, .01, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);

View File

@ -6,7 +6,7 @@
#include <stdlib.h>
#include <string.h>
void reorg(float *x, int size, int layers, int batch, int forward)
void flatten(float *x, int size, int layers, int batch, int forward)
{
float *swap = calloc(size*layers*batch, sizeof(float));
int i,c,b;
@ -189,12 +189,12 @@ void softmax(float *input, int n, float temp, float *output)
if(input[i] > largest) largest = input[i];
}
for(i = 0; i < n; ++i){
sum += exp(input[i]/temp-largest/temp);
float e = exp(input[i]/temp - largest/temp);
sum += e;
output[i] = e;
}
if(sum) sum = largest/temp+log(sum);
else sum = largest-100;
for(i = 0; i < n; ++i){
output[i] = exp(input[i]/temp-sum);
output[i] /= sum;
}
}

View File

@ -1,6 +1,6 @@
#ifndef BLAS_H
#define BLAS_H
void reorg(float *x, int size, int layers, int batch, int forward);
void flatten(float *x, int size, int layers, int batch, int forward);
void pm(int M, int N, float *A);
float *random_matrix(int rows, int cols);
void time_random_matrix(int TA, int TB, int m, int k, int n);
@ -80,5 +80,7 @@ void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forwa
void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output);
void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t);
void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out);
#endif
#endif

View File

@ -543,6 +543,30 @@ extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float *
check_error(cudaPeekAtLastError());
}
__global__ void flatten_kernel(int N, float *x, int spatial, int layers, int batch, int forward, float *out)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(i >= N) return;
int in_s = i%spatial;
i = i/spatial;
int in_c = i%layers;
i = i/layers;
int b = i;
int i1 = b*layers*spatial + in_c*spatial + in_s;
int i2 = b*layers*spatial + in_s*layers + in_c;
if (forward) out[i2] = x[i1];
else out[i1] = x[i2];
}
extern "C" void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out)
{
int size = spatial*batch*layers;
flatten_kernel<<<cuda_gridsize(size), BLOCK>>>(size, x, spatial, layers, batch, forward, out);
check_error(cudaPeekAtLastError());
}
extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
{
int size = w*h*c*batch;
@ -718,11 +742,12 @@ __device__ void softmax_device(int n, float *input, float temp, float *output)
largest = (val>largest) ? val : largest;
}
for(i = 0; i < n; ++i){
sum += exp(input[i]/temp-largest/temp);
float e = exp(input[i]/temp - largest/temp);
sum += e;
output[i] = e;
}
sum = (sum != 0) ? largest/temp+log(sum) : largest-100;
for(i = 0; i < n; ++i){
output[i] = exp(input[i]/temp-sum);
output[i] /= sum;
}
}

View File

@ -368,6 +368,14 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
if(l->batch_normalize){
cuda_free(l->x_gpu);
cuda_free(l->x_norm_gpu);
l->x_gpu = cuda_make_array(l->output, l->batch*l->outputs);
l->x_norm_gpu = cuda_make_array(l->output, l->batch*l->outputs);
}
#ifdef CUDNN
cudnn_convolutional_setup(l);
#endif

View File

@ -26,6 +26,7 @@ int cuda_get_device()
void check_error(cudaError_t status)
{
//cudaDeviceSynchronize();
cudaError_t status2 = cudaGetLastError();
if (status != cudaSuccess)
{

View File

@ -127,7 +127,7 @@ void oneoff(char *cfgfile, char *weightfile, char *outfile)
network net = parse_network_cfg(cfgfile);
int oldn = net.layers[net.n - 2].n;
int c = net.layers[net.n - 2].c;
net.layers[net.n - 2].n = 7879;
net.layers[net.n - 2].n = 9372;
net.layers[net.n - 2].biases += 5;
net.layers[net.n - 2].weights += 5*c;
if(weightfile){

View File

@ -171,6 +171,13 @@ void correct_boxes(box_label *boxes, int n, float dx, float dy, float sx, float
{
int i;
for(i = 0; i < n; ++i){
if(boxes[i].x == 0 && boxes[i].y == 0) {
boxes[i].x = 999999;
boxes[i].y = 999999;
boxes[i].w = 999999;
boxes[i].h = 999999;
continue;
}
boxes[i].left = boxes[i].left * sx - dx;
boxes[i].right = boxes[i].right * sx - dx;
boxes[i].top = boxes[i].top * sy - dy;
@ -289,6 +296,7 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes,
find_replace(path, "images", "labels", labelpath);
find_replace(labelpath, "JPEGImages", "labels", labelpath);
find_replace(labelpath, "raw", "labels", labelpath);
find_replace(labelpath, ".jpg", ".txt", labelpath);
find_replace(labelpath, ".png", ".txt", labelpath);
find_replace(labelpath, ".JPG", ".txt", labelpath);
@ -309,7 +317,7 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes,
h = boxes[i].h;
id = boxes[i].id;
if (w < .01 || h < .01) continue;
if ((w < .01 || h < .01)) continue;
truth[i*5+0] = x;
truth[i*5+1] = y;

View File

@ -75,8 +75,27 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
pthread_t load_thread = load_data(args);
clock_t time;
int count = 0;
//while(i*imgs < N*120){
while(get_current_batch(net) < net.max_batches){
if(l.random && count++%10 == 0){
printf("Resizing\n");
int dim = (rand() % 10 + 10) * 32;
//int dim = (rand() % 4 + 16) * 32;
printf("%d\n", dim);
args.w = dim;
args.h = dim;
pthread_join(load_thread, 0);
train = buffer;
free_data(train);
load_thread = load_data(args);
for(i = 0; i < ngpus; ++i){
resize_network(nets + i, dim, dim);
}
net = nets[0];
}
time=clock();
pthread_join(load_thread, 0);
train = buffer;
@ -117,13 +136,15 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
i = get_current_batch(net);
printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
if(i%1000==0 || (i < 1000 && i%100 == 0)){
if(i%100==0 || (i < 1000 && i%100 == 0)){
if(ngpus != 1) sync_nets(nets, ngpus, 0);
char buff[256];
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
save_weights(net, buff);
}
free_data(train);
}
if(ngpus != 1) sync_nets(nets, ngpus, 0);
char buff[256];
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
save_weights(net, buff);
@ -183,6 +204,29 @@ void print_detector_detections(FILE **fps, char *id, box *boxes, float **probs,
}
}
void print_imagenet_detections(FILE *fp, int id, box *boxes, float **probs, int total, int classes, int w, int h, int *map)
{
int i, j;
for(i = 0; i < total; ++i){
float xmin = boxes[i].x - boxes[i].w/2.;
float xmax = boxes[i].x + boxes[i].w/2.;
float ymin = boxes[i].y - boxes[i].h/2.;
float ymax = boxes[i].y + boxes[i].h/2.;
if (xmin < 0) xmin = 0;
if (ymin < 0) ymin = 0;
if (xmax > w) xmax = w;
if (ymax > h) ymax = h;
for(j = 0; j < classes; ++j){
int class = j;
if (map) class = map[j];
if (probs[i][class]) fprintf(fp, "%d %d %f %f %f %f %f\n", id, j+1, probs[i][class],
xmin, ymin, xmax, ymax);
}
}
}
void validate_detector(char *datacfg, char *cfgfile, char *weightfile)
{
list *options = read_data_cfg(datacfg);
@ -190,15 +234,25 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile)
char *name_list = option_find_str(options, "names", "data/names.list");
char *prefix = option_find_str(options, "results", "results");
char **names = get_labels(name_list);
char *mapf = option_find_str(options, "map", 0);
int *map = 0;
if (mapf) map = read_map(mapf);
char buff[1024];
int coco = option_find_int_quiet(options, "coco", 0);
FILE *coco_fp = 0;
if(coco){
char *type = option_find_str(options, "eval", "voc");
FILE *fp = 0;
int coco = 0;
int imagenet = 0;
if(0==strcmp(type, "coco")){
snprintf(buff, 1024, "%s/coco_results.json", prefix);
coco_fp = fopen(buff, "w");
fprintf(coco_fp, "[\n");
fp = fopen(buff, "w");
fprintf(fp, "[\n");
coco = 1;
} else if(0==strcmp(type, "imagenet")){
snprintf(buff, 1024, "%s/imagenet-detection.txt", prefix);
fp = fopen(buff, "w");
imagenet = 1;
}
network net = parse_network_cfg(cfgfile);
@ -230,10 +284,10 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile)
int i=0;
int t;
float thresh = .001;
float nms = .5;
float thresh = .005;
float nms = .45;
int nthreads = 2;
int nthreads = 4;
image *val = calloc(nthreads, sizeof(image));
image *val_resized = calloc(nthreads, sizeof(image));
image *buf = calloc(nthreads, sizeof(image));
@ -274,9 +328,11 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile)
int h = val[t].h;
get_region_boxes(l, w, h, thresh, probs, boxes, 0);
if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, classes, nms);
if(coco_fp){
print_cocos(coco_fp, path, boxes, probs, l.w*l.h*l.n, classes, w, h);
}else{
if (coco){
print_cocos(fp, path, boxes, probs, l.w*l.h*l.n, classes, w, h);
} else if (imagenet){
print_imagenet_detections(fp, i+t-nthreads+1 + 9741, boxes, probs, l.w*l.h*l.n, 200, w, h, map);
} else {
print_detector_detections(fps, id, boxes, probs, l.w*l.h*l.n, classes, w, h);
}
free(id);
@ -287,10 +343,10 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile)
for(j = 0; j < classes; ++j){
fclose(fps[j]);
}
if(coco_fp){
fseek(coco_fp, -2, SEEK_CUR);
fprintf(coco_fp, "\n]\n");
fclose(coco_fp);
if(coco){
fseek(fp, -2, SEEK_CUR);
fprintf(fp, "\n]\n");
fclose(fp);
}
fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
}

View File

@ -120,6 +120,7 @@ struct layer{
int random;
float thresh;
int classfix;
int absolute;
int dontload;
int dontloadscales;

View File

@ -41,7 +41,7 @@ void reset_momentum(network net)
net.momentum = 0;
net.decay = 0;
#ifdef GPU
if(gpu_index >= 0) update_network_gpu(net);
//if(net.gpu_index >= 0) update_network_gpu(net);
#endif
}
@ -60,7 +60,7 @@ float get_current_rate(network net)
for(i = 0; i < net.num_steps; ++i){
if(net.steps[i] > batch_num) return rate;
rate *= net.scales[i];
if(net.steps[i] > batch_num - 1) reset_momentum(net);
//if(net.steps[i] > batch_num - 1 && net.scales[i] > 1) reset_momentum(net);
}
return rate;
case EXP:
@ -321,6 +321,12 @@ void set_batch_network(network *net, int b)
int resize_network(network *net, int w, int h)
{
#ifdef GPU
cuda_set_device(net->gpu_index);
if(gpu_index >= 0){
cuda_free(net->workspace);
}
#endif
int i;
//if(w == net->w && h == net->h) return 0;
net->w = w;
@ -337,6 +343,10 @@ int resize_network(network *net, int w, int h)
resize_crop_layer(&l, w, h);
}else if(l.type == MAXPOOL){
resize_maxpool_layer(&l, w, h);
}else if(l.type == REGION){
resize_region_layer(&l, w, h);
}else if(l.type == ROUTE){
resize_route_layer(&l, net);
}else if(l.type == REORG){
resize_reorg_layer(&l, w, h);
}else if(l.type == AVGPOOL){
@ -357,7 +367,12 @@ int resize_network(network *net, int w, int h)
}
#ifdef GPU
if(gpu_index >= 0){
cuda_free(net->workspace);
if(net->input_gpu) {
cuda_free(*net->input_gpu);
*net->input_gpu = 0;
cuda_free(*net->truth_gpu);
*net->truth_gpu = 0;
}
net->workspace = cuda_make_array(0, (workspace_size-1)/sizeof(float)+1);
}else {
free(net->workspace);

View File

@ -78,6 +78,7 @@ void backward_network_gpu(network net, network_state state)
void update_network_gpu(network net)
{
cuda_set_device(net.gpu_index);
int i;
int update_batch = net.batch*net.subdivisions;
float rate = get_current_rate(net);
@ -377,7 +378,7 @@ float train_networks(network *nets, int n, data d, int interval)
float *get_network_output_layer_gpu(network net, int i)
{
layer l = net.layers[i];
cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
if(l.type != REGION) cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
return l.output;
}

View File

@ -2,32 +2,32 @@
#include <string.h>
#include <stdlib.h>
#include "blas.h"
#include "parser.h"
#include "assert.h"
#include "activations.h"
#include "crop_layer.h"
#include "cost_layer.h"
#include "convolutional_layer.h"
#include "activation_layer.h"
#include "normalization_layer.h"
#include "batchnorm_layer.h"
#include "connected_layer.h"
#include "rnn_layer.h"
#include "gru_layer.h"
#include "crnn_layer.h"
#include "maxpool_layer.h"
#include "reorg_layer.h"
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "detection_layer.h"
#include "region_layer.h"
#include "activations.h"
#include "assert.h"
#include "avgpool_layer.h"
#include "batchnorm_layer.h"
#include "blas.h"
#include "connected_layer.h"
#include "convolutional_layer.h"
#include "cost_layer.h"
#include "crnn_layer.h"
#include "crop_layer.h"
#include "detection_layer.h"
#include "dropout_layer.h"
#include "gru_layer.h"
#include "list.h"
#include "local_layer.h"
#include "maxpool_layer.h"
#include "normalization_layer.h"
#include "option_list.h"
#include "parser.h"
#include "region_layer.h"
#include "reorg_layer.h"
#include "rnn_layer.h"
#include "route_layer.h"
#include "shortcut_layer.h"
#include "list.h"
#include "option_list.h"
#include "softmax_layer.h"
#include "utils.h"
typedef struct{
@ -232,21 +232,6 @@ softmax_layer parse_softmax(list *options, size_params params)
return layer;
}
int *read_map(char *filename)
{
int n = 0;
int *map = 0;
char *str;
FILE *file = fopen(filename, "r");
if(!file) file_error(filename);
while((str=fgetl(file))){
++n;
map = realloc(map, n*sizeof(int));
map[n-1] = atoi(str);
}
return map;
}
layer parse_region(list *options, size_params params)
{
int coords = option_find_int(options, "coords", 4);
@ -269,6 +254,8 @@ layer parse_region(list *options, size_params params)
l.thresh = option_find_float(options, "thresh", .5);
l.classfix = option_find_int_quiet(options, "classfix", 0);
l.absolute = option_find_int_quiet(options, "absolute", 0);
l.random = option_find_int_quiet(options, "random", 0);
l.coord_scale = option_find_float(options, "coord_scale", 1);
l.object_scale = option_find_float(options, "object_scale", 1);

View File

@ -9,6 +9,8 @@
#include <string.h>
#include <stdlib.h>
#define DOABS 1
region_layer make_region_layer(int batch, int w, int h, int n, int classes, int coords)
{
region_layer l = {0};
@ -48,7 +50,26 @@ region_layer make_region_layer(int batch, int w, int h, int n, int classes, int
return l;
}
#define DOABS 1
void resize_region_layer(layer *l, int w, int h)
{
l->w = w;
l->h = h;
l->outputs = h*w*l->n*(l->classes + l->coords + 1);
l->inputs = l->outputs;
l->output = realloc(l->output, l->batch*l->outputs*sizeof(float));
l->delta = realloc(l->delta, l->batch*l->outputs*sizeof(float));
#ifdef GPU
cuda_free(l->delta_gpu);
cuda_free(l->output_gpu);
l->delta_gpu = cuda_make_array(l->delta, l->batch*l->outputs);
l->output_gpu = cuda_make_array(l->output, l->batch*l->outputs);
#endif
}
box get_region_box(float *x, float *biases, int n, int index, int i, int j, int w, int h)
{
box b;
@ -125,7 +146,9 @@ void forward_region_layer(const region_layer l, network_state state)
int i,j,b,t,n;
int size = l.coords + l.classes + 1;
memcpy(l.output, state.input, l.outputs*l.batch*sizeof(float));
reorg(l.output, l.w*l.h, size*l.n, l.batch, 1);
#ifndef GPU
flatten(l.output, l.w*l.h, size*l.n, l.batch, 1);
#endif
for (b = 0; b < l.batch; ++b){
for(i = 0; i < l.h*l.w*l.n; ++i){
int index = size*i + b*l.outputs;
@ -134,25 +157,14 @@ void forward_region_layer(const region_layer l, network_state state)
}
#ifndef GPU
if (l.softmax_tree){
#ifdef GPU
cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
int i;
int count = 5;
for (i = 0; i < l.softmax_tree->groups; ++i) {
int group_size = l.softmax_tree->group_size[i];
softmax_gpu(l.output_gpu+count, group_size, l.classes + 5, l.w*l.h*l.n*l.batch, 1, l.output_gpu + count);
count += group_size;
}
cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
#else
for (b = 0; b < l.batch; ++b){
for(i = 0; i < l.h*l.w*l.n; ++i){
int index = size*i + b*l.outputs;
softmax_tree(l.output + index + 5, 1, 0, 1, l.softmax_tree, l.output + index + 5);
}
}
#endif
} else if (l.softmax){
for (b = 0; b < l.batch; ++b){
for(i = 0; i < l.h*l.w*l.n; ++i){
@ -161,6 +173,7 @@ void forward_region_layer(const region_layer l, network_state state)
}
}
}
#endif
if(!state.train) return;
memset(l.delta, 0, l.outputs * l.batch * sizeof(float));
float avg_iou = 0;
@ -172,6 +185,32 @@ void forward_region_layer(const region_layer l, network_state state)
int class_count = 0;
*(l.cost) = 0;
for (b = 0; b < l.batch; ++b) {
if(l.softmax_tree){
int onlyclass = 0;
for(t = 0; t < 30; ++t){
box truth = float_to_box(state.truth + t*5 + b*l.truths);
if(!truth.x) break;
int class = state.truth[t*5 + b*l.truths + 4];
float maxp = 0;
int maxi = 0;
if(truth.x > 100000 && truth.y > 100000){
for(n = 0; n < l.n*l.w*l.h; ++n){
int index = size*n + b*l.outputs + 5;
float p = get_hierarchy_probability(l.output + index, l.softmax_tree, class);
if(p > maxp){
maxp = p;
maxi = n;
}
}
int index = size*maxi + b*l.outputs + 5;
delta_region_class(l.output, l.delta, index, class, l.classes, l.softmax_tree, l.class_scale, &avg_cat);
++class_count;
onlyclass = 1;
break;
}
}
if(onlyclass) continue;
}
for (j = 0; j < l.h; ++j) {
for (i = 0; i < l.w; ++i) {
for (n = 0; n < l.n; ++n) {
@ -273,7 +312,9 @@ void forward_region_layer(const region_layer l, network_state state)
}
}
//printf("\n");
reorg(l.delta, l.w*l.h, size*l.n, l.batch, 0);
#ifndef GPU
flatten(l.delta, l.w*l.h, size*l.n, l.batch, 0);
#endif
*(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2);
printf("Region Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, Avg Recall: %f, count: %d\n", avg_iou/count, avg_cat/class_count, avg_obj/count, avg_anyobj/(l.w*l.h*l.n*l.batch), recall/count, count);
}
@ -308,13 +349,18 @@ void get_region_boxes(layer l, int w, int h, float thresh, float **probs, box *b
hierarchy_predictions(predictions + class_index, l.classes, l.softmax_tree, 0);
int found = 0;
for(j = l.classes - 1; j >= 0; --j){
if(!found && predictions[class_index + j] > .5){
found = 1;
} else {
predictions[class_index + j] = 0;
if(1){
if(!found && predictions[class_index + j] > .5){
found = 1;
} else {
predictions[class_index + j] = 0;
}
float prob = predictions[class_index+j];
probs[index][j] = (scale > thresh) ? prob : 0;
}else{
float prob = scale*predictions[class_index+j];
probs[index][j] = (prob > thresh) ? prob : 0;
}
float prob = predictions[class_index+j];
probs[index][j] = (scale > thresh) ? prob : 0;
}
}else{
for(j = 0; j < l.classes; ++j){
@ -339,6 +385,18 @@ void forward_region_layer_gpu(const region_layer l, network_state state)
return;
}
*/
flatten_ongpu(state.input, l.h*l.w, l.n*(l.coords + l.classes + 1), l.batch, 1, l.output_gpu);
if(l.softmax_tree){
int i;
int count = 5;
for (i = 0; i < l.softmax_tree->groups; ++i) {
int group_size = l.softmax_tree->group_size[i];
softmax_gpu(l.output_gpu+count, group_size, l.classes + 5, l.w*l.h*l.n*l.batch, 1, l.output_gpu + count);
count += group_size;
}
}else if (l.softmax){
softmax_gpu(l.output_gpu+5, l.classes, l.classes + 5, l.w*l.h*l.n*l.batch, 1, l.output_gpu + 5);
}
float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
float *truth_cpu = 0;
@ -347,22 +405,22 @@ void forward_region_layer_gpu(const region_layer l, network_state state)
truth_cpu = calloc(num_truth, sizeof(float));
cuda_pull_array(state.truth, truth_cpu, num_truth);
}
cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
cuda_pull_array(l.output_gpu, in_cpu, l.batch*l.inputs);
network_state cpu_state = state;
cpu_state.train = state.train;
cpu_state.truth = truth_cpu;
cpu_state.input = in_cpu;
forward_region_layer(l, cpu_state);
cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs);
//cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
free(cpu_state.input);
if(!state.train) return;
cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs);
if(cpu_state.truth) free(cpu_state.truth);
}
void backward_region_layer_gpu(region_layer l, network_state state)
{
axpy_ongpu(l.batch*l.outputs, 1, l.delta_gpu, 1, state.delta, 1);
//copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1);
flatten_ongpu(l.delta_gpu, l.h*l.w, l.n*(l.coords + l.classes + 1), l.batch, 0, state.delta);
}
#endif

View File

@ -10,6 +10,7 @@ region_layer make_region_layer(int batch, int h, int w, int n, int classes, int
void forward_region_layer(const region_layer l, network_state state);
void backward_region_layer(const region_layer l, network_state state);
void get_region_boxes(layer l, int w, int h, float thresh, float **probs, box *boxes, int only_objectness);
void resize_region_layer(layer *l, int w, int h);
#ifdef GPU
void forward_region_layer_gpu(const region_layer l, network_state state);

View File

@ -22,6 +22,7 @@ layer make_reorg_layer(int batch, int h, int w, int c, int stride, int reverse)
l.out_h = h/stride;
l.out_c = c*(stride*stride);
}
l.reverse = reverse;
fprintf(stderr, "Reorg Layer: %d x %d x %d image -> %d x %d x %d image, \n", w,h,c,l.out_w, l.out_h, l.out_c);
l.outputs = l.out_h * l.out_w * l.out_c;
l.inputs = h*w*c;
@ -44,12 +45,20 @@ layer make_reorg_layer(int batch, int h, int w, int c, int stride, int reverse)
void resize_reorg_layer(layer *l, int w, int h)
{
int stride = l->stride;
int c = l->c;
l->h = h;
l->w = w;
l->out_w = w*stride;
l->out_h = h*stride;
if(l->reverse){
l->out_w = w*stride;
l->out_h = h*stride;
l->out_c = c/(stride*stride);
}else{
l->out_w = w/stride;
l->out_h = h/stride;
l->out_c = c*(stride*stride);
}
l->outputs = l->out_h * l->out_w * l->out_c;
l->inputs = l->outputs;

View File

@ -36,6 +36,40 @@ route_layer make_route_layer(int batch, int n, int *input_layers, int *input_siz
return l;
}
void resize_route_layer(route_layer *l, network *net)
{
int i;
layer first = net->layers[l->input_layers[0]];
l->out_w = first.out_w;
l->out_h = first.out_h;
l->out_c = first.out_c;
l->outputs = first.outputs;
l->input_sizes[0] = first.outputs;
for(i = 1; i < l->n; ++i){
int index = l->input_layers[i];
layer next = net->layers[index];
l->outputs += next.outputs;
l->input_sizes[i] = next.outputs;
if(next.out_w == first.out_w && next.out_h == first.out_h){
l->out_c += next.out_c;
}else{
printf("%d %d, %d %d\n", next.out_w, next.out_h, first.out_w, first.out_h);
l->out_h = l->out_w = l->out_c = 0;
}
}
l->inputs = l->outputs;
l->delta = realloc(l->delta, l->outputs*l->batch*sizeof(float));
l->output = realloc(l->output, l->outputs*l->batch*sizeof(float));
#ifdef GPU
cuda_free(l->output_gpu);
cuda_free(l->delta_gpu);
l->output_gpu = cuda_make_array(l->output, l->outputs*l->batch);
l->delta_gpu = cuda_make_array(l->delta, l->outputs*l->batch);
#endif
}
void forward_route_layer(const route_layer l, network_state state)
{
int i, j;

View File

@ -8,6 +8,7 @@ typedef layer route_layer;
route_layer make_route_layer(int batch, int n, int *input_layers, int *input_size);
void forward_route_layer(const route_layer l, network_state state);
void backward_route_layer(const route_layer l, network_state state);
void resize_route_layer(route_layer *l, network *net);
#ifdef GPU
void forward_route_layer_gpu(const route_layer l, network_state state);

View File

@ -24,6 +24,16 @@ void change_leaves(tree *t, char *leaf_list)
fprintf(stderr, "Found %d leaves.\n", found);
}
float get_hierarchy_probability(float *x, tree *hier, int c)
{
float p = 1;
while(c >= 0){
p = p * x[c];
c = hier->parent[c];
}
return p;
}
void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves)
{
int j;

View File

@ -16,5 +16,6 @@ typedef struct{
tree *read_tree(char *filename);
void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves);
void change_leaves(tree *t, char *leaf_list);
float get_hierarchy_probability(float *x, tree *hier, int c);
#endif

View File

@ -9,6 +9,21 @@
#include "utils.h"
int *read_map(char *filename)
{
int n = 0;
int *map = 0;
char *str;
FILE *file = fopen(filename, "r");
if(!file) file_error(filename);
while((str=fgetl(file))){
++n;
map = realloc(map, n*sizeof(int));
map[n-1] = atoi(str);
}
return map;
}
void sorta_shuffle(void *arr, size_t n, size_t size, size_t sections)
{
size_t i;

View File

@ -7,6 +7,7 @@
#define SECRET_NUM -1234
#define TWO_PI 6.2831853071795864769252866
int *read_map(char *filename);
void shuffle(void *arr, size_t n, size_t size);
void sorta_shuffle(void *arr, size_t n, size_t size, size_t sections);
void free_ptrs(void **ptrs, int n);