good chance I didn't break anything

This commit is contained in:
Joseph Redmon 2016-09-12 13:55:20 -07:00
parent 8ec889f103
commit 5c067dc447
19 changed files with 558 additions and 298 deletions

View File

@ -1,85 +1,91 @@
[net]
batch=128
subdivisions=1
height=256
width=256
height=227
width=227
channels=3
learning_rate=0.01
momentum=0.9
decay=0.0005
max_crop=256
[crop]
crop_height=224
crop_width=224
flip=1
angle=0
saturation=1
exposure=1
learning_rate=0.01
policy=poly
power=4
max_batches=800000
angle=7
hue = .1
saturation=.75
exposure=.75
aspect=.75
[convolutional]
filters=64
filters=96
size=11
stride=4
pad=0
activation=ramp
activation=relu
[maxpool]
size=3
stride=2
padding=0
[convolutional]
filters=192
filters=256
size=5
stride=1
pad=1
activation=ramp
activation=relu
[maxpool]
size=3
stride=2
padding=0
[convolutional]
filters=384
size=3
stride=1
pad=1
activation=ramp
activation=relu
[convolutional]
filters=384
size=3
stride=1
pad=1
activation=relu
[convolutional]
filters=256
size=3
stride=1
pad=1
activation=ramp
[convolutional]
filters=256
size=3
stride=1
pad=1
activation=ramp
activation=relu
[maxpool]
size=3
stride=2
padding=0
[connected]
output=4096
activation=ramp
activation=relu
[dropout]
probability=.5
[connected]
output=4096
activation=ramp
activation=relu
[dropout]
probability=.5
[connected]
output=1000
activation=ramp
activation=linear
[softmax]
groups=1

View File

@ -5,6 +5,7 @@
#include "blas.h"
#include "assert.h"
#include "classifier.h"
#include "cuda.h"
#include <sys/time.h>
#ifdef OPENCV
@ -51,6 +52,134 @@ float *get_regression_values(char **labels, int n)
return v;
}
void train_classifier_multi(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
{
#ifdef GPU
int nthreads = 8;
int i;
data_seed = time(0);
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
printf("%s\n", base);
printf("%d\n", ngpus);
network *nets = calloc(ngpus, sizeof(network));
for(i = 0; i < ngpus; ++i){
cuda_set_device(gpus[i]);
nets[i] = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&(nets[i]), weightfile);
}
if(clear) *nets[i].seen = 0;
}
network net = nets[0];
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = net.batch*ngpus/nthreads;
assert(net.batch*ngpus % nthreads == 0);
list *options = read_data_cfg(datacfg);
char *backup_directory = option_find_str(options, "backup", "/backup/");
char *label_list = option_find_str(options, "labels", "data/labels.list");
char *train_list = option_find_str(options, "train", "data/train.list");
int classes = option_find_int(options, "classes", 2);
char **labels = get_labels(label_list);
list *plist = get_paths(train_list);
char **paths = (char **)list_to_array(plist);
printf("%d\n", plist->size);
int N = plist->size;
clock_t time;
pthread_t *load_threads = calloc(nthreads, sizeof(pthread_t));
data *trains = calloc(nthreads, sizeof(data));
data *buffers = calloc(nthreads, sizeof(data));
load_args args = {0};
args.w = net.w;
args.h = net.h;
args.min = net.min_crop;
args.max = net.max_crop;
args.angle = net.angle;
args.aspect = net.aspect;
args.exposure = net.exposure;
args.saturation = net.saturation;
args.hue = net.hue;
args.size = net.w;
args.paths = paths;
args.classes = classes;
args.n = imgs;
args.m = N;
args.labels = labels;
args.type = CLASSIFICATION_DATA;
for(i = 0; i < nthreads; ++i){
args.d = buffers + i;
load_threads[i] = load_data_in_thread(args);
}
int epoch = (*net.seen)/N;
while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
time=clock();
for(i = 0; i < nthreads; ++i){
pthread_join(load_threads[i], 0);
trains[i] = buffers[i];
}
data train = concat_datas(trains, nthreads);
for(i = 0; i < nthreads; ++i){
args.d = buffers + i;
load_threads[i] = load_data_in_thread(args);
}
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
float loss = train_networks(nets, ngpus, train);
if(avg_loss == -1) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
free_data(train);
for(i = 0; i < nthreads; ++i){
free_data(trains[i]);
}
if(*net.seen/N > epoch){
epoch = *net.seen/N;
char buff[256];
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
save_weights(net, buff);
}
if(get_current_batch(net)%100 == 0){
char buff[256];
sprintf(buff, "%s/%s.backup",backup_directory,base);
save_weights(net, buff);
}
}
char buff[256];
sprintf(buff, "%s/%s.weights", backup_directory, base);
save_weights(net, buff);
for(i = 0; i < nthreads; ++i){
pthread_join(load_threads[i], 0);
free_data(buffers[i]);
}
free(buffers);
free(trains);
free(load_threads);
free_network(net);
free_ptrs((void**)labels, classes);
free_ptrs((void**)paths, plist->size);
free_list(plist);
free(base);
#endif
}
void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear)
{
int nthreads = 8;
@ -130,7 +259,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear)
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
#ifdef OPENCV
#ifdef OPENCV
if(0){
int u;
for(u = 0; u < imgs; ++u){
@ -139,7 +268,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear)
cvWaitKey(0);
}
}
#endif
#endif
float loss = train_network(net, train);
if(avg_loss == -1) avg_loss = loss;
@ -546,29 +675,29 @@ void try_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filena
float *X = im.data;
time=clock();
float *predictions = network_predict(net, X);
layer l = net.layers[layer_num];
for(i = 0; i < l.c; ++i){
if(l.rolling_mean) printf("%f %f %f\n", l.rolling_mean[i], l.rolling_variance[i], l.scales[i]);
if(l.rolling_mean) printf("%f %f %f\n", l.rolling_mean[i], l.rolling_variance[i], l.scales[i]);
}
#ifdef GPU
#ifdef GPU
cuda_pull_array(l.output_gpu, l.output, l.outputs);
#endif
#endif
for(i = 0; i < l.outputs; ++i){
printf("%f\n", l.output[i]);
}
/*
printf("\n\nWeights\n");
for(i = 0; i < l.n*l.size*l.size*l.c; ++i){
printf("%f\n", l.filters[i]);
}
printf("\n\nBiases\n");
for(i = 0; i < l.n; ++i){
printf("%f\n", l.biases[i]);
}
*/
printf("\n\nWeights\n");
for(i = 0; i < l.n*l.size*l.size*l.c; ++i){
printf("%f\n", l.filters[i]);
}
printf("\n\nBiases\n");
for(i = 0; i < l.n; ++i){
printf("%f\n", l.biases[i]);
}
*/
top_predictions(net, top, indexes);
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
@ -794,15 +923,15 @@ void threat_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_i
if(!in.data) break;
image in_s = resize_image(in, net.w, net.h);
image out = in;
int x1 = out.w / 20;
int y1 = out.h / 20;
int x2 = 2*x1;
int y2 = out.h - out.h/20;
image out = in;
int x1 = out.w / 20;
int y1 = out.h / 20;
int x2 = 2*x1;
int y2 = out.h - out.h/20;
int border = .01*out.h;
int h = y2 - y1 - 2*border;
int w = x2 - x1 - 2*border;
int border = .01*out.h;
int h = y2 - y1 - 2*border;
int w = x2 - x1 - 2*border;
float *predictions = network_predict(net, in_s.data);
float curr_threat = predictions[0] * 0 + predictions[1] * .6 + predictions[2];
@ -821,11 +950,11 @@ void threat_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_i
y1 + .02*h + 3*border, .5*border, 0,0,0);
draw_box_width(out, x2 + border, y1 + .42*h, x2 + .5 * w, y1 + .42*h + border, border, 0,0,0);
if(threat > .57) {
draw_box_width(out, x2 + .5 * w + border,
y1 + .42*h - 2*border,
x2 + .5 * w + 6*border,
y1 + .42*h + 3*border, 3*border, 1,1,0);
}
draw_box_width(out, x2 + .5 * w + border,
y1 + .42*h - 2*border,
x2 + .5 * w + 6*border,
y1 + .42*h + 3*border, 3*border, 1,1,0);
}
draw_box_width(out, x2 + .5 * w + border,
y1 + .42*h - 2*border,
x2 + .5 * w + 6*border,
@ -942,6 +1071,24 @@ void run_classifier(int argc, char **argv)
return;
}
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
int *gpus = 0;
int ngpus = 0;
if(gpu_list){
printf("%s\n", gpu_list);
int len = strlen(gpu_list);
ngpus = 1;
int i;
for(i = 0; i < len; ++i){
if (gpu_list[i] == ',') ++ngpus;
}
gpus = calloc(ngpus, sizeof(int));
for(i = 0; i < ngpus; ++i){
gpus[i] = atoi(gpu_list);
gpu_list = strchr(gpu_list, ',')+1;
}
}
int cam_index = find_int_arg(argc, argv, "-c", 0);
int clear = find_arg(argc, argv, "-clear");
char *data = argv[3];
@ -953,6 +1100,7 @@ void run_classifier(int argc, char **argv)
if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename);
else if(0==strcmp(argv[2], "try")) try_classifier(data, cfg, weights, filename, atoi(layer_s));
else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights, clear);
else if(0==strcmp(argv[2], "trainm")) train_classifier_multi(data, cfg, weights, gpus, ngpus, clear);
else if(0==strcmp(argv[2], "demo")) demo_classifier(data, cfg, weights, cam_index, filename);
else if(0==strcmp(argv[2], "threat")) threat_classifier(data, cfg, weights, cam_index, filename);
else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights, layer);

