mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
good chance I didn't break anything
This commit is contained in:
parent
8ec889f103
commit
5c067dc447
@ -1,85 +1,91 @@
|
||||
[net]
|
||||
batch=128
|
||||
subdivisions=1
|
||||
height=256
|
||||
width=256
|
||||
height=227
|
||||
width=227
|
||||
channels=3
|
||||
learning_rate=0.01
|
||||
momentum=0.9
|
||||
decay=0.0005
|
||||
max_crop=256
|
||||
|
||||
[crop]
|
||||
crop_height=224
|
||||
crop_width=224
|
||||
flip=1
|
||||
angle=0
|
||||
saturation=1
|
||||
exposure=1
|
||||
learning_rate=0.01
|
||||
policy=poly
|
||||
power=4
|
||||
max_batches=800000
|
||||
|
||||
angle=7
|
||||
hue = .1
|
||||
saturation=.75
|
||||
exposure=.75
|
||||
aspect=.75
|
||||
|
||||
[convolutional]
|
||||
filters=64
|
||||
filters=96
|
||||
size=11
|
||||
stride=4
|
||||
pad=0
|
||||
activation=ramp
|
||||
activation=relu
|
||||
|
||||
[maxpool]
|
||||
size=3
|
||||
stride=2
|
||||
padding=0
|
||||
|
||||
[convolutional]
|
||||
filters=192
|
||||
filters=256
|
||||
size=5
|
||||
stride=1
|
||||
pad=1
|
||||
activation=ramp
|
||||
activation=relu
|
||||
|
||||
[maxpool]
|
||||
size=3
|
||||
stride=2
|
||||
padding=0
|
||||
|
||||
[convolutional]
|
||||
filters=384
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=ramp
|
||||
activation=relu
|
||||
|
||||
[convolutional]
|
||||
filters=384
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=relu
|
||||
|
||||
[convolutional]
|
||||
filters=256
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=ramp
|
||||
|
||||
[convolutional]
|
||||
filters=256
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=ramp
|
||||
activation=relu
|
||||
|
||||
[maxpool]
|
||||
size=3
|
||||
stride=2
|
||||
padding=0
|
||||
|
||||
[connected]
|
||||
output=4096
|
||||
activation=ramp
|
||||
activation=relu
|
||||
|
||||
[dropout]
|
||||
probability=.5
|
||||
|
||||
[connected]
|
||||
output=4096
|
||||
activation=ramp
|
||||
activation=relu
|
||||
|
||||
[dropout]
|
||||
probability=.5
|
||||
|
||||
[connected]
|
||||
output=1000
|
||||
activation=ramp
|
||||
activation=linear
|
||||
|
||||
[softmax]
|
||||
groups=1
|
||||
|
206
src/classifier.c
206
src/classifier.c
@ -5,6 +5,7 @@
|
||||
#include "blas.h"
|
||||
#include "assert.h"
|
||||
#include "classifier.h"
|
||||
#include "cuda.h"
|
||||
#include <sys/time.h>
|
||||
|
||||
#ifdef OPENCV
|
||||
@ -51,6 +52,134 @@ float *get_regression_values(char **labels, int n)
|
||||
return v;
|
||||
}
|
||||
|
||||
void train_classifier_multi(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
|
||||
{
|
||||
#ifdef GPU
|
||||
int nthreads = 8;
|
||||
int i;
|
||||
|
||||
data_seed = time(0);
|
||||
srand(time(0));
|
||||
float avg_loss = -1;
|
||||
char *base = basecfg(cfgfile);
|
||||
printf("%s\n", base);
|
||||
printf("%d\n", ngpus);
|
||||
network *nets = calloc(ngpus, sizeof(network));
|
||||
for(i = 0; i < ngpus; ++i){
|
||||
cuda_set_device(gpus[i]);
|
||||
nets[i] = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&(nets[i]), weightfile);
|
||||
}
|
||||
if(clear) *nets[i].seen = 0;
|
||||
}
|
||||
network net = nets[0];
|
||||
|
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
int imgs = net.batch*ngpus/nthreads;
|
||||
assert(net.batch*ngpus % nthreads == 0);
|
||||
|
||||
list *options = read_data_cfg(datacfg);
|
||||
|
||||
char *backup_directory = option_find_str(options, "backup", "/backup/");
|
||||
char *label_list = option_find_str(options, "labels", "data/labels.list");
|
||||
char *train_list = option_find_str(options, "train", "data/train.list");
|
||||
int classes = option_find_int(options, "classes", 2);
|
||||
|
||||
char **labels = get_labels(label_list);
|
||||
list *plist = get_paths(train_list);
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
printf("%d\n", plist->size);
|
||||
int N = plist->size;
|
||||
clock_t time;
|
||||
|
||||
pthread_t *load_threads = calloc(nthreads, sizeof(pthread_t));
|
||||
data *trains = calloc(nthreads, sizeof(data));
|
||||
data *buffers = calloc(nthreads, sizeof(data));
|
||||
|
||||
load_args args = {0};
|
||||
args.w = net.w;
|
||||
args.h = net.h;
|
||||
|
||||
args.min = net.min_crop;
|
||||
args.max = net.max_crop;
|
||||
args.angle = net.angle;
|
||||
args.aspect = net.aspect;
|
||||
args.exposure = net.exposure;
|
||||
args.saturation = net.saturation;
|
||||
args.hue = net.hue;
|
||||
args.size = net.w;
|
||||
|
||||
args.paths = paths;
|
||||
args.classes = classes;
|
||||
args.n = imgs;
|
||||
args.m = N;
|
||||
args.labels = labels;
|
||||
args.type = CLASSIFICATION_DATA;
|
||||
|
||||
for(i = 0; i < nthreads; ++i){
|
||||
args.d = buffers + i;
|
||||
load_threads[i] = load_data_in_thread(args);
|
||||
}
|
||||
|
||||
int epoch = (*net.seen)/N;
|
||||
while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
|
||||
time=clock();
|
||||
for(i = 0; i < nthreads; ++i){
|
||||
pthread_join(load_threads[i], 0);
|
||||
trains[i] = buffers[i];
|
||||
}
|
||||
data train = concat_datas(trains, nthreads);
|
||||
|
||||
for(i = 0; i < nthreads; ++i){
|
||||
args.d = buffers + i;
|
||||
load_threads[i] = load_data_in_thread(args);
|
||||
}
|
||||
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
time=clock();
|
||||
|
||||
float loss = train_networks(nets, ngpus, train);
|
||||
if(avg_loss == -1) avg_loss = loss;
|
||||
avg_loss = avg_loss*.9 + loss*.1;
|
||||
printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
|
||||
free_data(train);
|
||||
for(i = 0; i < nthreads; ++i){
|
||||
free_data(trains[i]);
|
||||
}
|
||||
if(*net.seen/N > epoch){
|
||||
epoch = *net.seen/N;
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
|
||||
save_weights(net, buff);
|
||||
}
|
||||
if(get_current_batch(net)%100 == 0){
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s.backup",backup_directory,base);
|
||||
save_weights(net, buff);
|
||||
}
|
||||
}
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s.weights", backup_directory, base);
|
||||
save_weights(net, buff);
|
||||
|
||||
for(i = 0; i < nthreads; ++i){
|
||||
pthread_join(load_threads[i], 0);
|
||||
free_data(buffers[i]);
|
||||
}
|
||||
free(buffers);
|
||||
free(trains);
|
||||
free(load_threads);
|
||||
|
||||
free_network(net);
|
||||
free_ptrs((void**)labels, classes);
|
||||
free_ptrs((void**)paths, plist->size);
|
||||
free_list(plist);
|
||||
free(base);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear)
|
||||
{
|
||||
int nthreads = 8;
|
||||
@ -130,7 +259,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear)
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
time=clock();
|
||||
|
||||
#ifdef OPENCV
|
||||
#ifdef OPENCV
|
||||
if(0){
|
||||
int u;
|
||||
for(u = 0; u < imgs; ++u){
|
||||
@ -139,7 +268,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear)
|
||||
cvWaitKey(0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
float loss = train_network(net, train);
|
||||
if(avg_loss == -1) avg_loss = loss;
|
||||
@ -546,29 +675,29 @@ void try_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filena
|
||||
float *X = im.data;
|
||||
time=clock();
|
||||
float *predictions = network_predict(net, X);
|
||||
|
||||
|
||||
layer l = net.layers[layer_num];
|
||||
for(i = 0; i < l.c; ++i){
|
||||
if(l.rolling_mean) printf("%f %f %f\n", l.rolling_mean[i], l.rolling_variance[i], l.scales[i]);
|
||||
if(l.rolling_mean) printf("%f %f %f\n", l.rolling_mean[i], l.rolling_variance[i], l.scales[i]);
|
||||
}
|
||||
#ifdef GPU
|
||||
#ifdef GPU
|
||||
cuda_pull_array(l.output_gpu, l.output, l.outputs);
|
||||
#endif
|
||||
#endif
|
||||
for(i = 0; i < l.outputs; ++i){
|
||||
printf("%f\n", l.output[i]);
|
||||
}
|
||||
/*
|
||||
|
||||
printf("\n\nWeights\n");
|
||||
for(i = 0; i < l.n*l.size*l.size*l.c; ++i){
|
||||
printf("%f\n", l.filters[i]);
|
||||
}
|
||||
|
||||
printf("\n\nBiases\n");
|
||||
for(i = 0; i < l.n; ++i){
|
||||
printf("%f\n", l.biases[i]);
|
||||
}
|
||||
*/
|
||||
printf("\n\nWeights\n");
|
||||
for(i = 0; i < l.n*l.size*l.size*l.c; ++i){
|
||||
printf("%f\n", l.filters[i]);
|
||||
}
|
||||
|
||||
printf("\n\nBiases\n");
|
||||
for(i = 0; i < l.n; ++i){
|
||||
printf("%f\n", l.biases[i]);
|
||||
}
|
||||
*/
|
||||
|
||||
top_predictions(net, top, indexes);
|
||||
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
||||
@ -794,15 +923,15 @@ void threat_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_i
|
||||
if(!in.data) break;
|
||||
image in_s = resize_image(in, net.w, net.h);
|
||||
|
||||
image out = in;
|
||||
int x1 = out.w / 20;
|
||||
int y1 = out.h / 20;
|
||||
int x2 = 2*x1;
|
||||
int y2 = out.h - out.h/20;
|
||||
image out = in;
|
||||
int x1 = out.w / 20;
|
||||
int y1 = out.h / 20;
|
||||
int x2 = 2*x1;
|
||||
int y2 = out.h - out.h/20;
|
||||
|
||||
int border = .01*out.h;
|
||||
int h = y2 - y1 - 2*border;
|
||||
int w = x2 - x1 - 2*border;
|
||||
int border = .01*out.h;
|
||||
int h = y2 - y1 - 2*border;
|
||||
int w = x2 - x1 - 2*border;
|
||||
|
||||
float *predictions = network_predict(net, in_s.data);
|
||||
float curr_threat = predictions[0] * 0 + predictions[1] * .6 + predictions[2];
|
||||
@ -821,11 +950,11 @@ void threat_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_i
|
||||
y1 + .02*h + 3*border, .5*border, 0,0,0);
|
||||
draw_box_width(out, x2 + border, y1 + .42*h, x2 + .5 * w, y1 + .42*h + border, border, 0,0,0);
|
||||
if(threat > .57) {
|
||||
draw_box_width(out, x2 + .5 * w + border,
|
||||
y1 + .42*h - 2*border,
|
||||
x2 + .5 * w + 6*border,
|
||||
y1 + .42*h + 3*border, 3*border, 1,1,0);
|
||||
}
|
||||
draw_box_width(out, x2 + .5 * w + border,
|
||||
y1 + .42*h - 2*border,
|
||||
x2 + .5 * w + 6*border,
|
||||
y1 + .42*h + 3*border, 3*border, 1,1,0);
|
||||
}
|
||||
draw_box_width(out, x2 + .5 * w + border,
|
||||
y1 + .42*h - 2*border,
|
||||
x2 + .5 * w + 6*border,
|
||||
@ -942,6 +1071,24 @@ void run_classifier(int argc, char **argv)
|
||||
return;
|
||||
}
|
||||
|
||||
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
|
||||
int *gpus = 0;
|
||||
int ngpus = 0;
|
||||
if(gpu_list){
|
||||
printf("%s\n", gpu_list);
|
||||
int len = strlen(gpu_list);
|
||||
ngpus = 1;
|
||||
int i;
|
||||
for(i = 0; i < len; ++i){
|
||||
if (gpu_list[i] == ',') ++ngpus;
|
||||
}
|
||||
gpus = calloc(ngpus, sizeof(int));
|
||||
for(i = 0; i < ngpus; ++i){
|
||||
gpus[i] = atoi(gpu_list);
|
||||
gpu_list = strchr(gpu_list, ',')+1;
|
||||
}
|
||||
}
|
||||
|
||||
int cam_index = find_int_arg(argc, argv, "-c", 0);
|
||||
int clear = find_arg(argc, argv, "-clear");
|
||||
char *data = argv[3];
|
||||
@ -953,6 +1100,7 @@ void run_classifier(int argc, char **argv)
|
||||
if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename);
|
||||
else if(0==strcmp(argv[2], "try")) try_classifier(data, cfg, weights, filename, atoi(layer_s));
|
||||
else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights, clear);
|
||||
else if(0==strcmp(argv[2], "trainm")) train_classifier_multi(data, cfg, weights, gpus, ngpus, clear);
|
||||
else if(0==strcmp(argv[2], "demo")) demo_classifier(data, cfg, weights, cam_index, filename);
|
||||
else if(0==strcmp(argv[2], "threat")) threat_classifier(data, cfg, weights, cam_index, filename);
|
||||
else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights, layer);
|
||||
|
@ -48,25 +48,25 @@ void binarize_input_gpu(float *input, int n, int size, float *binary)
|
||||
}
|
||||
|
||||
|
||||
__global__ void binarize_filters_kernel(float *filters, int n, int size, float *binary)
|
||||
__global__ void binarize_weights_kernel(float *weights, int n, int size, float *binary)
|
||||
{
|
||||
int f = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (f >= n) return;
|
||||
int i = 0;
|
||||
float mean = 0;
|
||||
for(i = 0; i < size; ++i){
|
||||
mean += abs(filters[f*size + i]);
|
||||
mean += abs(weights[f*size + i]);
|
||||
}
|
||||
mean = mean / size;
|
||||
for(i = 0; i < size; ++i){
|
||||
binary[f*size + i] = (filters[f*size + i] > 0) ? mean : -mean;
|
||||
//binary[f*size + i] = filters[f*size + i];
|
||||
binary[f*size + i] = (weights[f*size + i] > 0) ? mean : -mean;
|
||||
//binary[f*size + i] = weights[f*size + i];
|
||||
}
|
||||
}
|
||||
|
||||
void binarize_filters_gpu(float *filters, int n, int size, float *binary)
|
||||
void binarize_weights_gpu(float *weights, int n, int size, float *binary)
|
||||
{
|
||||
binarize_filters_kernel<<<cuda_gridsize(n), BLOCK>>>(filters, n, size, binary);
|
||||
binarize_weights_kernel<<<cuda_gridsize(n), BLOCK>>>(weights, n, size, binary);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@ -74,12 +74,12 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
{
|
||||
fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
|
||||
if(l.binary){
|
||||
binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
|
||||
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
||||
swap_binary(&l);
|
||||
}
|
||||
|
||||
if(l.xnor){
|
||||
binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
|
||||
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
||||
swap_binary(&l);
|
||||
binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
|
||||
state.input = l.binary_input_gpu;
|
||||
@ -91,8 +91,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
&one,
|
||||
l.srcTensorDesc,
|
||||
state.input,
|
||||
l.filterDesc,
|
||||
l.filters_gpu,
|
||||
l.weightDesc,
|
||||
l.weights_gpu,
|
||||
l.convDesc,
|
||||
l.fw_algo,
|
||||
state.workspace,
|
||||
@ -108,7 +108,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
int n = l.out_w*l.out_h;
|
||||
for(i = 0; i < l.batch; ++i){
|
||||
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace);
|
||||
float * a = l.filters_gpu;
|
||||
float * a = l.weights_gpu;
|
||||
float * b = state.workspace;
|
||||
float * c = l.output_gpu;
|
||||
gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
|
||||
@ -150,15 +150,15 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
state.workspace,
|
||||
l.workspace_size,
|
||||
&one,
|
||||
l.dfilterDesc,
|
||||
l.filter_updates_gpu);
|
||||
l.dweightDesc,
|
||||
l.weight_updates_gpu);
|
||||
|
||||
if(state.delta){
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
cudnnConvolutionBackwardData(cudnn_handle(),
|
||||
&one,
|
||||
l.filterDesc,
|
||||
l.filters_gpu,
|
||||
l.weightDesc,
|
||||
l.weights_gpu,
|
||||
l.ddstTensorDesc,
|
||||
l.delta_gpu,
|
||||
l.convDesc,
|
||||
@ -181,14 +181,14 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
for(i = 0; i < l.batch; ++i){
|
||||
float * a = l.delta_gpu;
|
||||
float * b = state.workspace;
|
||||
float * c = l.filter_updates_gpu;
|
||||
float * c = l.weight_updates_gpu;
|
||||
|
||||
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace);
|
||||
gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
|
||||
|
||||
if(state.delta){
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
float * a = l.filters_gpu;
|
||||
float * a = l.weights_gpu;
|
||||
float * b = l.delta_gpu;
|
||||
float * c = state.workspace;
|
||||
|
||||
@ -206,9 +206,9 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
|
||||
void pull_convolutional_layer(convolutional_layer layer)
|
||||
{
|
||||
cuda_pull_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_pull_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_pull_array(layer.biases_gpu, layer.biases, layer.n);
|
||||
cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_pull_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
|
||||
if (layer.batch_normalize){
|
||||
cuda_pull_array(layer.scales_gpu, layer.scales, layer.n);
|
||||
@ -219,9 +219,9 @@ void pull_convolutional_layer(convolutional_layer layer)
|
||||
|
||||
void push_convolutional_layer(convolutional_layer layer)
|
||||
{
|
||||
cuda_push_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
|
||||
cuda_push_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
|
||||
if (layer.batch_normalize){
|
||||
cuda_push_array(layer.scales_gpu, layer.scales, layer.n);
|
||||
@ -240,9 +240,9 @@ void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float
|
||||
axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1);
|
||||
scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
|
||||
|
||||
axpy_ongpu(size, -decay*batch, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
|
||||
axpy_ongpu(size, learning_rate/batch, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
|
||||
scal_ongpu(size, momentum, layer.filter_updates_gpu, 1);
|
||||
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
|
||||
axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
|
||||
scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);
|
||||
}
|
||||
|
||||
|
||||
|
@ -19,28 +19,28 @@ void forward_xnor_layer(layer l, network_state state);
|
||||
|
||||
void swap_binary(convolutional_layer *l)
|
||||
{
|
||||
float *swap = l->filters;
|
||||
l->filters = l->binary_filters;
|
||||
l->binary_filters = swap;
|
||||
float *swap = l->weights;
|
||||
l->weights = l->binary_weights;
|
||||
l->binary_weights = swap;
|
||||
|
||||
#ifdef GPU
|
||||
swap = l->filters_gpu;
|
||||
l->filters_gpu = l->binary_filters_gpu;
|
||||
l->binary_filters_gpu = swap;
|
||||
swap = l->weights_gpu;
|
||||
l->weights_gpu = l->binary_weights_gpu;
|
||||
l->binary_weights_gpu = swap;
|
||||
#endif
|
||||
}
|
||||
|
||||
void binarize_filters(float *filters, int n, int size, float *binary)
|
||||
void binarize_weights(float *weights, int n, int size, float *binary)
|
||||
{
|
||||
int i, f;
|
||||
for(f = 0; f < n; ++f){
|
||||
float mean = 0;
|
||||
for(i = 0; i < size; ++i){
|
||||
mean += fabs(filters[f*size + i]);
|
||||
mean += fabs(weights[f*size + i]);
|
||||
}
|
||||
mean = mean / size;
|
||||
for(i = 0; i < size; ++i){
|
||||
binary[f*size + i] = (filters[f*size + i] > 0) ? mean : -mean;
|
||||
binary[f*size + i] = (weights[f*size + i] > 0) ? mean : -mean;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -103,7 +103,7 @@ size_t get_workspace_size(layer l){
|
||||
size_t s = 0;
|
||||
cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
|
||||
l.srcTensorDesc,
|
||||
l.filterDesc,
|
||||
l.weightDesc,
|
||||
l.convDesc,
|
||||
l.dstTensorDesc,
|
||||
l.fw_algo,
|
||||
@ -113,12 +113,12 @@ size_t get_workspace_size(layer l){
|
||||
l.srcTensorDesc,
|
||||
l.ddstTensorDesc,
|
||||
l.convDesc,
|
||||
l.dfilterDesc,
|
||||
l.dweightDesc,
|
||||
l.bf_algo,
|
||||
&s);
|
||||
if (s > most) most = s;
|
||||
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
|
||||
l.filterDesc,
|
||||
l.weightDesc,
|
||||
l.ddstTensorDesc,
|
||||
l.convDesc,
|
||||
l.dsrcTensorDesc,
|
||||
@ -137,22 +137,22 @@ void cudnn_convolutional_setup(layer *l)
|
||||
{
|
||||
cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
|
||||
cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
|
||||
cudnnSetFilter4dDescriptor(l->dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
||||
cudnnSetFilter4dDescriptor(l->dweightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
||||
|
||||
cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
|
||||
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
|
||||
cudnnSetFilter4dDescriptor(l->filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
||||
cudnnSetFilter4dDescriptor(l->weightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
|
||||
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
|
||||
cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
|
||||
l->srcTensorDesc,
|
||||
l->filterDesc,
|
||||
l->weightDesc,
|
||||
l->convDesc,
|
||||
l->dstTensorDesc,
|
||||
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
|
||||
0,
|
||||
&l->fw_algo);
|
||||
cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
|
||||
l->filterDesc,
|
||||
l->weightDesc,
|
||||
l->ddstTensorDesc,
|
||||
l->convDesc,
|
||||
l->dsrcTensorDesc,
|
||||
@ -163,7 +163,7 @@ void cudnn_convolutional_setup(layer *l)
|
||||
l->srcTensorDesc,
|
||||
l->ddstTensorDesc,
|
||||
l->convDesc,
|
||||
l->dfilterDesc,
|
||||
l->dweightDesc,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
|
||||
0,
|
||||
&l->bf_algo);
|
||||
@ -189,15 +189,15 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
l.pad = padding;
|
||||
l.batch_normalize = batch_normalize;
|
||||
|
||||
l.filters = calloc(c*n*size*size, sizeof(float));
|
||||
l.filter_updates = calloc(c*n*size*size, sizeof(float));
|
||||
l.weights = calloc(c*n*size*size, sizeof(float));
|
||||
l.weight_updates = calloc(c*n*size*size, sizeof(float));
|
||||
|
||||
l.biases = calloc(n, sizeof(float));
|
||||
l.bias_updates = calloc(n, sizeof(float));
|
||||
|
||||
// float scale = 1./sqrt(size*size*c);
|
||||
float scale = sqrt(2./(size*size*c));
|
||||
for(i = 0; i < c*n*size*size; ++i) l.filters[i] = scale*rand_uniform(-1, 1);
|
||||
for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_uniform(-1, 1);
|
||||
int out_h = convolutional_out_height(l);
|
||||
int out_w = convolutional_out_width(l);
|
||||
l.out_h = out_h;
|
||||
@ -210,12 +210,12 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float));
|
||||
|
||||
if(binary){
|
||||
l.binary_filters = calloc(c*n*size*size, sizeof(float));
|
||||
l.cfilters = calloc(c*n*size*size, sizeof(char));
|
||||
l.binary_weights = calloc(c*n*size*size, sizeof(float));
|
||||
l.cweights = calloc(c*n*size*size, sizeof(char));
|
||||
l.scales = calloc(n, sizeof(float));
|
||||
}
|
||||
if(xnor){
|
||||
l.binary_filters = calloc(c*n*size*size, sizeof(float));
|
||||
l.binary_weights = calloc(c*n*size*size, sizeof(float));
|
||||
l.binary_input = calloc(l.inputs*l.batch, sizeof(float));
|
||||
}
|
||||
|
||||
@ -235,8 +235,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
|
||||
#ifdef GPU
|
||||
if(gpu_index >= 0){
|
||||
l.filters_gpu = cuda_make_array(l.filters, c*n*size*size);
|
||||
l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size);
|
||||
l.weights_gpu = cuda_make_array(l.weights, c*n*size*size);
|
||||
l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
|
||||
|
||||
l.biases_gpu = cuda_make_array(l.biases, n);
|
||||
l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
|
||||
@ -248,10 +248,10 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
||||
|
||||
if(binary){
|
||||
l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
|
||||
l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size);
|
||||
}
|
||||
if(xnor){
|
||||
l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
|
||||
l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size);
|
||||
l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
|
||||
}
|
||||
|
||||
@ -271,10 +271,10 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
#ifdef CUDNN
|
||||
cudnnCreateTensorDescriptor(&l.srcTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
|
||||
cudnnCreateFilterDescriptor(&l.filterDesc);
|
||||
cudnnCreateFilterDescriptor(&l.weightDesc);
|
||||
cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
|
||||
cudnnCreateFilterDescriptor(&l.dfilterDesc);
|
||||
cudnnCreateFilterDescriptor(&l.dweightDesc);
|
||||
cudnnCreateConvolutionDescriptor(&l.convDesc);
|
||||
cudnn_convolutional_setup(&l);
|
||||
#endif
|
||||
@ -294,7 +294,7 @@ void denormalize_convolutional_layer(convolutional_layer l)
|
||||
for(i = 0; i < l.n; ++i){
|
||||
float scale = l.scales[i]/sqrt(l.rolling_variance[i] + .00001);
|
||||
for(j = 0; j < l.c*l.size*l.size; ++j){
|
||||
l.filters[i*l.c*l.size*l.size + j] *= scale;
|
||||
l.weights[i*l.c*l.size*l.size + j] *= scale;
|
||||
}
|
||||
l.biases[i] -= l.rolling_mean[i] * scale;
|
||||
l.scales[i] = 1;
|
||||
@ -403,8 +403,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
||||
|
||||
/*
|
||||
if(l.binary){
|
||||
binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
|
||||
binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
|
||||
binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
|
||||
binarize_weights2(l.weights, l.n, l.c*l.size*l.size, l.cweights, l.scales);
|
||||
swap_binary(&l);
|
||||
}
|
||||
*/
|
||||
@ -415,7 +415,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
||||
int k = l.size*l.size*l.c;
|
||||
int n = out_h*out_w;
|
||||
|
||||
char *a = l.cfilters;
|
||||
char *a = l.cweights;
|
||||
float *b = state.workspace;
|
||||
float *c = l.output;
|
||||
|
||||
@ -434,7 +434,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
||||
*/
|
||||
|
||||
if(l.xnor){
|
||||
binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
|
||||
binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
|
||||
swap_binary(&l);
|
||||
binarize_cpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input);
|
||||
state.input = l.binary_input;
|
||||
@ -449,7 +449,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
||||
printf("xnor\n");
|
||||
} else {
|
||||
|
||||
float *a = l.filters;
|
||||
float *a = l.weights;
|
||||
float *b = state.workspace;
|
||||
float *c = l.output;
|
||||
|
||||
@ -485,7 +485,7 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
|
||||
for(i = 0; i < l.batch; ++i){
|
||||
float *a = l.delta + i*m*k;
|
||||
float *b = state.workspace;
|
||||
float *c = l.filter_updates;
|
||||
float *c = l.weight_updates;
|
||||
|
||||
float *im = state.input+i*l.c*l.h*l.w;
|
||||
|
||||
@ -494,7 +494,7 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
|
||||
gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
|
||||
|
||||
if(state.delta){
|
||||
a = l.filters;
|
||||
a = l.weights;
|
||||
b = l.delta + i*m*k;
|
||||
c = state.workspace;
|
||||
|
||||
@ -511,36 +511,36 @@ void update_convolutional_layer(convolutional_layer l, int batch, float learning
|
||||
axpy_cpu(l.n, learning_rate/batch, l.bias_updates, 1, l.biases, 1);
|
||||
scal_cpu(l.n, momentum, l.bias_updates, 1);
|
||||
|
||||
axpy_cpu(size, -decay*batch, l.filters, 1, l.filter_updates, 1);
|
||||
axpy_cpu(size, learning_rate/batch, l.filter_updates, 1, l.filters, 1);
|
||||
scal_cpu(size, momentum, l.filter_updates, 1);
|
||||
axpy_cpu(size, -decay*batch, l.weights, 1, l.weight_updates, 1);
|
||||
axpy_cpu(size, learning_rate/batch, l.weight_updates, 1, l.weights, 1);
|
||||
scal_cpu(size, momentum, l.weight_updates, 1);
|
||||
}
|
||||
|
||||
|
||||
image get_convolutional_filter(convolutional_layer l, int i)
|
||||
image get_convolutional_weight(convolutional_layer l, int i)
|
||||
{
|
||||
int h = l.size;
|
||||
int w = l.size;
|
||||
int c = l.c;
|
||||
return float_to_image(w,h,c,l.filters+i*h*w*c);
|
||||
return float_to_image(w,h,c,l.weights+i*h*w*c);
|
||||
}
|
||||
|
||||
void rgbgr_filters(convolutional_layer l)
|
||||
void rgbgr_weights(convolutional_layer l)
|
||||
{
|
||||
int i;
|
||||
for(i = 0; i < l.n; ++i){
|
||||
image im = get_convolutional_filter(l, i);
|
||||
image im = get_convolutional_weight(l, i);
|
||||
if (im.c == 3) {
|
||||
rgbgr_image(im);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void rescale_filters(convolutional_layer l, float scale, float trans)
|
||||
void rescale_weights(convolutional_layer l, float scale, float trans)
|
||||
{
|
||||
int i;
|
||||
for(i = 0; i < l.n; ++i){
|
||||
image im = get_convolutional_filter(l, i);
|
||||
image im = get_convolutional_weight(l, i);
|
||||
if (im.c == 3) {
|
||||
scale_image(im, scale);
|
||||
float sum = sum_array(im.data, im.w*im.h*im.c);
|
||||
@ -549,21 +549,21 @@ void rescale_filters(convolutional_layer l, float scale, float trans)
|
||||
}
|
||||
}
|
||||
|
||||
image *get_filters(convolutional_layer l)
|
||||
image *get_weights(convolutional_layer l)
|
||||
{
|
||||
image *filters = calloc(l.n, sizeof(image));
|
||||
image *weights = calloc(l.n, sizeof(image));
|
||||
int i;
|
||||
for(i = 0; i < l.n; ++i){
|
||||
filters[i] = copy_image(get_convolutional_filter(l, i));
|
||||
//normalize_image(filters[i]);
|
||||
weights[i] = copy_image(get_convolutional_weight(l, i));
|
||||
//normalize_image(weights[i]);
|
||||
}
|
||||
return filters;
|
||||
return weights;
|
||||
}
|
||||
|
||||
image *visualize_convolutional_layer(convolutional_layer l, char *window, image *prev_filters)
|
||||
image *visualize_convolutional_layer(convolutional_layer l, char *window, image *prev_weights)
|
||||
{
|
||||
image *single_filters = get_filters(l);
|
||||
show_images(single_filters, l.n, window);
|
||||
image *single_weights = get_weights(l);
|
||||
show_images(single_weights, l.n, window);
|
||||
|
||||
image delta = get_convolutional_image(l);
|
||||
image dc = collapse_image_layers(delta, 1);
|
||||
@ -572,6 +572,6 @@ image *visualize_convolutional_layer(convolutional_layer l, char *window, image
|
||||
//show_image(dc, buff);
|
||||
//save_image(dc, buff);
|
||||
free_image(dc);
|
||||
return single_filters;
|
||||
return single_weights;
|
||||
}
|
||||
|
||||
|
@ -29,10 +29,10 @@ void denormalize_convolutional_layer(convolutional_layer l);
|
||||
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
|
||||
void forward_convolutional_layer(const convolutional_layer layer, network_state state);
|
||||
void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay);
|
||||
image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters);
|
||||
void binarize_filters(float *filters, int n, int size, float *binary);
|
||||
image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_weights);
|
||||
void binarize_weights(float *weights, int n, int size, float *binary);
|
||||
void swap_binary(convolutional_layer *l);
|
||||
void binarize_filters2(float *filters, int n, int size, char *binary, float *scales);
|
||||
void binarize_weights2(float *weights, int n, int size, char *binary, float *scales);
|
||||
|
||||
void backward_convolutional_layer(convolutional_layer layer, network_state state);
|
||||
|
||||
@ -41,12 +41,12 @@ void backward_bias(float *bias_updates, float *delta, int batch, int n, int size
|
||||
|
||||
image get_convolutional_image(convolutional_layer layer);
|
||||
image get_convolutional_delta(convolutional_layer layer);
|
||||
image get_convolutional_filter(convolutional_layer layer, int i);
|
||||
image get_convolutional_weight(convolutional_layer layer, int i);
|
||||
|
||||
int convolutional_out_height(convolutional_layer layer);
|
||||
int convolutional_out_width(convolutional_layer layer);
|
||||
void rescale_filters(convolutional_layer l, float scale, float trans);
|
||||
void rgbgr_filters(convolutional_layer l);
|
||||
void rescale_weights(convolutional_layer l, float scale, float trans);
|
||||
void rgbgr_weights(convolutional_layer l);
|
||||
|
||||
#endif
|
||||
|
||||
|
59
src/cuda.c
59
src/cuda.c
@ -9,6 +9,20 @@ int gpu_index = 0;
|
||||
#include <stdlib.h>
|
||||
#include <time.h>
|
||||
|
||||
void cuda_set_device(int n)
|
||||
{
|
||||
gpu_index = n;
|
||||
cudaError_t status = cudaSetDevice(n);
|
||||
check_error(status);
|
||||
}
|
||||
|
||||
int cuda_get_device()
|
||||
{
|
||||
int n = 0;
|
||||
cudaError_t status = cudaGetDevice(&n);
|
||||
check_error(status);
|
||||
return n;
|
||||
}
|
||||
|
||||
void check_error(cudaError_t status)
|
||||
{
|
||||
@ -38,8 +52,8 @@ dim3 cuda_gridsize(size_t n){
|
||||
size_t x = k;
|
||||
size_t y = 1;
|
||||
if(x > 65535){
|
||||
x = ceil(sqrt(k));
|
||||
y = (n-1)/(x*BLOCK) + 1;
|
||||
x = ceil(sqrt(k));
|
||||
y = (n-1)/(x*BLOCK) + 1;
|
||||
}
|
||||
dim3 d = {x, y, 1};
|
||||
//printf("%ld %ld %ld %ld\n", n, x, y, x*y*BLOCK);
|
||||
@ -49,25 +63,27 @@ dim3 cuda_gridsize(size_t n){
|
||||
#ifdef CUDNN
|
||||
cudnnHandle_t cudnn_handle()
|
||||
{
|
||||
static int init = 0;
|
||||
static cudnnHandle_t handle;
|
||||
if(!init) {
|
||||
cudnnCreate(&handle);
|
||||
init = 1;
|
||||
static int init[16] = {0};
|
||||
static cudnnHandle_t handle[16];
|
||||
int i = cuda_get_device();
|
||||
if(!init[i]) {
|
||||
cudnnCreate(&handle[i]);
|
||||
init[i] = 1;
|
||||
}
|
||||
return handle;
|
||||
return handle[i];
|
||||
}
|
||||
#endif
|
||||
|
||||
cublasHandle_t blas_handle()
|
||||
{
|
||||
static int init = 0;
|
||||
static cublasHandle_t handle;
|
||||
if(!init) {
|
||||
cublasCreate(&handle);
|
||||
init = 1;
|
||||
static int init[16] = {0};
|
||||
static cublasHandle_t handle[16];
|
||||
int i = cuda_get_device();
|
||||
if(!init[i]) {
|
||||
cublasCreate(&handle[i]);
|
||||
init[i] = 1;
|
||||
}
|
||||
return handle;
|
||||
return handle[i];
|
||||
}
|
||||
|
||||
float *cuda_make_array(float *x, size_t n)
|
||||
@ -86,14 +102,15 @@ float *cuda_make_array(float *x, size_t n)
|
||||
|
||||
void cuda_random(float *x_gpu, size_t n)
|
||||
{
|
||||
static curandGenerator_t gen;
|
||||
static int init = 0;
|
||||
if(!init){
|
||||
curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT);
|
||||
curandSetPseudoRandomGeneratorSeed(gen, time(0));
|
||||
init = 1;
|
||||
static curandGenerator_t gen[16];
|
||||
static int init[16] = {0};
|
||||
int i = cuda_get_device();
|
||||
if(!init[i]){
|
||||
curandCreateGenerator(&gen[i], CURAND_RNG_PSEUDO_DEFAULT);
|
||||
curandSetPseudoRandomGeneratorSeed(gen[i], time(0));
|
||||
init[i] = 1;
|
||||
}
|
||||
curandGenerateUniform(gen, x_gpu, n);
|
||||
curandGenerateUniform(gen[i], x_gpu, n);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ float *cuda_make_array(float *x, size_t n);
|
||||
int *cuda_make_int_array(size_t n);
|
||||
void cuda_push_array(float *x_gpu, float *x, size_t n);
|
||||
void cuda_pull_array(float *x_gpu, float *x, size_t n);
|
||||
void cuda_set_device(int n);
|
||||
void cuda_free(float *x_gpu);
|
||||
void cuda_random(float *x_gpu, size_t n);
|
||||
float cuda_compare(float *x_gpu, float *x, size_t n, char *s);
|
||||
|
@ -66,7 +66,7 @@ void average(int argc, char *argv[])
|
||||
if(l.type == CONVOLUTIONAL){
|
||||
int num = l.n*l.c*l.size*l.size;
|
||||
axpy_cpu(l.n, 1, l.biases, 1, out.biases, 1);
|
||||
axpy_cpu(num, 1, l.filters, 1, out.filters, 1);
|
||||
axpy_cpu(num, 1, l.weights, 1, out.weights, 1);
|
||||
}
|
||||
if(l.type == CONNECTED){
|
||||
axpy_cpu(l.outputs, 1, l.biases, 1, out.biases, 1);
|
||||
@ -80,7 +80,7 @@ void average(int argc, char *argv[])
|
||||
if(l.type == CONVOLUTIONAL){
|
||||
int num = l.n*l.c*l.size*l.size;
|
||||
scal_cpu(l.n, 1./n, l.biases, 1);
|
||||
scal_cpu(num, 1./n, l.filters, 1);
|
||||
scal_cpu(num, 1./n, l.weights, 1);
|
||||
}
|
||||
if(l.type == CONNECTED){
|
||||
scal_cpu(l.outputs, 1./n, l.biases, 1);
|
||||
@ -159,7 +159,7 @@ void rescale_net(char *cfgfile, char *weightfile, char *outfile)
|
||||
for(i = 0; i < net.n; ++i){
|
||||
layer l = net.layers[i];
|
||||
if(l.type == CONVOLUTIONAL){
|
||||
rescale_filters(l, 2, -.5);
|
||||
rescale_weights(l, 2, -.5);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -177,7 +177,7 @@ void rgbgr_net(char *cfgfile, char *weightfile, char *outfile)
|
||||
for(i = 0; i < net.n; ++i){
|
||||
layer l = net.layers[i];
|
||||
if(l.type == CONVOLUTIONAL){
|
||||
rgbgr_filters(l);
|
||||
rgbgr_weights(l);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -354,8 +354,7 @@ int main(int argc, char **argv)
|
||||
gpu_index = -1;
|
||||
#else
|
||||
if(gpu_index >= 0){
|
||||
cudaError_t status = cudaSetDevice(gpu_index);
|
||||
check_error(status);
|
||||
cuda_set_device(gpu_index);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -27,7 +27,7 @@ extern "C" void forward_deconvolutional_layer_gpu(deconvolutional_layer layer, n
|
||||
fill_ongpu(layer.outputs*layer.batch, 0, layer.output_gpu, 1);
|
||||
|
||||
for(i = 0; i < layer.batch; ++i){
|
||||
float *a = layer.filters_gpu;
|
||||
float *a = layer.weights_gpu;
|
||||
float *b = state.input + i*layer.c*layer.h*layer.w;
|
||||
float *c = layer.col_image_gpu;
|
||||
|
||||
@ -59,7 +59,7 @@ extern "C" void backward_deconvolutional_layer_gpu(deconvolutional_layer layer,
|
||||
|
||||
float *a = state.input + i*m*n;
|
||||
float *b = layer.col_image_gpu;
|
||||
float *c = layer.filter_updates_gpu;
|
||||
float *c = layer.weight_updates_gpu;
|
||||
|
||||
im2col_ongpu(layer.delta_gpu + i*layer.n*size, layer.n, out_h, out_w,
|
||||
layer.size, layer.stride, 0, b);
|
||||
@ -70,7 +70,7 @@ extern "C" void backward_deconvolutional_layer_gpu(deconvolutional_layer layer,
|
||||
int n = layer.h*layer.w;
|
||||
int k = layer.size*layer.size*layer.n;
|
||||
|
||||
float *a = layer.filters_gpu;
|
||||
float *a = layer.weights_gpu;
|
||||
float *b = layer.col_image_gpu;
|
||||
float *c = state.delta + i*n*m;
|
||||
|
||||
@ -81,17 +81,17 @@ extern "C" void backward_deconvolutional_layer_gpu(deconvolutional_layer layer,
|
||||
|
||||
extern "C" void pull_deconvolutional_layer(deconvolutional_layer layer)
|
||||
{
|
||||
cuda_pull_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_pull_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_pull_array(layer.biases_gpu, layer.biases, layer.n);
|
||||
cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_pull_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
|
||||
}
|
||||
|
||||
extern "C" void push_deconvolutional_layer(deconvolutional_layer layer)
|
||||
{
|
||||
cuda_push_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
|
||||
cuda_push_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
|
||||
}
|
||||
|
||||
@ -102,8 +102,8 @@ extern "C" void update_deconvolutional_layer_gpu(deconvolutional_layer layer, fl
|
||||
axpy_ongpu(layer.n, learning_rate, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
|
||||
scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
|
||||
|
||||
axpy_ongpu(size, -decay, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
|
||||
axpy_ongpu(size, learning_rate, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
|
||||
scal_ongpu(size, momentum, layer.filter_updates_gpu, 1);
|
||||
axpy_ongpu(size, -decay, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
|
||||
axpy_ongpu(size, learning_rate, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
|
||||
scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);
|
||||
}
|
||||
|
||||
|
@ -57,13 +57,13 @@ deconvolutional_layer make_deconvolutional_layer(int batch, int h, int w, int c,
|
||||
l.stride = stride;
|
||||
l.size = size;
|
||||
|
||||
l.filters = calloc(c*n*size*size, sizeof(float));
|
||||
l.filter_updates = calloc(c*n*size*size, sizeof(float));
|
||||
l.weights = calloc(c*n*size*size, sizeof(float));
|
||||
l.weight_updates = calloc(c*n*size*size, sizeof(float));
|
||||
|
||||
l.biases = calloc(n, sizeof(float));
|
||||
l.bias_updates = calloc(n, sizeof(float));
|
||||
float scale = 1./sqrt(size*size*c);
|
||||
for(i = 0; i < c*n*size*size; ++i) l.filters[i] = scale*rand_normal();
|
||||
for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_normal();
|
||||
for(i = 0; i < n; ++i){
|
||||
l.biases[i] = scale;
|
||||
}
|
||||
@ -81,8 +81,8 @@ deconvolutional_layer make_deconvolutional_layer(int batch, int h, int w, int c,
|
||||
l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float));
|
||||
|
||||
#ifdef GPU
|
||||
l.filters_gpu = cuda_make_array(l.filters, c*n*size*size);
|
||||
l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size);
|
||||
l.weights_gpu = cuda_make_array(l.weights, c*n*size*size);
|
||||
l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
|
||||
|
||||
l.biases_gpu = cuda_make_array(l.biases, n);
|
||||
l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
|
||||
@ -137,7 +137,7 @@ void forward_deconvolutional_layer(const deconvolutional_layer l, network_state
|
||||
fill_cpu(l.outputs*l.batch, 0, l.output, 1);
|
||||
|
||||
for(i = 0; i < l.batch; ++i){
|
||||
float *a = l.filters;
|
||||
float *a = l.weights;
|
||||
float *b = state.input + i*l.c*l.h*l.w;
|
||||
float *c = l.col_image;
|
||||
|
||||
@ -167,7 +167,7 @@ void backward_deconvolutional_layer(deconvolutional_layer l, network_state state
|
||||
|
||||
float *a = state.input + i*m*n;
|
||||
float *b = l.col_image;
|
||||
float *c = l.filter_updates;
|
||||
float *c = l.weight_updates;
|
||||
|
||||
im2col_cpu(l.delta + i*l.n*size, l.n, out_h, out_w,
|
||||
l.size, l.stride, 0, b);
|
||||
@ -178,7 +178,7 @@ void backward_deconvolutional_layer(deconvolutional_layer l, network_state state
|
||||
int n = l.h*l.w;
|
||||
int k = l.size*l.size*l.n;
|
||||
|
||||
float *a = l.filters;
|
||||
float *a = l.weights;
|
||||
float *b = l.col_image;
|
||||
float *c = state.delta + i*n*m;
|
||||
|
||||
@ -193,9 +193,9 @@ void update_deconvolutional_layer(deconvolutional_layer l, float learning_rate,
|
||||
axpy_cpu(l.n, learning_rate, l.bias_updates, 1, l.biases, 1);
|
||||
scal_cpu(l.n, momentum, l.bias_updates, 1);
|
||||
|
||||
axpy_cpu(size, -decay, l.filters, 1, l.filter_updates, 1);
|
||||
axpy_cpu(size, learning_rate, l.filter_updates, 1, l.filters, 1);
|
||||
scal_cpu(size, momentum, l.filter_updates, 1);
|
||||
axpy_cpu(size, -decay, l.weights, 1, l.weight_updates, 1);
|
||||
axpy_cpu(size, learning_rate, l.weight_updates, 1, l.weights, 1);
|
||||
scal_cpu(size, momentum, l.weight_updates, 1);
|
||||
}
|
||||
|
||||
|
||||
|
@ -117,12 +117,18 @@ static void convert_detections(float *predictions, int classes, int num, int squ
|
||||
int box_index = index * (classes + 5);
|
||||
boxes[index].x = (predictions[box_index + 0] + col + .5) / side * w;
|
||||
boxes[index].y = (predictions[box_index + 1] + row + .5) / side * h;
|
||||
if(1){
|
||||
if(0){
|
||||
boxes[index].x = (logistic_activate(predictions[box_index + 0]) + col) / side * w;
|
||||
boxes[index].y = (logistic_activate(predictions[box_index + 1]) + row) / side * h;
|
||||
}
|
||||
boxes[index].w = pow(logistic_activate(predictions[box_index + 2]), (square?2:1)) * w;
|
||||
boxes[index].h = pow(logistic_activate(predictions[box_index + 3]), (square?2:1)) * h;
|
||||
if(1){
|
||||
boxes[index].x = ((col + .5)/side + predictions[box_index + 0] * .5) * w;
|
||||
boxes[index].y = ((row + .5)/side + predictions[box_index + 1] * .5) * h;
|
||||
boxes[index].w = (exp(predictions[box_index + 2]) * .5) * w;
|
||||
boxes[index].h = (exp(predictions[box_index + 3]) * .5) * h;
|
||||
}
|
||||
for(j = 0; j < classes; ++j){
|
||||
int class_index = index * (classes + 5) + 5;
|
||||
float prob = scale*predictions[class_index+j];
|
||||
|
@ -14,8 +14,6 @@ void free_layer(layer l)
|
||||
if(l.indexes) free(l.indexes);
|
||||
if(l.rand) free(l.rand);
|
||||
if(l.cost) free(l.cost);
|
||||
if(l.filters) free(l.filters);
|
||||
if(l.filter_updates) free(l.filter_updates);
|
||||
if(l.biases) free(l.biases);
|
||||
if(l.bias_updates) free(l.bias_updates);
|
||||
if(l.weights) free(l.weights);
|
||||
@ -30,8 +28,8 @@ void free_layer(layer l)
|
||||
|
||||
#ifdef GPU
|
||||
if(l.indexes_gpu) cuda_free((float *)l.indexes_gpu);
|
||||
if(l.filters_gpu) cuda_free(l.filters_gpu);
|
||||
if(l.filter_updates_gpu) cuda_free(l.filter_updates_gpu);
|
||||
if(l.weights_gpu) cuda_free(l.weights_gpu);
|
||||
if(l.weight_updates_gpu) cuda_free(l.weight_updates_gpu);
|
||||
if(l.col_image_gpu) cuda_free(l.col_image_gpu);
|
||||
if(l.weights_gpu) cuda_free(l.weights_gpu);
|
||||
if(l.biases_gpu) cuda_free(l.biases_gpu);
|
||||
|
14
src/layer.h
14
src/layer.h
@ -105,9 +105,7 @@ struct layer{
|
||||
int *indexes;
|
||||
float *rand;
|
||||
float *cost;
|
||||
float *filters;
|
||||
char *cfilters;
|
||||
float *filter_updates;
|
||||
char *cweights;
|
||||
float *state;
|
||||
float *prev_state;
|
||||
float *forgot_state;
|
||||
@ -117,7 +115,7 @@ struct layer{
|
||||
float *concat;
|
||||
float *concat_delta;
|
||||
|
||||
float *binary_filters;
|
||||
float *binary_weights;
|
||||
|
||||
float *biases;
|
||||
float *bias_updates;
|
||||
@ -194,11 +192,9 @@ struct layer{
|
||||
float * save_delta_gpu;
|
||||
float * concat_gpu;
|
||||
float * concat_delta_gpu;
|
||||
float * filters_gpu;
|
||||
float * filter_updates_gpu;
|
||||
|
||||
float *binary_input_gpu;
|
||||
float *binary_filters_gpu;
|
||||
float *binary_weights_gpu;
|
||||
|
||||
float * mean_gpu;
|
||||
float * variance_gpu;
|
||||
@ -230,8 +226,8 @@ struct layer{
|
||||
#ifdef CUDNN
|
||||
cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;
|
||||
cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc;
|
||||
cudnnFilterDescriptor_t filterDesc;
|
||||
cudnnFilterDescriptor_t dfilterDesc;
|
||||
cudnnFilterDescriptor_t weightDesc;
|
||||
cudnnFilterDescriptor_t dweightDesc;
|
||||
cudnnConvolutionDescriptor_t convDesc;
|
||||
cudnnConvolutionFwdAlgo_t fw_algo;
|
||||
cudnnConvolutionBwdDataAlgo_t bd_algo;
|
||||
|
@ -47,23 +47,23 @@ local_layer make_local_layer(int batch, int h, int w, int c, int n, int size, in
|
||||
l.outputs = l.out_h * l.out_w * l.out_c;
|
||||
l.inputs = l.w * l.h * l.c;
|
||||
|
||||
l.filters = calloc(c*n*size*size*locations, sizeof(float));
|
||||
l.filter_updates = calloc(c*n*size*size*locations, sizeof(float));
|
||||
l.weights = calloc(c*n*size*size*locations, sizeof(float));
|
||||
l.weight_updates = calloc(c*n*size*size*locations, sizeof(float));
|
||||
|
||||
l.biases = calloc(l.outputs, sizeof(float));
|
||||
l.bias_updates = calloc(l.outputs, sizeof(float));
|
||||
|
||||
// float scale = 1./sqrt(size*size*c);
|
||||
float scale = sqrt(2./(size*size*c));
|
||||
for(i = 0; i < c*n*size*size; ++i) l.filters[i] = scale*rand_uniform(-1,1);
|
||||
for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_uniform(-1,1);
|
||||
|
||||
l.col_image = calloc(out_h*out_w*size*size*c, sizeof(float));
|
||||
l.output = calloc(l.batch*out_h * out_w * n, sizeof(float));
|
||||
l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float));
|
||||
|
||||
#ifdef GPU
|
||||
l.filters_gpu = cuda_make_array(l.filters, c*n*size*size*locations);
|
||||
l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size*locations);
|
||||
l.weights_gpu = cuda_make_array(l.weights, c*n*size*size*locations);
|
||||
l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size*locations);
|
||||
|
||||
l.biases_gpu = cuda_make_array(l.biases, l.outputs);
|
||||
l.bias_updates_gpu = cuda_make_array(l.bias_updates, l.outputs);
|
||||
@ -97,7 +97,7 @@ void forward_local_layer(const local_layer l, network_state state)
|
||||
l.size, l.stride, l.pad, l.col_image);
|
||||
float *output = l.output + i*l.outputs;
|
||||
for(j = 0; j < locations; ++j){
|
||||
float *a = l.filters + j*l.size*l.size*l.c*l.n;
|
||||
float *a = l.weights + j*l.size*l.size*l.c*l.n;
|
||||
float *b = l.col_image + j;
|
||||
float *c = output + j;
|
||||
|
||||
@ -130,7 +130,7 @@ void backward_local_layer(local_layer l, network_state state)
|
||||
for(j = 0; j < locations; ++j){
|
||||
float *a = l.delta + i*l.outputs + j;
|
||||
float *b = l.col_image + j;
|
||||
float *c = l.filter_updates + j*l.size*l.size*l.c*l.n;
|
||||
float *c = l.weight_updates + j*l.size*l.size*l.c*l.n;
|
||||
int m = l.n;
|
||||
int n = l.size*l.size*l.c;
|
||||
int k = 1;
|
||||
@ -140,7 +140,7 @@ void backward_local_layer(local_layer l, network_state state)
|
||||
|
||||
if(state.delta){
|
||||
for(j = 0; j < locations; ++j){
|
||||
float *a = l.filters + j*l.size*l.size*l.c*l.n;
|
||||
float *a = l.weights + j*l.size*l.size*l.c*l.n;
|
||||
float *b = l.delta + i*l.outputs + j;
|
||||
float *c = l.col_image + j;
|
||||
|
||||
@ -163,9 +163,9 @@ void update_local_layer(local_layer l, int batch, float learning_rate, float mom
|
||||
axpy_cpu(l.outputs, learning_rate/batch, l.bias_updates, 1, l.biases, 1);
|
||||
scal_cpu(l.outputs, momentum, l.bias_updates, 1);
|
||||
|
||||
axpy_cpu(size, -decay*batch, l.filters, 1, l.filter_updates, 1);
|
||||
axpy_cpu(size, learning_rate/batch, l.filter_updates, 1, l.filters, 1);
|
||||
scal_cpu(size, momentum, l.filter_updates, 1);
|
||||
axpy_cpu(size, -decay*batch, l.weights, 1, l.weight_updates, 1);
|
||||
axpy_cpu(size, learning_rate/batch, l.weight_updates, 1, l.weights, 1);
|
||||
scal_cpu(size, momentum, l.weight_updates, 1);
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
@ -187,7 +187,7 @@ void forward_local_layer_gpu(const local_layer l, network_state state)
|
||||
l.size, l.stride, l.pad, l.col_image_gpu);
|
||||
float *output = l.output_gpu + i*l.outputs;
|
||||
for(j = 0; j < locations; ++j){
|
||||
float *a = l.filters_gpu + j*l.size*l.size*l.c*l.n;
|
||||
float *a = l.weights_gpu + j*l.size*l.size*l.c*l.n;
|
||||
float *b = l.col_image_gpu + j;
|
||||
float *c = output + j;
|
||||
|
||||
@ -219,7 +219,7 @@ void backward_local_layer_gpu(local_layer l, network_state state)
|
||||
for(j = 0; j < locations; ++j){
|
||||
float *a = l.delta_gpu + i*l.outputs + j;
|
||||
float *b = l.col_image_gpu + j;
|
||||
float *c = l.filter_updates_gpu + j*l.size*l.size*l.c*l.n;
|
||||
float *c = l.weight_updates_gpu + j*l.size*l.size*l.c*l.n;
|
||||
int m = l.n;
|
||||
int n = l.size*l.size*l.c;
|
||||
int k = 1;
|
||||
@ -229,7 +229,7 @@ void backward_local_layer_gpu(local_layer l, network_state state)
|
||||
|
||||
if(state.delta){
|
||||
for(j = 0; j < locations; ++j){
|
||||
float *a = l.filters_gpu + j*l.size*l.size*l.c*l.n;
|
||||
float *a = l.weights_gpu + j*l.size*l.size*l.c*l.n;
|
||||
float *b = l.delta_gpu + i*l.outputs + j;
|
||||
float *c = l.col_image_gpu + j;
|
||||
|
||||
@ -252,16 +252,16 @@ void update_local_layer_gpu(local_layer l, int batch, float learning_rate, float
|
||||
axpy_ongpu(l.outputs, learning_rate/batch, l.bias_updates_gpu, 1, l.biases_gpu, 1);
|
||||
scal_ongpu(l.outputs, momentum, l.bias_updates_gpu, 1);
|
||||
|
||||
axpy_ongpu(size, -decay*batch, l.filters_gpu, 1, l.filter_updates_gpu, 1);
|
||||
axpy_ongpu(size, learning_rate/batch, l.filter_updates_gpu, 1, l.filters_gpu, 1);
|
||||
scal_ongpu(size, momentum, l.filter_updates_gpu, 1);
|
||||
axpy_ongpu(size, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1);
|
||||
axpy_ongpu(size, learning_rate/batch, l.weight_updates_gpu, 1, l.weights_gpu, 1);
|
||||
scal_ongpu(size, momentum, l.weight_updates_gpu, 1);
|
||||
}
|
||||
|
||||
void pull_local_layer(local_layer l)
|
||||
{
|
||||
int locations = l.out_w*l.out_h;
|
||||
int size = l.size*l.size*l.c*l.n*locations;
|
||||
cuda_pull_array(l.filters_gpu, l.filters, size);
|
||||
cuda_pull_array(l.weights_gpu, l.weights, size);
|
||||
cuda_pull_array(l.biases_gpu, l.biases, l.outputs);
|
||||
}
|
||||
|
||||
@ -269,7 +269,7 @@ void push_local_layer(local_layer l)
|
||||
{
|
||||
int locations = l.out_w*l.out_h;
|
||||
int size = l.size*l.size*l.c*l.n*locations;
|
||||
cuda_push_array(l.filters_gpu, l.filters, size);
|
||||
cuda_push_array(l.weights_gpu, l.weights, size);
|
||||
cuda_push_array(l.biases_gpu, l.biases, l.outputs);
|
||||
}
|
||||
#endif
|
||||
|
@ -318,11 +318,11 @@ void backward_network(network net, network_state state)
|
||||
|
||||
float train_network_datum(network net, float *x, float *y)
|
||||
{
|
||||
*net.seen += net.batch;
|
||||
#ifdef GPU
|
||||
if(gpu_index >= 0) return train_network_datum_gpu(net, x, y);
|
||||
#endif
|
||||
network_state state;
|
||||
*net.seen += net.batch;
|
||||
state.index = 0;
|
||||
state.net = net;
|
||||
state.input = x;
|
||||
|
@ -65,6 +65,7 @@ typedef struct network_state {
|
||||
} network_state;
|
||||
|
||||
#ifdef GPU
|
||||
float train_networks(network *nets, int n, data d);
|
||||
float train_network_datum_gpu(network net, float *x, float *y);
|
||||
float *network_predict_gpu(network net, float *input);
|
||||
float * get_network_output_gpu_layer(network net, int i);
|
||||
|
@ -209,6 +209,7 @@ void forward_backward_network_gpu(network net, float *x, float *y)
|
||||
|
||||
float train_network_datum_gpu(network net, float *x, float *y)
|
||||
{
|
||||
*net.seen += net.batch;
|
||||
forward_backward_network_gpu(net, x, y);
|
||||
float error = get_network_cost(net);
|
||||
if (((*net.seen) / net.batch) % net.subdivisions == 0) update_network_gpu(net);
|
||||
@ -226,25 +227,115 @@ void *train_thread(void *ptr)
|
||||
{
|
||||
train_args args = *(train_args*)ptr;
|
||||
|
||||
cudaError_t status = cudaSetDevice(args.net.gpu_index);
|
||||
check_error(status);
|
||||
cuda_set_device(args.net.gpu_index);
|
||||
forward_backward_network_gpu(args.net, args.X, args.y);
|
||||
free(ptr);
|
||||
return 0;
|
||||
}
|
||||
|
||||
pthread_t train_network_in_thread(train_args args)
|
||||
pthread_t train_network_in_thread(network net, float *X, float *y)
|
||||
{
|
||||
pthread_t thread;
|
||||
train_args *ptr = (train_args *)calloc(1, sizeof(train_args));
|
||||
*ptr = args;
|
||||
ptr->net = net;
|
||||
ptr->X = X;
|
||||
ptr->y = y;
|
||||
if(pthread_create(&thread, 0, train_thread, ptr)) error("Thread creation failed");
|
||||
return thread;
|
||||
}
|
||||
|
||||
void pull_updates(layer l)
|
||||
{
|
||||
#ifdef GPU
|
||||
if(l.type == CONVOLUTIONAL){
|
||||
cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.n);
|
||||
cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.n*l.size*l.size*l.c);
|
||||
if(l.scale_updates) cuda_pull_array(l.scale_updates_gpu, l.scale_updates, l.n);
|
||||
} else if(l.type == CONNECTED){
|
||||
cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
|
||||
cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.outputs*l.inputs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void push_updates(layer l)
|
||||
{
|
||||
#ifdef GPU
|
||||
if(l.type == CONVOLUTIONAL){
|
||||
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n);
|
||||
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.n*l.size*l.size*l.c);
|
||||
if(l.scale_updates) cuda_push_array(l.scale_updates_gpu, l.scale_updates, l.n);
|
||||
} else if(l.type == CONNECTED){
|
||||
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
|
||||
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.outputs*l.inputs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void merge_updates(layer l, layer base)
|
||||
{
|
||||
if (l.type == CONVOLUTIONAL) {
|
||||
axpy_cpu(l.n, 1, l.bias_updates, 1, base.bias_updates, 1);
|
||||
axpy_cpu(l.n*l.size*l.size*l.c, 1, l.weight_updates, 1, base.weight_updates, 1);
|
||||
if (l.scale_updates) {
|
||||
axpy_cpu(l.n, 1, l.scale_updates, 1, base.scale_updates, 1);
|
||||
}
|
||||
} else if(l.type == CONNECTED) {
|
||||
axpy_cpu(l.outputs, 1, l.bias_updates, 1, base.bias_updates, 1);
|
||||
axpy_cpu(l.outputs*l.inputs, 1, l.weight_updates, 1, base.weight_updates, 1);
|
||||
}
|
||||
}
|
||||
|
||||
void distribute_updates(layer l, layer base)
|
||||
{
|
||||
if (l.type == CONVOLUTIONAL) {
|
||||
copy_cpu(l.n, base.bias_updates, 1, l.bias_updates, 1);
|
||||
copy_cpu(l.n*l.size*l.size*l.c, base.weight_updates, 1, l.weight_updates, 1);
|
||||
if (l.scale_updates) {
|
||||
copy_cpu(l.n, base.scale_updates, 1, l.scale_updates, 1);
|
||||
}
|
||||
} else if(l.type == CONNECTED) {
|
||||
copy_cpu(l.outputs, base.bias_updates, 1, l.bias_updates, 1);
|
||||
copy_cpu(l.outputs*l.inputs, base.weight_updates, 1, l.weight_updates, 1);
|
||||
}
|
||||
}
|
||||
|
||||
void sync_updates(network *nets, int n)
|
||||
{
|
||||
int i,j;
|
||||
int layers = nets[0].n;
|
||||
network net = nets[0];
|
||||
for (j = 0; j < layers; ++j) {
|
||||
layer base = net.layers[j];
|
||||
cuda_set_device(net.gpu_index);
|
||||
pull_updates(base);
|
||||
for (i = 1; i < n; ++i) {
|
||||
cuda_set_device(nets[i].gpu_index);
|
||||
layer l = nets[i].layers[j];
|
||||
pull_updates(l);
|
||||
merge_updates(l, base);
|
||||
}
|
||||
for (i = 1; i < n; ++i) {
|
||||
cuda_set_device(nets[i].gpu_index);
|
||||
layer l = nets[i].layers[j];
|
||||
distribute_updates(l, base);
|
||||
push_updates(l);
|
||||
}
|
||||
cuda_set_device(net.gpu_index);
|
||||
push_updates(base);
|
||||
}
|
||||
for (i = 0; i < n; ++i) {
|
||||
cuda_set_device(nets[i].gpu_index);
|
||||
if(i > 0) nets[i].momentum = 0;
|
||||
update_network_gpu(nets[i]);
|
||||
}
|
||||
}
|
||||
|
||||
float train_networks(network *nets, int n, data d)
|
||||
{
|
||||
int batch = nets[0].batch;
|
||||
assert(batch * n == d.X.rows);
|
||||
assert(nets[0].subdivisions % n == 0);
|
||||
float **X = (float **) calloc(n, sizeof(float *));
|
||||
float **y = (float **) calloc(n, sizeof(float *));
|
||||
pthread_t *threads = (pthread_t *) calloc(n, sizeof(pthread_t));
|
||||
@ -255,11 +346,20 @@ float train_networks(network *nets, int n, data d)
|
||||
X[i] = (float *) calloc(batch*d.X.cols, sizeof(float));
|
||||
y[i] = (float *) calloc(batch*d.y.cols, sizeof(float));
|
||||
get_next_batch(d, batch, i*batch, X[i], y[i]);
|
||||
float err = train_network_datum(nets[i], X[i], y[i]);
|
||||
sum += err;
|
||||
threads[i] = train_network_in_thread(nets[i], X[i], y[i]);
|
||||
}
|
||||
for(i = 0; i < n; ++i){
|
||||
pthread_join(threads[i], 0);
|
||||
*nets[i].seen += n*nets[i].batch;
|
||||
printf("%f\n", get_network_cost(nets[i]) / batch);
|
||||
sum += get_network_cost(nets[i]);
|
||||
free(X[i]);
|
||||
free(y[i]);
|
||||
}
|
||||
if (((*nets[0].seen) / nets[0].batch) % nets[0].subdivisions == 0) sync_updates(nets, n);
|
||||
free(X);
|
||||
free(y);
|
||||
free(threads);
|
||||
return (float)sum/(n*batch);
|
||||
}
|
||||
|
||||
|
35
src/parser.c
35
src/parser.c
@ -551,6 +551,7 @@ network parse_network_cfg(char *filename)
|
||||
node *n = sections->front;
|
||||
if(!n) error("Config file has no sections");
|
||||
network net = make_network(sections->size - 1);
|
||||
net.gpu_index = gpu_index;
|
||||
size_params params;
|
||||
|
||||
section *s = (section *)n->val;
|
||||
@ -856,13 +857,13 @@ void save_weights_double(network net, char *filename)
|
||||
|
||||
for (j = 0; j < l.n; ++j){
|
||||
int index = j*l.c*l.size*l.size;
|
||||
fwrite(l.filters+index, sizeof(float), l.c*l.size*l.size, fp);
|
||||
fwrite(l.weights+index, sizeof(float), l.c*l.size*l.size, fp);
|
||||
for (k = 0; k < l.c*l.size*l.size; ++k) fwrite(&zero, sizeof(float), 1, fp);
|
||||
}
|
||||
for (j = 0; j < l.n; ++j){
|
||||
int index = j*l.c*l.size*l.size;
|
||||
for (k = 0; k < l.c*l.size*l.size; ++k) fwrite(&zero, sizeof(float), 1, fp);
|
||||
fwrite(l.filters+index, sizeof(float), l.c*l.size*l.size, fp);
|
||||
fwrite(l.weights+index, sizeof(float), l.c*l.size*l.size, fp);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -876,7 +877,7 @@ void save_convolutional_weights_binary(layer l, FILE *fp)
|
||||
pull_convolutional_layer(l);
|
||||
}
|
||||
#endif
|
||||
binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
|
||||
binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
|
||||
int size = l.c*l.size*l.size;
|
||||
int i, j, k;
|
||||
fwrite(l.biases, sizeof(float), l.n, fp);
|
||||
@ -886,7 +887,7 @@ void save_convolutional_weights_binary(layer l, FILE *fp)
|
||||
fwrite(l.rolling_variance, sizeof(float), l.n, fp);
|
||||
}
|
||||
for(i = 0; i < l.n; ++i){
|
||||
float mean = l.binary_filters[i*size];
|
||||
float mean = l.binary_weights[i*size];
|
||||
if(mean < 0) mean = -mean;
|
||||
fwrite(&mean, sizeof(float), 1, fp);
|
||||
for(j = 0; j < size/8; ++j){
|
||||
@ -894,7 +895,7 @@ void save_convolutional_weights_binary(layer l, FILE *fp)
|
||||
unsigned char c = 0;
|
||||
for(k = 0; k < 8; ++k){
|
||||
if (j*8 + k >= size) break;
|
||||
if (l.binary_filters[index + k] > 0) c = (c | 1<<k);
|
||||
if (l.binary_weights[index + k] > 0) c = (c | 1<<k);
|
||||
}
|
||||
fwrite(&c, sizeof(char), 1, fp);
|
||||
}
|
||||
@ -919,7 +920,7 @@ void save_convolutional_weights(layer l, FILE *fp)
|
||||
fwrite(l.rolling_mean, sizeof(float), l.n, fp);
|
||||
fwrite(l.rolling_variance, sizeof(float), l.n, fp);
|
||||
}
|
||||
fwrite(l.filters, sizeof(float), num, fp);
|
||||
fwrite(l.weights, sizeof(float), num, fp);
|
||||
}
|
||||
|
||||
void save_batchnorm_weights(layer l, FILE *fp)
|
||||
@ -952,6 +953,9 @@ void save_connected_weights(layer l, FILE *fp)
|
||||
|
||||
void save_weights_upto(network net, char *filename, int cutoff)
|
||||
{
|
||||
#ifdef GPU
|
||||
cuda_set_device(net.gpu_index);
|
||||
#endif
|
||||
fprintf(stderr, "Saving weights to %s\n", filename);
|
||||
FILE *fp = fopen(filename, "w");
|
||||
if(!fp) file_error(filename);
|
||||
@ -997,7 +1001,7 @@ void save_weights_upto(network net, char *filename, int cutoff)
|
||||
int locations = l.out_w*l.out_h;
|
||||
int size = l.size*l.size*l.c*l.n*locations;
|
||||
fwrite(l.biases, sizeof(float), l.outputs, fp);
|
||||
fwrite(l.filters, sizeof(float), size, fp);
|
||||
fwrite(l.weights, sizeof(float), size, fp);
|
||||
}
|
||||
}
|
||||
fclose(fp);
|
||||
@ -1075,7 +1079,7 @@ void load_convolutional_weights_binary(layer l, FILE *fp)
|
||||
fread(&c, sizeof(char), 1, fp);
|
||||
for(k = 0; k < 8; ++k){
|
||||
if (j*8 + k >= size) break;
|
||||
l.filters[index + k] = (c & 1<<k) ? mean : -mean;
|
||||
l.weights[index + k] = (c & 1<<k) ? mean : -mean;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1099,12 +1103,12 @@ void load_convolutional_weights(layer l, FILE *fp)
|
||||
fread(l.rolling_mean, sizeof(float), l.n, fp);
|
||||
fread(l.rolling_variance, sizeof(float), l.n, fp);
|
||||
}
|
||||
fread(l.filters, sizeof(float), num, fp);
|
||||
//if(l.c == 3) scal_cpu(num, 1./256, l.filters, 1);
|
||||
fread(l.weights, sizeof(float), num, fp);
|
||||
//if(l.c == 3) scal_cpu(num, 1./256, l.weights, 1);
|
||||
if (l.flipped) {
|
||||
transpose_matrix(l.filters, l.c*l.size*l.size, l.n);
|
||||
transpose_matrix(l.weights, l.c*l.size*l.size, l.n);
|
||||
}
|
||||
//if (l.binary) binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.filters);
|
||||
//if (l.binary) binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.weights);
|
||||
#ifdef GPU
|
||||
if(gpu_index >= 0){
|
||||
push_convolutional_layer(l);
|
||||
@ -1115,6 +1119,9 @@ void load_convolutional_weights(layer l, FILE *fp)
|
||||
|
||||
void load_weights_upto(network *net, char *filename, int cutoff)
|
||||
{
|
||||
#ifdef GPU
|
||||
cuda_set_device(net->gpu_index);
|
||||
#endif
|
||||
fprintf(stderr, "Loading weights from %s...", filename);
|
||||
fflush(stdout);
|
||||
FILE *fp = fopen(filename, "rb");
|
||||
@ -1139,7 +1146,7 @@ void load_weights_upto(network *net, char *filename, int cutoff)
|
||||
if(l.type == DECONVOLUTIONAL){
|
||||
int num = l.n*l.c*l.size*l.size;
|
||||
fread(l.biases, sizeof(float), l.n, fp);
|
||||
fread(l.filters, sizeof(float), num, fp);
|
||||
fread(l.weights, sizeof(float), num, fp);
|
||||
#ifdef GPU
|
||||
if(gpu_index >= 0){
|
||||
push_deconvolutional_layer(l);
|
||||
@ -1174,7 +1181,7 @@ void load_weights_upto(network *net, char *filename, int cutoff)
|
||||
int locations = l.out_w*l.out_h;
|
||||
int size = l.size*l.size*l.c*l.n*locations;
|
||||
fread(l.biases, sizeof(float), l.outputs, fp);
|
||||
fread(l.filters, sizeof(float), size, fp);
|
||||
fread(l.weights, sizeof(float), size, fp);
|
||||
#ifdef GPU
|
||||
if(gpu_index >= 0){
|
||||
push_local_layer(l);
|
||||
|
@ -22,11 +22,18 @@ region_layer make_region_layer(int batch, int w, int h, int n, int classes, int
|
||||
l.classes = classes;
|
||||
l.coords = coords;
|
||||
l.cost = calloc(1, sizeof(float));
|
||||
l.biases = calloc(n*2, sizeof(float));
|
||||
l.bias_updates = calloc(n*2, sizeof(float));
|
||||
l.outputs = h*w*n*(classes + coords + 1);
|
||||
l.inputs = l.outputs;
|
||||
l.truths = 30*(5);
|
||||
l.delta = calloc(batch*l.outputs, sizeof(float));
|
||||
l.output = calloc(batch*l.outputs, sizeof(float));
|
||||
int i;
|
||||
for(i = 0; i < n*2; ++i){
|
||||
l.biases[i] = .5;
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
|
||||
l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
|
||||
@ -38,62 +45,30 @@ region_layer make_region_layer(int batch, int w, int h, int n, int classes, int
|
||||
return l;
|
||||
}
|
||||
|
||||
box get_region_box2(float *x, int index, int i, int j, int w, int h)
|
||||
box get_region_box(float *x, float *biases, int n, int index, int i, int j, int w, int h)
|
||||
{
|
||||
float aspect = exp(x[index+0]);
|
||||
float scale = logistic_activate(x[index+1]);
|
||||
float move_x = x[index+2];
|
||||
float move_y = x[index+3];
|
||||
|
||||
box b;
|
||||
b.w = sqrt(scale * aspect);
|
||||
b.h = b.w * 1./aspect;
|
||||
b.x = move_x * b.w + (i + .5)/w;
|
||||
b.y = move_y * b.h + (j + .5)/h;
|
||||
b.x = (i + .5)/w + x[index + 0] * biases[2*n];
|
||||
b.y = (j + .5)/h + x[index + 1] * biases[2*n + 1];
|
||||
b.w = exp(x[index + 2]) * biases[2*n];
|
||||
b.h = exp(x[index + 3]) * biases[2*n+1];
|
||||
return b;
|
||||
}
|
||||
|
||||
float delta_region_box2(box truth, float *output, int index, int i, int j, int w, int h, float *delta)
|
||||
float delta_region_box(box truth, float *x, float *biases, int n, int index, int i, int j, int w, int h, float *delta, float scale)
|
||||
{
|
||||
box pred = get_region_box2(output, index, i, j, w, h);
|
||||
float iou = box_iou(pred, truth);
|
||||
float true_aspect = truth.w/truth.h;
|
||||
float true_scale = truth.w*truth.h;
|
||||
|
||||
float true_dx = (truth.x - (i+.5)/w) / truth.w;
|
||||
float true_dy = (truth.y - (j+.5)/h) / truth.h;
|
||||
delta[index + 0] = (true_aspect - exp(output[index + 0])) * exp(output[index + 0]);
|
||||
delta[index + 1] = (true_scale - logistic_activate(output[index + 1])) * logistic_gradient(logistic_activate(output[index + 1]));
|
||||
delta[index + 2] = true_dx - output[index + 2];
|
||||
delta[index + 3] = true_dy - output[index + 3];
|
||||
return iou;
|
||||
}
|
||||
|
||||
box get_region_box(float *x, int index, int i, int j, int w, int h, int adjust, int logistic)
|
||||
{
|
||||
box b;
|
||||
b.x = (x[index + 0] + i + .5)/w;
|
||||
b.y = (x[index + 1] + j + .5)/h;
|
||||
b.w = x[index + 2];
|
||||
b.h = x[index + 3];
|
||||
if(logistic){
|
||||
b.w = logistic_activate(x[index + 2]);
|
||||
b.h = logistic_activate(x[index + 3]);
|
||||
}
|
||||
if(adjust && b.w < .01) b.w = .01;
|
||||
if(adjust && b.h < .01) b.h = .01;
|
||||
return b;
|
||||
}
|
||||
|
||||
float delta_region_box(box truth, float *output, int index, int i, int j, int w, int h, float *delta, int logistic, float scale)
|
||||
{
|
||||
box pred = get_region_box(output, index, i, j, w, h, 0, logistic);
|
||||
box pred = get_region_box(x, biases, n, index, i, j, w, h);
|
||||
float iou = box_iou(pred, truth);
|
||||
|
||||
delta[index + 0] = scale * (truth.x - pred.x);
|
||||
delta[index + 1] = scale * (truth.y - pred.y);
|
||||
delta[index + 2] = scale * ((truth.w - pred.w)*(logistic ? logistic_gradient(pred.w) : 1));
|
||||
delta[index + 3] = scale * ((truth.h - pred.h)*(logistic ? logistic_gradient(pred.h) : 1));
|
||||
float tx = (truth.x - (i + .5)/w) / biases[2*n];
|
||||
float ty = (truth.y - (j + .5)/h) / biases[2*n + 1];
|
||||
float tw = log(truth.w / biases[2*n]);
|
||||
float th = log(truth.h / biases[2*n + 1]);
|
||||
|
||||
delta[index + 0] = scale * (tx - x[index + 0]);
|
||||
delta[index + 1] = scale * (ty - x[index + 1]);
|
||||
delta[index + 2] = scale * (tw - x[index + 2]);
|
||||
delta[index + 3] = scale * (th - x[index + 3]);
|
||||
return iou;
|
||||
}
|
||||
|
||||
@ -107,7 +82,7 @@ float tisnan(float x)
|
||||
return (x != x);
|
||||
}
|
||||
|
||||
#define LOG 1
|
||||
#define LOG 0
|
||||
|
||||
void forward_region_layer(const region_layer l, network_state state)
|
||||
{
|
||||
@ -127,6 +102,7 @@ void forward_region_layer(const region_layer l, network_state state)
|
||||
if(!state.train) return;
|
||||
memset(l.delta, 0, l.outputs * l.batch * sizeof(float));
|
||||
float avg_iou = 0;
|
||||
float recall = 0;
|
||||
float avg_cat = 0;
|
||||
float avg_obj = 0;
|
||||
float avg_anyobj = 0;
|
||||
@ -137,7 +113,7 @@ void forward_region_layer(const region_layer l, network_state state)
|
||||
for (i = 0; i < l.w; ++i) {
|
||||
for (n = 0; n < l.n; ++n) {
|
||||
int index = size*(j*l.w*l.n + i*l.n + n) + b*l.outputs;
|
||||
box pred = get_region_box(l.output, index, i, j, l.w, l.h, 1, LOG);
|
||||
box pred = get_region_box(l.output, l.biases, n, index, i, j, l.w, l.h);
|
||||
float best_iou = 0;
|
||||
for(t = 0; t < 30; ++t){
|
||||
box truth = float_to_box(state.truth + t*5 + b*l.truths);
|
||||
@ -155,7 +131,11 @@ void forward_region_layer(const region_layer l, network_state state)
|
||||
truth.y = (j + .5)/l.h;
|
||||
truth.w = .5;
|
||||
truth.h = .5;
|
||||
delta_region_box(truth, l.output, index, i, j, l.w, l.h, l.delta, LOG, 1);
|
||||
delta_region_box(truth, l.output, l.biases, n, index, i, j, l.w, l.h, l.delta, .01);
|
||||
//l.delta[index + 0] = .1 * (0 - l.output[index + 0]);
|
||||
//l.delta[index + 1] = .1 * (0 - l.output[index + 1]);
|
||||
//l.delta[index + 2] = .1 * (0 - l.output[index + 2]);
|
||||
//l.delta[index + 3] = .1 * (0 - l.output[index + 3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -176,8 +156,8 @@ void forward_region_layer(const region_layer l, network_state state)
|
||||
printf("index %d %d\n",i, j);
|
||||
for(n = 0; n < l.n; ++n){
|
||||
int index = size*(j*l.w*l.n + i*l.n + n) + b*l.outputs;
|
||||
box pred = get_region_box(l.output, index, i, j, l.w, l.h, 1, LOG);
|
||||
printf("pred: (%f, %f) %f x %f\n", pred.x, pred.y, pred.w, pred.h);
|
||||
box pred = get_region_box(l.output, l.biases, n, index, i, j, l.w, l.h);
|
||||
printf("pred: (%f, %f) %f x %f\n", pred.x*l.w - i - .5, pred.y * l.h - j - .5, pred.w, pred.h);
|
||||
pred.x = 0;
|
||||
pred.y = 0;
|
||||
float iou = box_iou(pred, truth_shift);
|
||||
@ -187,9 +167,10 @@ void forward_region_layer(const region_layer l, network_state state)
|
||||
best_n = n;
|
||||
}
|
||||
}
|
||||
printf("%d %f (%f, %f) %f x %f\n", best_n, best_iou, truth.x, truth.y, truth.w, truth.h);
|
||||
printf("%d %f (%f, %f) %f x %f\n", best_n, best_iou, truth.x * l.w - i - .5, truth.y*l.h - j - .5, truth.w, truth.h);
|
||||
|
||||
float iou = delta_region_box(truth, l.output, best_index, i, j, l.w, l.h, l.delta, LOG, l.coord_scale);
|
||||
float iou = delta_region_box(truth, l.output, l.biases, best_n, best_index, i, j, l.w, l.h, l.delta, l.coord_scale);
|
||||
if(iou > .5) recall += 1;
|
||||
avg_iou += iou;
|
||||
|
||||
//l.delta[best_index + 4] = iou - l.output[best_index + 4];
|
||||
@ -239,7 +220,7 @@ void forward_region_layer(const region_layer l, network_state state)
|
||||
printf("\n");
|
||||
reorg(l.delta, l.w*l.h, size*l.n, l.batch, 0);
|
||||
*(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2);
|
||||
printf("Region Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_obj/count, avg_anyobj/(l.w*l.h*l.n*l.batch), count);
|
||||
printf("Region Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, Avg Recall: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_obj/count, avg_anyobj/(l.w*l.h*l.n*l.batch), recall/count, count);
|
||||
}
|
||||
|
||||
void backward_region_layer(const region_layer l, network_state state)
|
||||
|
Loading…
Reference in New Issue
Block a user