View File

@ -48,25 +48,25 @@ void binarize_input_gpu(float *input, int n, int size, float *binary)
}
__global__ void binarize_filters_kernel(float *filters, int n, int size, float *binary)
__global__ void binarize_weights_kernel(float *weights, int n, int size, float *binary)
{
int f = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (f >= n) return;
int i = 0;
float mean = 0;
for(i = 0; i < size; ++i){
mean += abs(filters[f*size + i]);
mean += abs(weights[f*size + i]);
}
mean = mean / size;
for(i = 0; i < size; ++i){
binary[f*size + i] = (filters[f*size + i] > 0) ? mean : -mean;
//binary[f*size + i] = filters[f*size + i];
binary[f*size + i] = (weights[f*size + i] > 0) ? mean : -mean;
//binary[f*size + i] = weights[f*size + i];
}
}
void binarize_filters_gpu(float *filters, int n, int size, float *binary)
void binarize_weights_gpu(float *weights, int n, int size, float *binary)
{
binarize_filters_kernel<<<cuda_gridsize(n), BLOCK>>>(filters, n, size, binary);
binarize_weights_kernel<<<cuda_gridsize(n), BLOCK>>>(weights, n, size, binary);
check_error(cudaPeekAtLastError());
}
@ -74,12 +74,12 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
{
fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
if(l.binary){
binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
swap_binary(&l);
}
if(l.xnor){
binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
swap_binary(&l);
binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
state.input = l.binary_input_gpu;
@ -91,8 +91,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
&one,
l.srcTensorDesc,
state.input,
l.filterDesc,
l.filters_gpu,
l.weightDesc,
l.weights_gpu,
l.convDesc,
l.fw_algo,
state.workspace,
@ -108,7 +108,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
int n = l.out_w*l.out_h;
for(i = 0; i < l.batch; ++i){
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace);
float * a = l.filters_gpu;
float * a = l.weights_gpu;
float * b = state.workspace;
float * c = l.output_gpu;
gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
@ -150,15 +150,15 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
state.workspace,
l.workspace_size,
&one,
l.dfilterDesc,
l.filter_updates_gpu);
l.dweightDesc,
l.weight_updates_gpu);
if(state.delta){
if(l.binary || l.xnor) swap_binary(&l);
cudnnConvolutionBackwardData(cudnn_handle(),
&one,
l.filterDesc,
l.filters_gpu,
l.weightDesc,
l.weights_gpu,
l.ddstTensorDesc,
l.delta_gpu,
l.convDesc,
@ -181,14 +181,14 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
for(i = 0; i < l.batch; ++i){
float * a = l.delta_gpu;
float * b = state.workspace;
float * c = l.filter_updates_gpu;
float * c = l.weight_updates_gpu;
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace);
gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
if(state.delta){
if(l.binary || l.xnor) swap_binary(&l);
float * a = l.filters_gpu;
float * a = l.weights_gpu;
float * b = l.delta_gpu;
float * c = state.workspace;
@ -206,9 +206,9 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
void pull_convolutional_layer(convolutional_layer layer)
{
cuda_pull_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
cuda_pull_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
cuda_pull_array(layer.biases_gpu, layer.biases, layer.n);
cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
cuda_pull_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
if (layer.batch_normalize){
cuda_pull_array(layer.scales_gpu, layer.scales, layer.n);
@ -219,9 +219,9 @@ void pull_convolutional_layer(convolutional_layer layer)
void push_convolutional_layer(convolutional_layer layer)
{
cuda_push_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
cuda_push_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
if (layer.batch_normalize){
cuda_push_array(layer.scales_gpu, layer.scales, layer.n);
@ -240,9 +240,9 @@ void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float
axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1);
scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
axpy_ongpu(size, -decay*batch, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
axpy_ongpu(size, learning_rate/batch, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
scal_ongpu(size, momentum, layer.filter_updates_gpu, 1);
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);
}

View File

@ -19,28 +19,28 @@ void forward_xnor_layer(layer l, network_state state);
void swap_binary(convolutional_layer *l)
{
float *swap = l->filters;
l->filters = l->binary_filters;
l->binary_filters = swap;
float *swap = l->weights;
l->weights = l->binary_weights;
l->binary_weights = swap;
#ifdef GPU
swap = l->filters_gpu;
l->filters_gpu = l->binary_filters_gpu;
l->binary_filters_gpu = swap;
swap = l->weights_gpu;
l->weights_gpu = l->binary_weights_gpu;
l->binary_weights_gpu = swap;
#endif
}
void binarize_filters(float *filters, int n, int size, float *binary)
void binarize_weights(float *weights, int n, int size, float *binary)
{
int i, f;
for(f = 0; f < n; ++f){
float mean = 0;
for(i = 0; i < size; ++i){
mean += fabs(filters[f*size + i]);
mean += fabs(weights[f*size + i]);
}
mean = mean / size;
for(i = 0; i < size; ++i){
binary[f*size + i] = (filters[f*size + i] > 0) ? mean : -mean;
binary[f*size + i] = (weights[f*size + i] > 0) ? mean : -mean;
}
}
}
@ -103,7 +103,7 @@ size_t get_workspace_size(layer l){
size_t s = 0;
cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
l.srcTensorDesc,
l.filterDesc,
l.weightDesc,
l.convDesc,
l.dstTensorDesc,
l.fw_algo,
@ -113,12 +113,12 @@ size_t get_workspace_size(layer l){
l.srcTensorDesc,
l.ddstTensorDesc,
l.convDesc,
l.dfilterDesc,
l.dweightDesc,
l.bf_algo,
&s);
if (s > most) most = s;
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
l.filterDesc,
l.weightDesc,
l.ddstTensorDesc,
l.convDesc,
l.dsrcTensorDesc,
@ -137,22 +137,22 @@ void cudnn_convolutional_setup(layer *l)
{
cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
cudnnSetFilter4dDescriptor(l->dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
cudnnSetFilter4dDescriptor(l->dweightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
cudnnSetFilter4dDescriptor(l->filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
cudnnSetFilter4dDescriptor(l->weightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
l->srcTensorDesc,
l->filterDesc,
l->weightDesc,
l->convDesc,
l->dstTensorDesc,
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
0,
&l->fw_algo);
cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
l->filterDesc,
l->weightDesc,
l->ddstTensorDesc,
l->convDesc,
l->dsrcTensorDesc,
@ -163,7 +163,7 @@ void cudnn_convolutional_setup(layer *l)
l->srcTensorDesc,
l->ddstTensorDesc,
l->convDesc,
l->dfilterDesc,
l->dweightDesc,
CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
0,
&l->bf_algo);
@ -189,15 +189,15 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.pad = padding;
l.batch_normalize = batch_normalize;
l.filters = calloc(c*n*size*size, sizeof(float));
l.filter_updates = calloc(c*n*size*size, sizeof(float));
l.weights = calloc(c*n*size*size, sizeof(float));
l.weight_updates = calloc(c*n*size*size, sizeof(float));
l.biases = calloc(n, sizeof(float));
l.bias_updates = calloc(n, sizeof(float));
// float scale = 1./sqrt(size*size*c);
float scale = sqrt(2./(size*size*c));
for(i = 0; i < c*n*size*size; ++i) l.filters[i] = scale*rand_uniform(-1, 1);
for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_uniform(-1, 1);
int out_h = convolutional_out_height(l);
int out_w = convolutional_out_width(l);
l.out_h = out_h;
@ -210,12 +210,12 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float));
if(binary){
l.binary_filters = calloc(c*n*size*size, sizeof(float));
l.cfilters = calloc(c*n*size*size, sizeof(char));
l.binary_weights = calloc(c*n*size*size, sizeof(float));
l.cweights = calloc(c*n*size*size, sizeof(char));
l.scales = calloc(n, sizeof(float));
}
if(xnor){
l.binary_filters = calloc(c*n*size*size, sizeof(float));
l.binary_weights = calloc(c*n*size*size, sizeof(float));
l.binary_input = calloc(l.inputs*l.batch, sizeof(float));
}
@ -235,8 +235,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
#ifdef GPU
if(gpu_index >= 0){
l.filters_gpu = cuda_make_array(l.filters, c*n*size*size);
l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size);
l.weights_gpu = cuda_make_array(l.weights, c*n*size*size);
l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
l.biases_gpu = cuda_make_array(l.biases, n);
l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
@ -248,10 +248,10 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
if(binary){
l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size);
}
if(xnor){
l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size);
l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
}
@ -271,10 +271,10 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
#ifdef CUDNN
cudnnCreateTensorDescriptor(&l.srcTensorDesc);
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
cudnnCreateFilterDescriptor(&l.filterDesc);
cudnnCreateFilterDescriptor(&l.weightDesc);
cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
cudnnCreateFilterDescriptor(&l.dfilterDesc);
cudnnCreateFilterDescriptor(&l.dweightDesc);
cudnnCreateConvolutionDescriptor(&l.convDesc);
cudnn_convolutional_setup(&l);
#endif
@ -294,7 +294,7 @@ void denormalize_convolutional_layer(convolutional_layer l)
for(i = 0; i < l.n; ++i){
float scale = l.scales[i]/sqrt(l.rolling_variance[i] + .00001);
for(j = 0; j < l.c*l.size*l.size; ++j){
l.filters[i*l.c*l.size*l.size + j] *= scale;
l.weights[i*l.c*l.size*l.size + j] *= scale;
}
l.biases[i] -= l.rolling_mean[i] * scale;
l.scales[i] = 1;
@ -403,8 +403,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
/*
if(l.binary){
binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
binarize_weights2(l.weights, l.n, l.c*l.size*l.size, l.cweights, l.scales);
swap_binary(&l);
}
*/
@ -415,7 +415,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
int k = l.size*l.size*l.c;
int n = out_h*out_w;
char *a = l.cfilters;
char *a = l.cweights;
float *b = state.workspace;
float *c = l.output;
@ -434,7 +434,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
*/
if(l.xnor){
binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
swap_binary(&l);
binarize_cpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input);
state.input = l.binary_input;
@ -449,7 +449,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
printf("xnor\n");
} else {
float *a = l.filters;
float *a = l.weights;
float *b = state.workspace;
float *c = l.output;
@ -485,7 +485,7 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
for(i = 0; i < l.batch; ++i){
float *a = l.delta + i*m*k;
float *b = state.workspace;
float *c = l.filter_updates;
float *c = l.weight_updates;
float *im = state.input+i*l.c*l.h*l.w;
@ -494,7 +494,7 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
if(state.delta){
a = l.filters;
a = l.weights;
b = l.delta + i*m*k;
c = state.workspace;
@ -511,36 +511,36 @@ void update_convolutional_layer(convolutional_layer l, int batch, float learning
axpy_cpu(l.n, learning_rate/batch, l.bias_updates, 1, l.biases, 1);
scal_cpu(l.n, momentum, l.bias_updates, 1);
axpy_cpu(size, -decay*batch, l.filters, 1, l.filter_updates, 1);
axpy_cpu(size, learning_rate/batch, l.filter_updates, 1, l.filters, 1);
scal_cpu(size, momentum, l.filter_updates, 1);
axpy_cpu(size, -decay*batch, l.weights, 1, l.weight_updates, 1);
axpy_cpu(size, learning_rate/batch, l.weight_updates, 1, l.weights, 1);
scal_cpu(size, momentum, l.weight_updates, 1);
}
image get_convolutional_filter(convolutional_layer l, int i)
image get_convolutional_weight(convolutional_layer l, int i)
{
int h = l.size;
int w = l.size;
int c = l.c;
return float_to_image(w,h,c,l.filters+i*h*w*c);
return float_to_image(w,h,c,l.weights+i*h*w*c);
}
void rgbgr_filters(convolutional_layer l)
void rgbgr_weights(convolutional_layer l)
{
int i;
for(i = 0; i < l.n; ++i){
image im = get_convolutional_filter(l, i);
image im = get_convolutional_weight(l, i);
if (im.c == 3) {
rgbgr_image(im);
}
}
}
void rescale_filters(convolutional_layer l, float scale, float trans)
void rescale_weights(convolutional_layer l, float scale, float trans)
{
int i;
for(i = 0; i < l.n; ++i){
image im = get_convolutional_filter(l, i);
image im = get_convolutional_weight(l, i);
if (im.c == 3) {
scale_image(im, scale);
float sum = sum_array(im.data, im.w*im.h*im.c);
@ -549,21 +549,21 @@ void rescale_filters(convolutional_layer l, float scale, float trans)
}
}
image *get_filters(convolutional_layer l)
image *get_weights(convolutional_layer l)
{
image *filters = calloc(l.n, sizeof(image));
image *weights = calloc(l.n, sizeof(image));
int i;
for(i = 0; i < l.n; ++i){
filters[i] = copy_image(get_convolutional_filter(l, i));
//normalize_image(filters[i]);
weights[i] = copy_image(get_convolutional_weight(l, i));
//normalize_image(weights[i]);
}
return filters;
return weights;
}
image *visualize_convolutional_layer(convolutional_layer l, char *window, image *prev_filters)
image *visualize_convolutional_layer(convolutional_layer l, char *window, image *prev_weights)
{
image *single_filters = get_filters(l);
show_images(single_filters, l.n, window);
image *single_weights = get_weights(l);
show_images(single_weights, l.n, window);
image delta = get_convolutional_image(l);
image dc = collapse_image_layers(delta, 1);
@ -572,6 +572,6 @@ image *visualize_convolutional_layer(convolutional_layer l, char *window, image
//show_image(dc, buff);
//save_image(dc, buff);
free_image(dc);
return single_filters;
return single_weights;
}

View File

@ -29,10 +29,10 @@ void denormalize_convolutional_layer(convolutional_layer l);
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
void forward_convolutional_layer(const convolutional_layer layer, network_state state);
void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay);
image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters);
void binarize_filters(float *filters, int n, int size, float *binary);
image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_weights);
void binarize_weights(float *weights, int n, int size, float *binary);
void swap_binary(convolutional_layer *l);
void binarize_filters2(float *filters, int n, int size, char *binary, float *scales);
void binarize_weights2(float *weights, int n, int size, char *binary, float *scales);
void backward_convolutional_layer(convolutional_layer layer, network_state state);
@ -41,12 +41,12 @@ void backward_bias(float *bias_updates, float *delta, int batch, int n, int size
image get_convolutional_image(convolutional_layer layer);
image get_convolutional_delta(convolutional_layer layer);
image get_convolutional_filter(convolutional_layer layer, int i);
image get_convolutional_weight(convolutional_layer layer, int i);
int convolutional_out_height(convolutional_layer layer);
int convolutional_out_width(convolutional_layer layer);
void rescale_filters(convolutional_layer l, float scale, float trans);
void rgbgr_filters(convolutional_layer l);
void rescale_weights(convolutional_layer l, float scale, float trans);
void rgbgr_weights(convolutional_layer l);
#endif

View File

@ -9,6 +9,20 @@ int gpu_index = 0;
#include <stdlib.h>
#include <time.h>
void cuda_set_device(int n)
{
gpu_index = n;
cudaError_t status = cudaSetDevice(n);
check_error(status);
}
int cuda_get_device()
{
int n = 0;
cudaError_t status = cudaGetDevice(&n);
check_error(status);
return n;
}
void check_error(cudaError_t status)
{
@ -38,8 +52,8 @@ dim3 cuda_gridsize(size_t n){
size_t x = k;
size_t y = 1;
if(x > 65535){
x = ceil(sqrt(k));
y = (n-1)/(x*BLOCK) + 1;
x = ceil(sqrt(k));
y = (n-1)/(x*BLOCK) + 1;
}
dim3 d = {x, y, 1};
//printf("%ld %ld %ld %ld\n", n, x, y, x*y*BLOCK);
@ -49,25 +63,27 @@ dim3 cuda_gridsize(size_t n){
#ifdef CUDNN
cudnnHandle_t cudnn_handle()
{
static int init = 0;
static cudnnHandle_t handle;
if(!init) {
cudnnCreate(&handle);
init = 1;
static int init[16] = {0};
static cudnnHandle_t handle[16];
int i = cuda_get_device();
if(!init[i]) {
cudnnCreate(&handle[i]);
init[i] = 1;
}
return handle;
return handle[i];
}
#endif
cublasHandle_t blas_handle()
{
static int init = 0;
static cublasHandle_t handle;
if(!init) {
cublasCreate(&handle);
init = 1;
static int init[16] = {0};
static cublasHandle_t handle[16];
int i = cuda_get_device();
if(!init[i]) {
cublasCreate(&handle[i]);
init[i] = 1;
}
return handle;
return handle[i];
}
float *cuda_make_array(float *x, size_t n)
@ -86,14 +102,15 @@ float *cuda_make_array(float *x, size_t n)
void cuda_random(float *x_gpu, size_t n)
{
static curandGenerator_t gen;
static int init = 0;
if(!init){
curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(gen, time(0));
init = 1;
static curandGenerator_t gen[16];
static int init[16] = {0};
int i = cuda_get_device();
if(!init[i]){
curandCreateGenerator(&gen[i], CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(gen[i], time(0));
init[i] = 1;
}
curandGenerateUniform(gen, x_gpu, n);
curandGenerateUniform(gen[i], x_gpu, n);
check_error(cudaPeekAtLastError());
}

View File

@ -21,6 +21,7 @@ float *cuda_make_array(float *x, size_t n);
int *cuda_make_int_array(size_t n);
void cuda_push_array(float *x_gpu, float *x, size_t n);
void cuda_pull_array(float *x_gpu, float *x, size_t n);
void cuda_set_device(int n);
void cuda_free(float *x_gpu);
void cuda_random(float *x_gpu, size_t n);
float cuda_compare(float *x_gpu, float *x, size_t n, char *s);

View File

@ -66,7 +66,7 @@ void average(int argc, char *argv[])
if(l.type == CONVOLUTIONAL){
int num = l.n*l.c*l.size*l.size;
axpy_cpu(l.n, 1, l.biases, 1, out.biases, 1);
axpy_cpu(num, 1, l.filters, 1, out.filters, 1);
axpy_cpu(num, 1, l.weights, 1, out.weights, 1);
}
if(l.type == CONNECTED){
axpy_cpu(l.outputs, 1, l.biases, 1, out.biases, 1);
@ -80,7 +80,7 @@ void average(int argc, char *argv[])
if(l.type == CONVOLUTIONAL){
int num = l.n*l.c*l.size*l.size;
scal_cpu(l.n, 1./n, l.biases, 1);
scal_cpu(num, 1./n, l.filters, 1);
scal_cpu(num, 1./n, l.weights, 1);
}
if(l.type == CONNECTED){
scal_cpu(l.outputs, 1./n, l.biases, 1);
@ -159,7 +159,7 @@ void rescale_net(char *cfgfile, char *weightfile, char *outfile)
for(i = 0; i < net.n; ++i){
layer l = net.layers[i];
if(l.type == CONVOLUTIONAL){
rescale_filters(l, 2, -.5);
rescale_weights(l, 2, -.5);
break;
}
}
@ -177,7 +177,7 @@ void rgbgr_net(char *cfgfile, char *weightfile, char *outfile)
for(i = 0; i < net.n; ++i){
layer l = net.layers[i];
if(l.type == CONVOLUTIONAL){
rgbgr_filters(l);
rgbgr_weights(l);
break;
}
}
@ -354,8 +354,7 @@ int main(int argc, char **argv)
gpu_index = -1;
#else
if(gpu_index >= 0){
cudaError_t status = cudaSetDevice(gpu_index);
check_error(status);
cuda_set_device(gpu_index);
}
#endif

View File

@ -27,7 +27,7 @@ extern "C" void forward_deconvolutional_layer_gpu(deconvolutional_layer layer, n
fill_ongpu(layer.outputs*layer.batch, 0, layer.output_gpu, 1);
for(i = 0; i < layer.batch; ++i){
float *a = layer.filters_gpu;
float *a = layer.weights_gpu;
float *b = state.input + i*layer.c*layer.h*layer.w;
float *c = layer.col_image_gpu;
@ -59,7 +59,7 @@ extern "C" void backward_deconvolutional_layer_gpu(deconvolutional_layer layer,
float *a = state.input + i*m*n;
float *b = layer.col_image_gpu;
float *c = layer.filter_updates_gpu;
float *c = layer.weight_updates_gpu;
im2col_ongpu(layer.delta_gpu + i*layer.n*size, layer.n, out_h, out_w,
layer.size, layer.stride, 0, b);
@ -70,7 +70,7 @@ extern "C" void backward_deconvolutional_layer_gpu(deconvolutional_layer layer,
int n = layer.h*layer.w;
int k = layer.size*layer.size*layer.n;
float *a = layer.filters_gpu;
float *a = layer.weights_gpu;
float *b = layer.col_image_gpu;
float *c = state.delta + i*n*m;
@ -81,17 +81,17 @@ extern "C" void backward_deconvolutional_layer_gpu(deconvolutional_layer layer,
extern "C" void pull_deconvolutional_layer(deconvolutional_layer layer)
{
cuda_pull_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
cuda_pull_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
cuda_pull_array(layer.biases_gpu, layer.biases, layer.n);
cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
cuda_pull_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
}
extern "C" void push_deconvolutional_layer(deconvolutional_layer layer)
{
cuda_push_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
cuda_push_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
}
@ -102,8 +102,8 @@ extern "C" void update_deconvolutional_layer_gpu(deconvolutional_layer layer, fl
axpy_ongpu(layer.n, learning_rate, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
axpy_ongpu(size, -decay, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
axpy_ongpu(size, learning_rate, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
scal_ongpu(size, momentum, layer.filter_updates_gpu, 1);
axpy_ongpu(size, -decay, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
axpy_ongpu(size, learning_rate, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);
}

View File

@ -57,13 +57,13 @@ deconvolutional_layer make_deconvolutional_layer(int batch, int h, int w, int c,
l.stride = stride;
l.size = size;
l.filters = calloc(c*n*size*size, sizeof(float));
l.filter_updates = calloc(c*n*size*size, sizeof(float));
l.weights = calloc(c*n*size*size, sizeof(float));
l.weight_updates = calloc(c*n*size*size, sizeof(float));
l.biases = calloc(n, sizeof(float));
l.bias_updates = calloc(n, sizeof(float));
float scale = 1./sqrt(size*size*c);
for(i = 0; i < c*n*size*size; ++i) l.filters[i] = scale*rand_normal();
for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_normal();
for(i = 0; i < n; ++i){
l.biases[i] = scale;
}
@ -81,8 +81,8 @@ deconvolutional_layer make_deconvolutional_layer(int batch, int h, int w, int c,
l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float));
#ifdef GPU
l.filters_gpu = cuda_make_array(l.filters, c*n*size*size);
l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size);
l.weights_gpu = cuda_make_array(l.weights, c*n*size*size);
l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
l.biases_gpu = cuda_make_array(l.biases, n);
l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
@ -137,7 +137,7 @@ void forward_deconvolutional_layer(const deconvolutional_layer l, network_state
fill_cpu(l.outputs*l.batch, 0, l.output, 1);
for(i = 0; i < l.batch; ++i){
float *a = l.filters;
float *a = l.weights;
float *b = state.input + i*l.c*l.h*l.w;
float *c = l.col_image;
@ -167,7 +167,7 @@ void backward_deconvolutional_layer(deconvolutional_layer l, network_state state
float *a = state.input + i*m*n;
float *b = l.col_image;
float *c = l.filter_updates;
float *c = l.weight_updates;
im2col_cpu(l.delta + i*l.n*size, l.n, out_h, out_w,
l.size, l.stride, 0, b);
@ -178,7 +178,7 @@ void backward_deconvolutional_layer(deconvolutional_layer l, network_state state
int n = l.h*l.w;
int k = l.size*l.size*l.n;
float *a = l.filters;
float *a = l.weights;
float *b = l.col_image;
float *c = state.delta + i*n*m;
@ -193,9 +193,9 @@ void update_deconvolutional_layer(deconvolutional_layer l, float learning_rate,
axpy_cpu(l.n, learning_rate, l.bias_updates, 1, l.biases, 1);
scal_cpu(l.n, momentum, l.bias_updates, 1);
axpy_cpu(size, -decay, l.filters, 1, l.filter_updates, 1);
axpy_cpu(size, learning_rate, l.filter_updates, 1, l.filters, 1);
scal_cpu(size, momentum, l.filter_updates, 1);
axpy_cpu(size, -decay, l.weights, 1, l.weight_updates, 1);
axpy_cpu(size, learning_rate, l.weight_updates, 1, l.weights, 1);
scal_cpu(size, momentum, l.weight_updates, 1);
}

View File

@ -117,12 +117,18 @@ static void convert_detections(float *predictions, int classes, int num, int squ
int box_index = index * (classes + 5);
boxes[index].x = (predictions[box_index + 0] + col + .5) / side * w;
boxes[index].y = (predictions[box_index + 1] + row + .5) / side * h;
if(1){
if(0){
boxes[index].x = (logistic_activate(predictions[box_index + 0]) + col) / side * w;
boxes[index].y = (logistic_activate(predictions[box_index + 1]) + row) / side * h;
}
boxes[index].w = pow(logistic_activate(predictions[box_index + 2]), (square?2:1)) * w;
boxes[index].h = pow(logistic_activate(predictions[box_index + 3]), (square?2:1)) * h;
if(1){
boxes[index].x = ((col + .5)/side + predictions[box_index + 0] * .5) * w;
boxes[index].y = ((row + .5)/side + predictions[box_index + 1] * .5) * h;
boxes[index].w = (exp(predictions[box_index + 2]) * .5) * w;
boxes[index].h = (exp(predictions[box_index + 3]) * .5) * h;
}
for(j = 0; j < classes; ++j){
int class_index = index * (classes + 5) + 5;
float prob = scale*predictions[class_index+j];

View File

@ -14,8 +14,6 @@ void free_layer(layer l)
if(l.indexes) free(l.indexes);
if(l.rand) free(l.rand);
if(l.cost) free(l.cost);
if(l.filters) free(l.filters);
if(l.filter_updates) free(l.filter_updates);
if(l.biases) free(l.biases);
if(l.bias_updates) free(l.bias_updates);
if(l.weights) free(l.weights);
@ -30,8 +28,8 @@ void free_layer(layer l)
#ifdef GPU
if(l.indexes_gpu) cuda_free((float *)l.indexes_gpu);
if(l.filters_gpu) cuda_free(l.filters_gpu);
if(l.filter_updates_gpu) cuda_free(l.filter_updates_gpu);
if(l.weights_gpu) cuda_free(l.weights_gpu);
if(l.weight_updates_gpu) cuda_free(l.weight_updates_gpu);
if(l.col_image_gpu) cuda_free(l.col_image_gpu);
if(l.weights_gpu) cuda_free(l.weights_gpu);
if(l.biases_gpu) cuda_free(l.biases_gpu);

View File

@ -105,9 +105,7 @@ struct layer{
int *indexes;
float *rand;
float *cost;
float *filters;
char *cfilters;
float *filter_updates;
char *cweights;
float *state;
float *prev_state;
float *forgot_state;
@ -117,7 +115,7 @@ struct layer{
float *concat;
float *concat_delta;
float *binary_filters;
float *binary_weights;
float *biases;
float *bias_updates;
@ -194,11 +192,9 @@ struct layer{
float * save_delta_gpu;
float * concat_gpu;
float * concat_delta_gpu;
float * filters_gpu;
float * filter_updates_gpu;
float *binary_input_gpu;
float *binary_filters_gpu;
float *binary_weights_gpu;
float * mean_gpu;
float * variance_gpu;
@ -230,8 +226,8 @@ struct layer{
#ifdef CUDNN
cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;
cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc;
cudnnFilterDescriptor_t filterDesc;
cudnnFilterDescriptor_t dfilterDesc;
cudnnFilterDescriptor_t weightDesc;
cudnnFilterDescriptor_t dweightDesc;
cudnnConvolutionDescriptor_t convDesc;
cudnnConvolutionFwdAlgo_t fw_algo;
cudnnConvolutionBwdDataAlgo_t bd_algo;

View File

@ -47,23 +47,23 @@ local_layer make_local_layer(int batch, int h, int w, int c, int n, int size, in
l.outputs = l.out_h * l.out_w * l.out_c;
l.inputs = l.w * l.h * l.c;
l.filters = calloc(c*n*size*size*locations, sizeof(float));
l.filter_updates = calloc(c*n*size*size*locations, sizeof(float));
l.weights = calloc(c*n*size*size*locations, sizeof(float));
l.weight_updates = calloc(c*n*size*size*locations, sizeof(float));
l.biases = calloc(l.outputs, sizeof(float));
l.bias_updates = calloc(l.outputs, sizeof(float));
// float scale = 1./sqrt(size*size*c);
float scale = sqrt(2./(size*size*c));
for(i = 0; i < c*n*size*size; ++i) l.filters[i] = scale*rand_uniform(-1,1);
for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_uniform(-1,1);
l.col_image = calloc(out_h*out_w*size*size*c, sizeof(float));
l.output = calloc(l.batch*out_h * out_w * n, sizeof(float));
l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float));
#ifdef GPU
l.filters_gpu = cuda_make_array(l.filters, c*n*size*size*locations);
l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size*locations);
l.weights_gpu = cuda_make_array(l.weights, c*n*size*size*locations);
l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size*locations);
l.biases_gpu = cuda_make_array(l.biases, l.outputs);
l.bias_updates_gpu = cuda_make_array(l.bias_updates, l.outputs);
@ -97,7 +97,7 @@ void forward_local_layer(const local_layer l, network_state state)
l.size, l.stride, l.pad, l.col_image);
float *output = l.output + i*l.outputs;
for(j = 0; j < locations; ++j){
float *a = l.filters + j*l.size*l.size*l.c*l.n;
float *a = l.weights + j*l.size*l.size*l.c*l.n;
float *b = l.col_image + j;
float *c = output + j;
@ -130,7 +130,7 @@ void backward_local_layer(local_layer l, network_state state)
for(j = 0; j < locations; ++j){
float *a = l.delta + i*l.outputs + j;
float *b = l.col_image + j;
float *c = l.filter_updates + j*l.size*l.size*l.c*l.n;
float *c = l.weight_updates + j*l.size*l.size*l.c*l.n;
int m = l.n;
int n = l.size*l.size*l.c;
int k = 1;
@ -140,7 +140,7 @@ void backward_local_layer(local_layer l, network_state state)
if(state.delta){
for(j = 0; j < locations; ++j){
float *a = l.filters + j*l.size*l.size*l.c*l.n;
float *a = l.weights + j*l.size*l.size*l.c*l.n;
float *b = l.delta + i*l.outputs + j;
float *c = l.col_image + j;
@ -163,9 +163,9 @@ void update_local_layer(local_layer l, int batch, float learning_rate, float mom
axpy_cpu(l.outputs, learning_rate/batch, l.bias_updates, 1, l.biases, 1);
scal_cpu(l.outputs, momentum, l.bias_updates, 1);
axpy_cpu(size, -decay*batch, l.filters, 1, l.filter_updates, 1);
axpy_cpu(size, learning_rate/batch, l.filter_updates, 1, l.filters, 1);
scal_cpu(size, momentum, l.filter_updates, 1);
axpy_cpu(size, -decay*batch, l.weights, 1, l.weight_updates, 1);
axpy_cpu(size, learning_rate/batch, l.weight_updates, 1, l.weights, 1);
scal_cpu(size, momentum, l.weight_updates, 1);
}
#ifdef GPU
@ -187,7 +187,7 @@ void forward_local_layer_gpu(const local_layer l, network_state state)
l.size, l.stride, l.pad, l.col_image_gpu);
float *output = l.output_gpu + i*l.outputs;
for(j = 0; j < locations; ++j){
float *a = l.filters_gpu + j*l.size*l.size*l.c*l.n;
float *a = l.weights_gpu + j*l.size*l.size*l.c*l.n;
float *b = l.col_image_gpu + j;
float *c = output + j;
@ -219,7 +219,7 @@ void backward_local_layer_gpu(local_layer l, network_state state)
for(j = 0; j < locations; ++j){
float *a = l.delta_gpu + i*l.outputs + j;
float *b = l.col_image_gpu + j;
float *c = l.filter_updates_gpu + j*l.size*l.size*l.c*l.n;
float *c = l.weight_updates_gpu + j*l.size*l.size*l.c*l.n;
int m = l.n;
int n = l.size*l.size*l.c;
int k = 1;
@ -229,7 +229,7 @@ void backward_local_layer_gpu(local_layer l, network_state state)
if(state.delta){
for(j = 0; j < locations; ++j){
float *a = l.filters_gpu + j*l.size*l.size*l.c*l.n;
float *a = l.weights_gpu + j*l.size*l.size*l.c*l.n;
float *b = l.delta_gpu + i*l.outputs + j;
float *c = l.col_image_gpu + j;
@ -252,16 +252,16 @@ void update_local_layer_gpu(local_layer l, int batch, float learning_rate, float
axpy_ongpu(l.outputs, learning_rate/batch, l.bias_updates_gpu, 1, l.biases_gpu, 1);
scal_ongpu(l.outputs, momentum, l.bias_updates_gpu, 1);
axpy_ongpu(size, -decay*batch, l.filters_gpu, 1, l.filter_updates_gpu, 1);
axpy_ongpu(size, learning_rate/batch, l.filter_updates_gpu, 1, l.filters_gpu, 1);
scal_ongpu(size, momentum, l.filter_updates_gpu, 1);
axpy_ongpu(size, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1);
axpy_ongpu(size, learning_rate/batch, l.weight_updates_gpu, 1, l.weights_gpu, 1);
scal_ongpu(size, momentum, l.weight_updates_gpu, 1);
}
void pull_local_layer(local_layer l)
{
int locations = l.out_w*l.out_h;
int size = l.size*l.size*l.c*l.n*locations;
cuda_pull_array(l.filters_gpu, l.filters, size);
cuda_pull_array(l.weights_gpu, l.weights, size);
cuda_pull_array(l.biases_gpu, l.biases, l.outputs);
}
@ -269,7 +269,7 @@ void push_local_layer(local_layer l)
{
int locations = l.out_w*l.out_h;
int size = l.size*l.size*l.c*l.n*locations;
cuda_push_array(l.filters_gpu, l.filters, size);
cuda_push_array(l.weights_gpu, l.weights, size);
cuda_push_array(l.biases_gpu, l.biases, l.outputs);
}
#endif

View File

@ -318,11 +318,11 @@ void backward_network(network net, network_state state)
float train_network_datum(network net, float *x, float *y)
{
*net.seen += net.batch;
#ifdef GPU
if(gpu_index >= 0) return train_network_datum_gpu(net, x, y);
#endif
network_state state;
*net.seen += net.batch;
state.index = 0;
state.net = net;
state.input = x;

View File

@ -65,6 +65,7 @@ typedef struct network_state {
} network_state;
#ifdef GPU
float train_networks(network *nets, int n, data d);
float train_network_datum_gpu(network net, float *x, float *y);
float *network_predict_gpu(network net, float *input);
float * get_network_output_gpu_layer(network net, int i);

View File

@ -209,6 +209,7 @@ void forward_backward_network_gpu(network net, float *x, float *y)
float train_network_datum_gpu(network net, float *x, float *y)
{
*net.seen += net.batch;
forward_backward_network_gpu(net, x, y);
float error = get_network_cost(net);
if (((*net.seen) / net.batch) % net.subdivisions == 0) update_network_gpu(net);
@ -226,25 +227,115 @@ void *train_thread(void *ptr)
{
train_args args = *(train_args*)ptr;
cudaError_t status = cudaSetDevice(args.net.gpu_index);
check_error(status);
cuda_set_device(args.net.gpu_index);
forward_backward_network_gpu(args.net, args.X, args.y);
free(ptr);
return 0;
}
pthread_t train_network_in_thread(train_args args)
pthread_t train_network_in_thread(network net, float *X, float *y)
{
pthread_t thread;
train_args *ptr = (train_args *)calloc(1, sizeof(train_args));
*ptr = args;
ptr->net = net;
ptr->X = X;
ptr->y = y;
if(pthread_create(&thread, 0, train_thread, ptr)) error("Thread creation failed");
return thread;
}
void pull_updates(layer l)
{
#ifdef GPU
if(l.type == CONVOLUTIONAL){
cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.n);
cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.n*l.size*l.size*l.c);
if(l.scale_updates) cuda_pull_array(l.scale_updates_gpu, l.scale_updates, l.n);
} else if(l.type == CONNECTED){
cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.outputs*l.inputs);
}
#endif
}
void push_updates(layer l)
{
#ifdef GPU
if(l.type == CONVOLUTIONAL){
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n);
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.n*l.size*l.size*l.c);
if(l.scale_updates) cuda_push_array(l.scale_updates_gpu, l.scale_updates, l.n);
} else if(l.type == CONNECTED){
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.outputs*l.inputs);
}
#endif
}
void merge_updates(layer l, layer base)
{
if (l.type == CONVOLUTIONAL) {
axpy_cpu(l.n, 1, l.bias_updates, 1, base.bias_updates, 1);
axpy_cpu(l.n*l.size*l.size*l.c, 1, l.weight_updates, 1, base.weight_updates, 1);
if (l.scale_updates) {
axpy_cpu(l.n, 1, l.scale_updates, 1, base.scale_updates, 1);
}
} else if(l.type == CONNECTED) {
axpy_cpu(l.outputs, 1, l.bias_updates, 1, base.bias_updates, 1);
axpy_cpu(l.outputs*l.inputs, 1, l.weight_updates, 1, base.weight_updates, 1);
}
}
void distribute_updates(layer l, layer base)
{
if (l.type == CONVOLUTIONAL) {
copy_cpu(l.n, base.bias_updates, 1, l.bias_updates, 1);
copy_cpu(l.n*l.size*l.size*l.c, base.weight_updates, 1, l.weight_updates, 1);
if (l.scale_updates) {
copy_cpu(l.n, base.scale_updates, 1, l.scale_updates, 1);
}
} else if(l.type == CONNECTED) {
copy_cpu(l.outputs, base.bias_updates, 1, l.bias_updates, 1);
copy_cpu(l.outputs*l.inputs, base.weight_updates, 1, l.weight_updates, 1);
}
}
void sync_updates(network *nets, int n)
{
int i,j;
int layers = nets[0].n;
network net = nets[0];
for (j = 0; j < layers; ++j) {
layer base = net.layers[j];
cuda_set_device(net.gpu_index);
pull_updates(base);
for (i = 1; i < n; ++i) {
cuda_set_device(nets[i].gpu_index);
layer l = nets[i].layers[j];
pull_updates(l);
merge_updates(l, base);
}
for (i = 1; i < n; ++i) {
cuda_set_device(nets[i].gpu_index);
layer l = nets[i].layers[j];
distribute_updates(l, base);
push_updates(l);
}
cuda_set_device(net.gpu_index);
push_updates(base);
}
for (i = 0; i < n; ++i) {
cuda_set_device(nets[i].gpu_index);
if(i > 0) nets[i].momentum = 0;
update_network_gpu(nets[i]);
}
}
float train_networks(network *nets, int n, data d)
{
int batch = nets[0].batch;
assert(batch * n == d.X.rows);
assert(nets[0].subdivisions % n == 0);
float **X = (float **) calloc(n, sizeof(float *));
float **y = (float **) calloc(n, sizeof(float *));
pthread_t *threads = (pthread_t *) calloc(n, sizeof(pthread_t));
@ -255,11 +346,20 @@ float train_networks(network *nets, int n, data d)
X[i] = (float *) calloc(batch*d.X.cols, sizeof(float));
y[i] = (float *) calloc(batch*d.y.cols, sizeof(float));
get_next_batch(d, batch, i*batch, X[i], y[i]);
float err = train_network_datum(nets[i], X[i], y[i]);
sum += err;
threads[i] = train_network_in_thread(nets[i], X[i], y[i]);
}
for(i = 0; i < n; ++i){
pthread_join(threads[i], 0);
*nets[i].seen += n*nets[i].batch;
printf("%f\n", get_network_cost(nets[i]) / batch);
sum += get_network_cost(nets[i]);
free(X[i]);
free(y[i]);
}
if (((*nets[0].seen) / nets[0].batch) % nets[0].subdivisions == 0) sync_updates(nets, n);
free(X);
free(y);
free(threads);
return (float)sum/(n*batch);
}

View File

@ -551,6 +551,7 @@ network parse_network_cfg(char *filename)
node *n = sections->front;
if(!n) error("Config file has no sections");
network net = make_network(sections->size - 1);
net.gpu_index = gpu_index;
size_params params;
section *s = (section *)n->val;
@ -856,13 +857,13 @@ void save_weights_double(network net, char *filename)
for (j = 0; j < l.n; ++j){
int index = j*l.c*l.size*l.size;
fwrite(l.filters+index, sizeof(float), l.c*l.size*l.size, fp);
fwrite(l.weights+index, sizeof(float), l.c*l.size*l.size, fp);
for (k = 0; k < l.c*l.size*l.size; ++k) fwrite(&zero, sizeof(float), 1, fp);
}
for (j = 0; j < l.n; ++j){
int index = j*l.c*l.size*l.size;
for (k = 0; k < l.c*l.size*l.size; ++k) fwrite(&zero, sizeof(float), 1, fp);
fwrite(l.filters+index, sizeof(float), l.c*l.size*l.size, fp);
fwrite(l.weights+index, sizeof(float), l.c*l.size*l.size, fp);
}
}
}
@ -876,7 +877,7 @@ void save_convolutional_weights_binary(layer l, FILE *fp)
pull_convolutional_layer(l);
}
#endif
binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
int size = l.c*l.size*l.size;
int i, j, k;
fwrite(l.biases, sizeof(float), l.n, fp);
@ -886,7 +887,7 @@ void save_convolutional_weights_binary(layer l, FILE *fp)
fwrite(l.rolling_variance, sizeof(float), l.n, fp);
}
for(i = 0; i < l.n; ++i){
float mean = l.binary_filters[i*size];
float mean = l.binary_weights[i*size];
if(mean < 0) mean = -mean;
fwrite(&mean, sizeof(float), 1, fp);
for(j = 0; j < size/8; ++j){
@ -894,7 +895,7 @@ void save_convolutional_weights_binary(layer l, FILE *fp)
unsigned char c = 0;
for(k = 0; k < 8; ++k){
if (j*8 + k >= size) break;
if (l.binary_filters[index + k] > 0) c = (c | 1<<k);
if (l.binary_weights[index + k] > 0) c = (c | 1<<k);
}
fwrite(&c, sizeof(char), 1, fp);
}
@ -919,7 +920,7 @@ void save_convolutional_weights(layer l, FILE *fp)
fwrite(l.rolling_mean, sizeof(float), l.n, fp);
fwrite(l.rolling_variance, sizeof(float), l.n, fp);
}
fwrite(l.filters, sizeof(float), num, fp);
fwrite(l.weights, sizeof(float), num, fp);
}
void save_batchnorm_weights(layer l, FILE *fp)
@ -952,6 +953,9 @@ void save_connected_weights(layer l, FILE *fp)
void save_weights_upto(network net, char *filename, int cutoff)
{
#ifdef GPU
cuda_set_device(net.gpu_index);
#endif
fprintf(stderr, "Saving weights to %s\n", filename);
FILE *fp = fopen(filename, "w");
if(!fp) file_error(filename);
@ -997,7 +1001,7 @@ void save_weights_upto(network net, char *filename, int cutoff)
int locations = l.out_w*l.out_h;
int size = l.size*l.size*l.c*l.n*locations;
fwrite(l.biases, sizeof(float), l.outputs, fp);
fwrite(l.filters, sizeof(float), size, fp);
fwrite(l.weights, sizeof(float), size, fp);
}
}
fclose(fp);
@ -1075,7 +1079,7 @@ void load_convolutional_weights_binary(layer l, FILE *fp)
fread(&c, sizeof(char), 1, fp);
for(k = 0; k < 8; ++k){
if (j*8 + k >= size) break;
l.filters[index + k] = (c & 1<<k) ? mean : -mean;
l.weights[index + k] = (c & 1<<k) ? mean : -mean;
}
}
}
@ -1099,12 +1103,12 @@ void load_convolutional_weights(layer l, FILE *fp)
fread(l.rolling_mean, sizeof(float), l.n, fp);
fread(l.rolling_variance, sizeof(float), l.n, fp);
}
fread(l.filters, sizeof(float), num, fp);
//if(l.c == 3) scal_cpu(num, 1./256, l.filters, 1);
fread(l.weights, sizeof(float), num, fp);
//if(l.c == 3) scal_cpu(num, 1./256, l.weights, 1);
if (l.flipped) {
transpose_matrix(l.filters, l.c*l.size*l.size, l.n);
transpose_matrix(l.weights, l.c*l.size*l.size, l.n);
}
//if (l.binary) binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.filters);
//if (l.binary) binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.weights);
#ifdef GPU
if(gpu_index >= 0){
push_convolutional_layer(l);
@ -1115,6 +1119,9 @@ void load_convolutional_weights(layer l, FILE *fp)
void load_weights_upto(network *net, char *filename, int cutoff)
{
#ifdef GPU
cuda_set_device(net->gpu_index);
#endif
fprintf(stderr, "Loading weights from %s...", filename);
fflush(stdout);
FILE *fp = fopen(filename, "rb");
@ -1139,7 +1146,7 @@ void load_weights_upto(network *net, char *filename, int cutoff)
if(l.type == DECONVOLUTIONAL){
int num = l.n*l.c*l.size*l.size;
fread(l.biases, sizeof(float), l.n, fp);
fread(l.filters, sizeof(float), num, fp);
fread(l.weights, sizeof(float), num, fp);
#ifdef GPU
if(gpu_index >= 0){
push_deconvolutional_layer(l);
@ -1174,7 +1181,7 @@ void load_weights_upto(network *net, char *filename, int cutoff)
int locations = l.out_w*l.out_h;
int size = l.size*l.size*l.c*l.n*locations;
fread(l.biases, sizeof(float), l.outputs, fp);
fread(l.filters, sizeof(float), size, fp);
fread(l.weights, sizeof(float), size, fp);
#ifdef GPU
if(gpu_index >= 0){
push_local_layer(l);

View File

@ -22,11 +22,18 @@ region_layer make_region_layer(int batch, int w, int h, int n, int classes, int
l.classes = classes;
l.coords = coords;
l.cost = calloc(1, sizeof(float));
l.biases = calloc(n*2, sizeof(float));
l.bias_updates = calloc(n*2, sizeof(float));
l.outputs = h*w*n*(classes + coords + 1);
l.inputs = l.outputs;
l.truths = 30*(5);
l.delta = calloc(batch*l.outputs, sizeof(float));
l.output = calloc(batch*l.outputs, sizeof(float));
int i;
for(i = 0; i < n*2; ++i){
l.biases[i] = .5;
}
#ifdef GPU
l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
@ -38,62 +45,30 @@ region_layer make_region_layer(int batch, int w, int h, int n, int classes, int
return l;
}
box get_region_box2(float *x, int index, int i, int j, int w, int h)
box get_region_box(float *x, float *biases, int n, int index, int i, int j, int w, int h)
{
float aspect = exp(x[index+0]);
float scale = logistic_activate(x[index+1]);
float move_x = x[index+2];
float move_y = x[index+3];
box b;
b.w = sqrt(scale * aspect);
b.h = b.w * 1./aspect;
b.x = move_x * b.w + (i + .5)/w;
b.y = move_y * b.h + (j + .5)/h;
b.x = (i + .5)/w + x[index + 0] * biases[2*n];
b.y = (j + .5)/h + x[index + 1] * biases[2*n + 1];
b.w = exp(x[index + 2]) * biases[2*n];
b.h = exp(x[index + 3]) * biases[2*n+1];
return b;
}
float delta_region_box2(box truth, float *output, int index, int i, int j, int w, int h, float *delta)
float delta_region_box(box truth, float *x, float *biases, int n, int index, int i, int j, int w, int h, float *delta, float scale)
{
box pred = get_region_box2(output, index, i, j, w, h);
float iou = box_iou(pred, truth);
float true_aspect = truth.w/truth.h;
float true_scale = truth.w*truth.h;
float true_dx = (truth.x - (i+.5)/w) / truth.w;
float true_dy = (truth.y - (j+.5)/h) / truth.h;
delta[index + 0] = (true_aspect - exp(output[index + 0])) * exp(output[index + 0]);
delta[index + 1] = (true_scale - logistic_activate(output[index + 1])) * logistic_gradient(logistic_activate(output[index + 1]));
delta[index + 2] = true_dx - output[index + 2];
delta[index + 3] = true_dy - output[index + 3];
return iou;
}
box get_region_box(float *x, int index, int i, int j, int w, int h, int adjust, int logistic)
{
box b;
b.x = (x[index + 0] + i + .5)/w;
b.y = (x[index + 1] + j + .5)/h;
b.w = x[index + 2];
b.h = x[index + 3];
if(logistic){
b.w = logistic_activate(x[index + 2]);
b.h = logistic_activate(x[index + 3]);
}
if(adjust && b.w < .01) b.w = .01;
if(adjust && b.h < .01) b.h = .01;
return b;
}
float delta_region_box(box truth, float *output, int index, int i, int j, int w, int h, float *delta, int logistic, float scale)
{
box pred = get_region_box(output, index, i, j, w, h, 0, logistic);
box pred = get_region_box(x, biases, n, index, i, j, w, h);
float iou = box_iou(pred, truth);
delta[index + 0] = scale * (truth.x - pred.x);
delta[index + 1] = scale * (truth.y - pred.y);
delta[index + 2] = scale * ((truth.w - pred.w)*(logistic ? logistic_gradient(pred.w) : 1));
delta[index + 3] = scale * ((truth.h - pred.h)*(logistic ? logistic_gradient(pred.h) : 1));
float tx = (truth.x - (i + .5)/w) / biases[2*n];
float ty = (truth.y - (j + .5)/h) / biases[2*n + 1];
float tw = log(truth.w / biases[2*n]);
float th = log(truth.h / biases[2*n + 1]);
delta[index + 0] = scale * (tx - x[index + 0]);
delta[index + 1] = scale * (ty - x[index + 1]);
delta[index + 2] = scale * (tw - x[index + 2]);
delta[index + 3] = scale * (th - x[index + 3]);
return iou;
}
@ -107,7 +82,7 @@ float tisnan(float x)
return (x != x);
}
#define LOG 1
#define LOG 0
void forward_region_layer(const region_layer l, network_state state)
{
@ -127,6 +102,7 @@ void forward_region_layer(const region_layer l, network_state state)
if(!state.train) return;
memset(l.delta, 0, l.outputs * l.batch * sizeof(float));
float avg_iou = 0;
float recall = 0;
float avg_cat = 0;
float avg_obj = 0;
float avg_anyobj = 0;
@ -137,7 +113,7 @@ void forward_region_layer(const region_layer l, network_state state)
for (i = 0; i < l.w; ++i) {
for (n = 0; n < l.n; ++n) {
int index = size*(j*l.w*l.n + i*l.n + n) + b*l.outputs;
box pred = get_region_box(l.output, index, i, j, l.w, l.h, 1, LOG);
box pred = get_region_box(l.output, l.biases, n, index, i, j, l.w, l.h);
float best_iou = 0;
for(t = 0; t < 30; ++t){
box truth = float_to_box(state.truth + t*5 + b*l.truths);
@ -155,7 +131,11 @@ void forward_region_layer(const region_layer l, network_state state)
truth.y = (j + .5)/l.h;
truth.w = .5;
truth.h = .5;
delta_region_box(truth, l.output, index, i, j, l.w, l.h, l.delta, LOG, 1);
delta_region_box(truth, l.output, l.biases, n, index, i, j, l.w, l.h, l.delta, .01);
//l.delta[index + 0] = .1 * (0 - l.output[index + 0]);
//l.delta[index + 1] = .1 * (0 - l.output[index + 1]);
//l.delta[index + 2] = .1 * (0 - l.output[index + 2]);
//l.delta[index + 3] = .1 * (0 - l.output[index + 3]);
}
}
}
@ -176,8 +156,8 @@ void forward_region_layer(const region_layer l, network_state state)
printf("index %d %d\n",i, j);
for(n = 0; n < l.n; ++n){
int index = size*(j*l.w*l.n + i*l.n + n) + b*l.outputs;
box pred = get_region_box(l.output, index, i, j, l.w, l.h, 1, LOG);
printf("pred: (%f, %f) %f x %f\n", pred.x, pred.y, pred.w, pred.h);
box pred = get_region_box(l.output, l.biases, n, index, i, j, l.w, l.h);
printf("pred: (%f, %f) %f x %f\n", pred.x*l.w - i - .5, pred.y * l.h - j - .5, pred.w, pred.h);
pred.x = 0;
pred.y = 0;
float iou = box_iou(pred, truth_shift);
@ -187,9 +167,10 @@ void forward_region_layer(const region_layer l, network_state state)
best_n = n;
}
}
printf("%d %f (%f, %f) %f x %f\n", best_n, best_iou, truth.x, truth.y, truth.w, truth.h);
printf("%d %f (%f, %f) %f x %f\n", best_n, best_iou, truth.x * l.w - i - .5, truth.y*l.h - j - .5, truth.w, truth.h);
float iou = delta_region_box(truth, l.output, best_index, i, j, l.w, l.h, l.delta, LOG, l.coord_scale);
float iou = delta_region_box(truth, l.output, l.biases, best_n, best_index, i, j, l.w, l.h, l.delta, l.coord_scale);
if(iou > .5) recall += 1;
avg_iou += iou;
//l.delta[best_index + 4] = iou - l.output[best_index + 4];
@ -239,7 +220,7 @@ void forward_region_layer(const region_layer l, network_state state)
printf("\n");
reorg(l.delta, l.w*l.h, size*l.n, l.batch, 0);
*(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2);
printf("Region Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_obj/count, avg_anyobj/(l.w*l.h*l.n*l.batch), count);
printf("Region Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, Avg Recall: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_obj/count, avg_anyobj/(l.w*l.h*l.n*l.batch), recall/count, count);
}
void backward_region_layer(const region_layer l, network_state state)