NO FUCKING SPOILERS DOUG

This commit is contained in:
Joseph Redmon 2018-01-16 14:30:00 -08:00
parent 6e79145309
commit e3931c75cd
23 changed files with 450 additions and 90 deletions

View File

@ -57,7 +57,7 @@ CFLAGS+= -DCUDNN
LDFLAGS+= -lcudnn
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o detection_layer.o route_layer.o upsample_layer.o box.o normalization_layer.o avgpool_layer.o layer.o local_layer.o shortcut_layer.o logistic_layer.o activation_layer.o rnn_layer.o gru_layer.o crnn_layer.o demo.o batchnorm_layer.o region_layer.o reorg_layer.o tree.o lstm_layer.o
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o detection_layer.o route_layer.o upsample_layer.o box.o normalization_layer.o avgpool_layer.o layer.o local_layer.o shortcut_layer.o logistic_layer.o activation_layer.o rnn_layer.o gru_layer.o crnn_layer.o demo.o batchnorm_layer.o region_layer.o reorg_layer.o tree.o lstm_layer.o l2norm_layer.o
EXECOBJA=captcha.o lsd.o super.o art.o tag.o cifar.o go.o rnn.o segmenter.o regressor.o classifier.o coco.o yolo.o detector.o nightmare.o darknet.o
ifeq ($(GPU), 1)
LDFLAGS+= -lstdc++

View File

@ -1,10 +1,10 @@
[net]
# Train
# batch=128
# subdivisions=1
batch=128
subdivisions=1
# Test
batch=1
subdivisions=1
#batch=1
#subdivisions=1
height=256
width=256
channels=3
@ -88,7 +88,6 @@ activation=leaky
[maxpool]
size=2
stride=2
padding=1
[convolutional]
batch_normalize=1
@ -110,6 +109,3 @@ activation=leaky
[softmax]
groups=1
[cost]
type=sse

View File

@ -1,17 +1,31 @@
[net]
batch=128
subdivisions=1
height=224
width=224
# Training
#batch=128
#subdivisions=2
# Testing
batch=1
subdivisions=1
height=256
width=256
min_crop=128
max_crop=448
channels=3
momentum=0.9
decay=0.0005
max_crop=448
burn_in=1000
learning_rate=0.1
policy=poly
power=4
max_batches=1600000
max_batches=800000
angle=7
hue=.1
saturation=.75
exposure=.75
aspect=.75
[convolutional]
batch_normalize=1

View File

@ -578,6 +578,8 @@ void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *fi
}
image im = load_image_color(input, 0, 0);
image r = letterbox_image(im, net->w, net->h);
//image r = resize_min(im, 320);
//printf("%d %d\n", r.w, r.h);
//resize_network(net, r.w, r.h);
//printf("%d %d\n", r.w, r.h);

View File

@ -383,6 +383,92 @@ void train_pix2pix(char *cfg, char *weight, char *acfg, char *aweight, int clear
}
*/
void slerp(float *start, float *end, float s, int n, float *out)
{
float omega = acos(dot_cpu(n, start, 1, end, 1));
float so = sin(omega);
fill_cpu(n, 0, out, 1);
axpy_cpu(n, sin((1-s)*omega)/so, start, 1, out, 1);
axpy_cpu(n, sin(s*omega)/so, end, 1, out, 1);
float mag = mag_array(out, n);
scale_array(out, n, 1./mag);
}
image random_unit_vector_image(w, h, c)
{
image im = make_image(w, h, c);
int i;
for(i = 0; i < im.w*im.h*im.c; ++i){
im.data[i] = rand_normal();
}
float mag = mag_array(im.data, im.w*im.h*im.c);
scale_array(im.data, im.w*im.h*im.c, 1./mag);
return im;
}
void inter_dcgan(char *cfgfile, char *weightfile)
{
network *net = load_network(cfgfile, weightfile, 0);
set_batch_network(net, 1);
srand(2222222);
clock_t time;
char buff[256];
char *input = buff;
int i, imlayer = 0;
for (i = 0; i < net->n; ++i) {
if (net->layers[i].out_c == 3) {
imlayer = i;
printf("%d\n", i);
break;
}
}
image start = random_unit_vector_image(net->w, net->h, net->c);
image end = random_unit_vector_image(net->w, net->h, net->c);
image im = make_image(net->w, net->h, net->c);
image orig = copy_image(start);
int c = 0;
int count = 0;
int max_count = 15;
while(1){
++c;
if(count == max_count){
count = 0;
free_image(start);
start = end;
end = random_unit_vector_image(net->w, net->h, net->c);
if(c > 300){
end = orig;
}
if(c>300 + max_count) return;
}
++count;
slerp(start.data, end.data, (float)count / max_count, im.w*im.h*im.c, im.data);
float *X = im.data;
time=clock();
network_predict(net, X);
image out = get_network_image_layer(net, imlayer);
//yuv_to_rgb(out);
normalize_image(out);
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
//char buff[256];
sprintf(buff, "out%05d", c);
show_image(out, "out");
save_image(out, "out");
save_image(out, buff);
#ifdef OPENCV
//cvWaitKey(0);
#endif
}
}
void test_dcgan(char *cfgfile, char *weightfile)
{
network *net = load_network(cfgfile, weightfile, 0);
@ -409,7 +495,7 @@ void test_dcgan(char *cfgfile, char *weightfile)
im.data[i] = rand_normal();
}
float mag = mag_array(im.data, im.w*im.h*im.c);
//scale_array(im.data, im.w*im.h*im.c, 1./mag);
scale_array(im.data, im.w*im.h*im.c, 1./mag);
float *X = im.data;
time=clock();
@ -429,7 +515,7 @@ void test_dcgan(char *cfgfile, char *weightfile)
}
void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, int display, char *train_images)
void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, int display, char *train_images, int maxbatch)
{
#ifdef GPU
char *backup_directory = "/home/pjreddie/backup/";
@ -441,7 +527,6 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear,
network *anet = load_network(acfg, aweight, clear);
//float orig_rate = anet->learning_rate;
int start = 0;
int i, j, k;
layer imlayer = {0};
for (i = 0; i < gnet->n; ++i) {
@ -488,8 +573,8 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear,
//data generated = copy_data(train);
while (get_current_batch(gnet) < gnet->max_batches) {
start += 1;
if (maxbatch == 0) maxbatch = gnet->max_batches;
while (get_current_batch(gnet) < maxbatch) {
i += 1;
time=clock();
pthread_join(load_thread, 0);
@ -519,6 +604,13 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear,
float mag = mag_array(gnet->input + z*gnet->inputs, gnet->inputs);
scale_array(gnet->input + z*gnet->inputs, gnet->inputs, 1./mag);
}
/*
for(z = 0; z < 100; ++z){
printf("%f, ", gnet->input[z]);
}
printf("\n");
printf("input: %f %f\n", mean_array(gnet->input, x_size), variance_array(gnet->input, x_size));
*/
//cuda_push_array(gnet->input_gpu, gnet->input, x_size);
//cuda_push_array(gnet->truth_gpu, gnet->truth, y_size);
@ -533,18 +625,26 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear,
backward_network(anet);
float genaloss = *anet->cost / anet->batch;
printf("%f\n", genaloss);
//printf("%f\n", genaloss);
scal_gpu(imlayer.outputs*imlayer.batch, 1, imerror, 1);
scal_gpu(imlayer.outputs*imlayer.batch, 0, gnet->layers[gnet->n-1].delta_gpu, 1);
printf("realness %f\n", cuda_mag_array(imerror, imlayer.outputs*imlayer.batch));
printf("features %f\n", cuda_mag_array(gnet->layers[gnet->n-1].delta_gpu, imlayer.outputs*imlayer.batch));
//printf("realness %f\n", cuda_mag_array(imerror, imlayer.outputs*imlayer.batch));
//printf("features %f\n", cuda_mag_array(gnet->layers[gnet->n-1].delta_gpu, imlayer.outputs*imlayer.batch));
axpy_gpu(imlayer.outputs*imlayer.batch, 1, imerror, 1, gnet->layers[gnet->n-1].delta_gpu, 1);
backward_network(gnet);
/*
for(k = 0; k < gnet->n; ++k){
layer l = gnet->layers[k];
cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
printf("%d: %f %f\n", k, mean_array(l.output, l.outputs*l.batch), variance_array(l.output, l.outputs*l.batch));
}
*/
for(k = 0; k < gnet->batch; ++k){
int index = j*gnet->batch + k;
copy_cpu(gnet->outputs, gnet->output + k*gnet->outputs, 1, gen.X.vals[index], 1);
@ -1104,6 +1204,7 @@ void run_lsd(int argc, char **argv)
int clear = find_arg(argc, argv, "-clear");
int display = find_arg(argc, argv, "-display");
int batches = find_int_arg(argc, argv, "-b", 0);
char *file = find_char_arg(argc, argv, "-file", "/home/pjreddie/data/imagenet/imagenet1k.train.list");
char *cfg = argv[3];
@ -1115,9 +1216,10 @@ void run_lsd(int argc, char **argv)
//else if(0==strcmp(argv[2], "train2")) train_lsd2(cfg, weights, acfg, aweights, clear);
//else if(0==strcmp(argv[2], "traincolor")) train_colorizer(cfg, weights, acfg, aweights, clear);
//else if(0==strcmp(argv[2], "train3")) train_lsd3(argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], clear);
if(0==strcmp(argv[2], "traingan")) train_dcgan(cfg, weights, acfg, aweights, clear, display, file);
if(0==strcmp(argv[2], "traingan")) train_dcgan(cfg, weights, acfg, aweights, clear, display, file, batches);
else if(0==strcmp(argv[2], "traincolor")) train_colorizer(cfg, weights, acfg, aweights, clear, display);
else if(0==strcmp(argv[2], "gan")) test_dcgan(cfg, weights);
else if(0==strcmp(argv[2], "inter")) inter_dcgan(cfg, weights);
else if(0==strcmp(argv[2], "test")) test_lsd(cfg, weights, filename, 0);
else if(0==strcmp(argv[2], "color")) test_lsd(cfg, weights, filename, 1);
/*

View File

@ -87,11 +87,12 @@ typedef enum {
REORG,
UPSAMPLE,
LOGXENT,
L2NORM,
BLANK
} LAYER_TYPE;
typedef enum{
SSE, MASKED, L1, SEG, SMOOTH
SSE, MASKED, L1, SEG, SMOOTH,WGAN
} COST_TYPE;
typedef struct{
@ -162,6 +163,7 @@ struct layer{
float shift;
float ratio;
float learning_rate_scale;
float clip;
int softmax;
int classes;
int coords;
@ -475,6 +477,7 @@ typedef struct network{
int train;
int index;
float *cost;
float clip;
#ifdef GPU
float *input_gpu;
@ -604,6 +607,7 @@ void backward_network(network *net);
void update_network(network *net);
float dot_cpu(int N, float *X, int INCX, float *Y, int INCY);
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
void copy_cpu(int N, float *X, int INCX, float *Y, int INCY);
void scal_cpu(int N, float ALPHA, float *X, int INCX);

View File

@ -123,6 +123,27 @@ void variance_cpu(float *x, float *mean, int batch, int filters, int spatial, fl
}
}
void l2normalize_cpu(float *x, float *dx, int batch, int filters, int spatial)
{
int b,f,i;
for(b = 0; b < batch; ++b){
for(i = 0; i < spatial; ++i){
float sum = 0;
for(f = 0; f < filters; ++f){
int index = b*filters*spatial + f*spatial + i;
sum += powf(x[index], 2);
}
sum = sqrtf(sum);
for(f = 0; f < filters; ++f){
int index = b*filters*spatial + f*spatial + i;
x[index] /= sum;
dx[index] = (1 - x[index]) / sum;
}
}
}
}
void normalize_cpu(float *x, float *mean, float *variance, int batch, int filters, int spatial)
{
int b, f, i;
@ -310,3 +331,21 @@ void softmax_cpu(float *input, int n, int batch, int batch_offset, int groups, i
}
}
void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int forward, float *out)
{
int i, j, k, b;
for(b = 0; b < batch; ++b){
for(k = 0; k < c; ++k){
for(j = 0; j < h*stride; ++j){
for(i = 0; i < w*stride; ++i){
int in_index = b*w*h*c + k*w*h + (j/stride)*w + i/stride;
int out_index = b*w*h*c + k*w*h + j*w + i;
if(forward) out[out_index] = in[in_index];
else in[in_index] += out[out_index];
}
}
}
}
}

View File

@ -19,7 +19,6 @@ void constrain_gpu(int N, float ALPHA, float * X, int INCX);
void pow_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
void mul_cpu(int N, float *X, int INCX, float *Y, int INCY);
float dot_cpu(int N, float *X, int INCX, float *Y, int INCY);
int test_gpu_blas();
void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
@ -31,6 +30,7 @@ void backward_scale_cpu(float *x_norm, float *delta, int batch, int n, int size,
void mean_delta_cpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta);
void variance_delta_cpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta);
void normalize_delta_cpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta);
void l2normalize_cpu(float *x, float *dx, int batch, int filters, int spatial);
void smooth_l1_cpu(int n, float *pred, float *truth, float *delta, float *error);
void l2_cpu(int n, float *pred, float *truth, float *delta, float *error);
@ -42,6 +42,7 @@ void weighted_delta_cpu(float *a, float *b, float *s, float *da, float *db, floa
void softmax(float *input, int n, float temp, int stride, float *output);
void softmax_cpu(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output);
void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int forward, float *out);
#ifdef GPU
#include "cuda.h"
@ -62,6 +63,7 @@ void mul_gpu(int N, float *X, int INCX, float *Y, int INCY);
void mean_gpu(float *x, int batch, int filters, int spatial, float *mean);
void variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial);
void l2normalize_gpu(float *x, float *dx, int batch, int filters, int spatial);
void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta);
@ -82,6 +84,7 @@ void softmax_x_ent_gpu(int n, float *pred, float *truth, float *delta, float *er
void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error);
void l2_gpu(int n, float *pred, float *truth, float *delta, float *error);
void l1_gpu(int n, float *pred, float *truth, float *delta, float *error);
void wgan_gpu(int n, float *pred, float *truth, float *delta, float *error);
void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc);
void weighted_sum_gpu(float *a, float *b, float *s, int num, float *c);
void mult_add_into_gpu(int num, float *a, float *b, float *c);

View File

@ -164,8 +164,11 @@ __global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float
{
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (index >= N) return;
float mhat = m[index] / (1.f - powf(B1, t));
float vhat = v[index] / (1.f - powf(B2, t));
x[index] = x[index] + (rate * sqrtf(1.f-powf(B2, t)) / (1.f-powf(B1, t)) * m[index] / (sqrtf(v[index]) + eps));
x[index] = x[index] + rate * mhat / (sqrtf(vhat) + eps);
}
extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
@ -466,6 +469,35 @@ extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch,
check_error(cudaPeekAtLastError());
}
__global__ void l2norm_kernel(int N, float *x, float *dx, int batch, int filters, int spatial)
{
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (index >= N) return;
int b = index / spatial;
int i = index % spatial;
int f;
float sum = 0;
for(f = 0; f < filters; ++f){
int index = b*filters*spatial + f*spatial + i;
sum += powf(x[index], 2);
}
sum = sqrtf(sum);
if(sum == 0) sum = 1;
//printf("%f\n", sum);
for(f = 0; f < filters; ++f){
int index = b*filters*spatial + f*spatial + i;
x[index] /= sum;
dx[index] = (1 - x[index]) / sum;
}
}
extern "C" void l2normalize_gpu(float *x, float *dx, int batch, int filters, int spatial)
{
size_t N = batch*spatial;
l2norm_kernel<<<cuda_gridsize(N), BLOCK>>>(N, x, dx, batch, filters, spatial);
check_error(cudaPeekAtLastError());
}
__global__ void fast_mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
{
const int threads = BLOCK;
@ -757,7 +789,7 @@ __global__ void logistic_x_ent_kernel(int n, float *pred, float *truth, float *d
if(i < n){
float t = truth[i];
float p = pred[i];
error[i] = -t*log(p) - (1-t)*log(1-p);
error[i] = -t*log(p+.0000001) - (1-t)*log(1-p+.0000001);
delta[i] = t-p;
}
}
@ -800,6 +832,21 @@ extern "C" void l1_gpu(int n, float *pred, float *truth, float *delta, float *er
check_error(cudaPeekAtLastError());
}
__global__ void wgan_kernel(int n, float *pred, float *truth, float *delta, float *error)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(i < n){
error[i] = truth[i] ? -pred[i] : pred[i];
delta[i] = (truth[i] > 0) ? 1 : -1;
}
}
extern "C" void wgan_gpu(int n, float *pred, float *truth, float *delta, float *error)
{
wgan_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
check_error(cudaPeekAtLastError());
}
@ -926,13 +973,13 @@ extern "C" void softmax_tree(float *input, int spatial, int batch, int stride, f
int *tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups);
int *tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups);
/*
static int *tree_groups_size = 0;
static int *tree_groups_offset = 0;
if(!tree_groups_size){
tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups);
tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups);
}
*/
static int *tree_groups_size = 0;
static int *tree_groups_offset = 0;
if(!tree_groups_size){
tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups);
tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups);
}
*/
int num = spatial*batch*hier.groups;
softmax_tree_kernel<<<cuda_gridsize(num), BLOCK>>>(input, spatial, batch, stride, temp, output, hier.groups, tree_groups_size, tree_groups_offset);
check_error(cudaPeekAtLastError());
@ -976,7 +1023,7 @@ __global__ void upsample_kernel(size_t N, float *x, int w, int h, int c, int bat
int in_index = b*w*h*c + in_c*w*h + in_h*w + in_w;
if(forward) out[out_index] = x[in_index];
if(forward) out[out_index] += x[in_index];
else atomicAdd(x+in_index, out[out_index]);
}
extern "C" void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float *out)

View File

@ -314,6 +314,9 @@ void update_convolutional_layer_gpu(layer l, update_args a)
scal_gpu(l.n, momentum, l.scale_updates_gpu, 1);
}
}
if(l.clip){
constrain_gpu(l.nweights, l.clip, l.weights_gpu, 1);
}
}

View File

@ -322,7 +322,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.workspace_size = get_workspace_size(l);
l.activation = activation;
fprintf(stderr, "conv %5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c);
fprintf(stderr, "conv %5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BFLOPs\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c, (2.0 * l.n * l.size*l.size*l.c/l.groups * l.out_h*l.out_w)/1000000000.);
return l;
}

View File

@ -14,6 +14,7 @@ COST_TYPE get_cost_type(char *s)
if (strcmp(s, "masked")==0) return MASKED;
if (strcmp(s, "smooth")==0) return SMOOTH;
if (strcmp(s, "L1")==0) return L1;
if (strcmp(s, "wgan")==0) return WGAN;
fprintf(stderr, "Couldn't find cost type %s, going with SSE\n", s);
return SSE;
}
@ -31,6 +32,8 @@ char *get_cost_string(COST_TYPE a)
return "smooth";
case L1:
return "L1";
case WGAN:
return "wgan";
}
return "sse";
}
@ -133,6 +136,8 @@ void forward_cost_layer_gpu(cost_layer l, network net)
smooth_l1_gpu(l.batch*l.inputs, net.input_gpu, net.truth_gpu, l.delta_gpu, l.output_gpu);
} else if (l.cost_type == L1){
l1_gpu(l.batch*l.inputs, net.input_gpu, net.truth_gpu, l.delta_gpu, l.output_gpu);
} else if (l.cost_type == WGAN){
wgan_gpu(l.batch*l.inputs, net.input_gpu, net.truth_gpu, l.delta_gpu, l.output_gpu);
} else {
l2_gpu(l.batch*l.inputs, net.input_gpu, net.truth_gpu, l.delta_gpu, l.output_gpu);
}

View File

@ -230,7 +230,7 @@ void fill_truth_swag(char *path, float *truth, int classes, int flip, float dx,
int id;
int i;
for (i = 0; i < count && i < 30; ++i) {
for (i = 0; i < count && i < 90; ++i) {
x = boxes[i].x;
y = boxes[i].y;
w = boxes[i].w;
@ -424,6 +424,7 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes,
float x,y,w,h;
int id;
int i;
int sub = 0;
for (i = 0; i < count; ++i) {
x = boxes[i].x;
@ -432,13 +433,16 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes,
h = boxes[i].h;
id = boxes[i].id;
if ((w < .001 || h < .001)) continue;
if ((w < .001 || h < .001)) {
++sub;
continue;
}
truth[i*5+0] = x;
truth[i*5+1] = y;
truth[i*5+2] = w;
truth[i*5+3] = h;
truth[i*5+4] = id;
truth[(i-sub)*5+0] = x;
truth[(i-sub)*5+1] = y;
truth[(i-sub)*5+2] = w;
truth[(i-sub)*5+3] = h;
truth[(i-sub)*5+4] = id;
}
free(boxes);
}
@ -907,7 +911,7 @@ data load_data_swag(char **paths, int n, int classes, float jitter)
d.X.vals = calloc(d.X.rows, sizeof(float*));
d.X.cols = h*w*3;
int k = (4+classes)*30;
int k = (4+classes)*90;
d.y = make_matrix(1, k);
int dw = w*jitter;

View File

@ -15,6 +15,22 @@ static size_t get_workspace_size(layer l){
return (size_t)l.h*l.w*l.size*l.size*l.n*sizeof(float);
}
void bilinear_init(layer l)
{
int i,j,f;
float center = (l.size-1) / 2.;
for(f = 0; f < l.n; ++f){
for(j = 0; j < l.size; ++j){
for(i = 0; i < l.size; ++i){
float val = (1 - fabs(i - center)) * (1 - fabs(j - center));
int c = f%l.c;
int ind = f*l.size*l.size*l.c + c*l.size*l.size + j*l.size + i;
l.weights[ind] = val;
}
}
}
}
layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int adam)
{
@ -38,9 +54,11 @@ layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size
l.biases = calloc(n, sizeof(float));
l.bias_updates = calloc(n, sizeof(float));
//float scale = n/(size*size*c);
//printf("scale: %f\n", scale);
float scale = .02;
printf("scale: %f\n", scale);
for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_normal();
//bilinear_init(l);
for(i = 0; i < n; ++i){
l.biases[i] = 0;
}
@ -52,6 +70,8 @@ layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size
l.outputs = l.out_w * l.out_h * l.out_c;
l.inputs = l.w * l.h * l.c;
scal_cpu(l.nweights, (float)l.out_w*l.out_h/(l.w*l.h), l.weights, 1);
l.output = calloc(l.batch*l.outputs, sizeof(float));
l.delta = calloc(l.batch*l.outputs, sizeof(float));
@ -122,7 +142,7 @@ layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size
l.mean_delta_gpu = cuda_make_array(0, n);
l.variance_delta_gpu = cuda_make_array(0, n);
l.scales_gpu = cuda_make_array(0, n);
l.scales_gpu = cuda_make_array(l.scales, n);
l.scale_updates_gpu = cuda_make_array(0, n);
l.x_gpu = cuda_make_array(0, l.batch*l.out_h*l.out_w*n);

63
src/l2norm_layer.c Normal file
View File

@ -0,0 +1,63 @@
#include "l2norm_layer.h"
#include "activations.h"
#include "blas.h"
#include "cuda.h"
#include <float.h>
#include <math.h>
#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
layer make_l2norm_layer(int batch, int inputs)
{
fprintf(stderr, "l2norm %4d\n", inputs);
layer l = {0};
l.type = L2NORM;
l.batch = batch;
l.inputs = inputs;
l.outputs = inputs;
l.output = calloc(inputs*batch, sizeof(float));
l.scales = calloc(inputs*batch, sizeof(float));
l.delta = calloc(inputs*batch, sizeof(float));
l.forward = forward_l2norm_layer;
l.backward = backward_l2norm_layer;
#ifdef GPU
l.forward_gpu = forward_l2norm_layer_gpu;
l.backward_gpu = backward_l2norm_layer_gpu;
l.output_gpu = cuda_make_array(l.output, inputs*batch);
l.scales_gpu = cuda_make_array(l.output, inputs*batch);
l.delta_gpu = cuda_make_array(l.delta, inputs*batch);
#endif
return l;
}
void forward_l2norm_layer(const layer l, network net)
{
copy_cpu(l.outputs*l.batch, net.input, 1, l.output, 1);
l2normalize_cpu(l.output, l.scales, l.batch, l.out_c, l.out_w*l.out_h);
}
void backward_l2norm_layer(const layer l, network net)
{
//axpy_cpu(l.inputs*l.batch, 1, l.scales, 1, l.delta, 1);
axpy_cpu(l.inputs*l.batch, 1, l.delta, 1, net.delta, 1);
}
#ifdef GPU
void forward_l2norm_layer_gpu(const layer l, network net)
{
copy_gpu(l.outputs*l.batch, net.input_gpu, 1, l.output_gpu, 1);
l2normalize_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_w*l.out_h);
}
void backward_l2norm_layer_gpu(const layer l, network net)
{
axpy_gpu(l.batch*l.inputs, 1, l.scales_gpu, 1, l.delta_gpu, 1);
axpy_gpu(l.batch*l.inputs, 1, l.delta_gpu, 1, net.delta_gpu, 1);
}
#endif

15
src/l2norm_layer.h Normal file
View File

@ -0,0 +1,15 @@
#ifndef L2NORM_LAYER_H
#define L2NORM_LAYER_H
#include "layer.h"
#include "network.h"
layer make_l2norm_layer(int batch, int inputs);
void forward_l2norm_layer(const layer l, network net);
void backward_l2norm_layer(const layer l, network net);
#ifdef GPU
void forward_l2norm_layer_gpu(const layer l, network net);
void backward_l2norm_layer_gpu(const layer l, network net);
#endif
#endif

View File

@ -60,13 +60,12 @@ void forward_logistic_layer_gpu(const layer l, network net)
logistic_x_ent_gpu(l.batch*l.inputs, l.output_gpu, net.truth_gpu, l.delta_gpu, l.loss_gpu);
cuda_pull_array(l.loss_gpu, l.loss, l.batch*l.inputs);
l.cost[0] = sum_array(l.loss, l.batch*l.inputs);
printf("hey: %f\n", l.cost[0]);
}
}
void backward_logistic_layer_gpu(const layer layer, network net)
void backward_logistic_layer_gpu(const layer l, network net)
{
axpy_gpu(layer.batch*layer.inputs, 1, layer.delta_gpu, 1, net.delta_gpu, 1);
axpy_gpu(l.batch*l.inputs, 1, l.delta_gpu, 1, net.delta_gpu, 1);
}
#endif

View File

@ -378,6 +378,8 @@ int resize_network(network *net, int w, int h)
resize_region_layer(&l, w, h);
}else if(l.type == ROUTE){
resize_route_layer(&l, net);
}else if(l.type == SHORTCUT){
resize_shortcut_layer(&l, w, h);
}else if(l.type == UPSAMPLE){
resize_upsample_layer(&l, w, h);
}else if(l.type == REORG){

View File

@ -5,6 +5,7 @@
#include "activation_layer.h"
#include "logistic_layer.h"
#include "l2norm_layer.h"
#include "activations.h"
#include "avgpool_layer.h"
#include "batchnorm_layer.h"
@ -56,6 +57,7 @@ LAYER_TYPE string_to_layer_type(char * type)
|| strcmp(type, "[deconvolutional]")==0) return DECONVOLUTIONAL;
if (strcmp(type, "[activation]")==0) return ACTIVE;
if (strcmp(type, "[logistic]")==0) return LOGXENT;
if (strcmp(type, "[l2norm]")==0) return L2NORM;
if (strcmp(type, "[net]")==0
|| strcmp(type, "[network]")==0) return NETWORK;
if (strcmp(type, "[crnn]")==0) return CRNN;
@ -307,7 +309,7 @@ layer parse_region(list *options, size_params params)
l.softmax = option_find_int(options, "softmax", 0);
l.background = option_find_int_quiet(options, "background", 0);
l.max_boxes = option_find_int_quiet(options, "max",30);
l.max_boxes = option_find_int_quiet(options, "max",90);
l.jitter = option_find_float(options, "jitter", .2);
l.rescore = option_find_int_quiet(options, "rescore",0);
@ -356,7 +358,7 @@ detection_layer parse_detection(list *options, size_params params)
layer.softmax = option_find_int(options, "softmax", 0);
layer.sqrt = option_find_int(options, "sqrt", 0);
layer.max_boxes = option_find_int_quiet(options, "max",30);
layer.max_boxes = option_find_int_quiet(options, "max",90);
layer.coord_scale = option_find_float(options, "coord_scale", 1);
layer.forced = option_find_int(options, "forced", 0);
layer.object_scale = option_find_float(options, "object_scale", 1);
@ -496,10 +498,20 @@ layer parse_shortcut(list *options, size_params params, network *net)
}
layer parse_l2norm(list *options, size_params params)
{
layer l = make_l2norm_layer(params.batch, params.inputs);
l.h = l.out_h = params.h;
l.w = l.out_w = params.w;
l.c = l.out_c = params.c;
return l;
}
layer parse_logistic(list *options, size_params params)
{
layer l = make_logistic_layer(params.batch, params.inputs);
l.w = l.out_h = params.h;
l.h = l.out_h = params.h;
l.w = l.out_w = params.w;
l.c = l.out_c = params.c;
return l;
@ -512,12 +524,9 @@ layer parse_activation(list *options, size_params params)
layer l = make_activation_layer(params.batch, params.inputs, activation);
l.out_h = params.h;
l.out_w = params.w;
l.out_c = params.c;
l.h = params.h;
l.w = params.w;
l.c = params.c;
l.h = l.out_h = params.h;
l.w = l.out_w = params.w;
l.c = l.out_c = params.c;
return l;
}
@ -614,6 +623,7 @@ void parse_net_options(list *options, network *net)
net->max_ratio = option_find_float_quiet(options, "max_ratio", (float) net->max_crop / net->w);
net->min_ratio = option_find_float_quiet(options, "min_ratio", (float) net->min_crop / net->w);
net->center = option_find_int_quiet(options, "center",0);
net->clip = option_find_float_quiet(options, "clip", 0);
net->angle = option_find_float_quiet(options, "angle", 0);
net->aspect = option_find_float_quiet(options, "aspect", 1);
@ -714,6 +724,8 @@ network *parse_network_cfg(char *filename)
l = parse_activation(options, params);
}else if(lt == LOGXENT){
l = parse_logistic(options, params);
}else if(lt == L2NORM){
l = parse_l2norm(options, params);
}else if(lt == RNN){
l = parse_rnn(options, params);
}else if(lt == GRU){
@ -762,6 +774,7 @@ network *parse_network_cfg(char *filename)
}else{
fprintf(stderr, "Type not recognized: %s\n", s->type);
}
l.clip = net->clip;
l.truth = option_find_int_quiet(options, "truth", 0);
l.onlyforward = option_find_int_quiet(options, "onlyforward", 0);
l.stopbackward = option_find_int_quiet(options, "stopbackward", 0);

View File

@ -39,7 +39,7 @@ layer make_region_layer(int batch, int w, int h, int n, int total, int *mask, in
l.bias_updates = calloc(n*2, sizeof(float));
l.outputs = h*w*n*(classes + coords + 1);
l.inputs = l.outputs;
l.truths = 30*(l.coords + 1);
l.truths = 90*(l.coords + 1);
l.delta = calloc(batch*l.outputs, sizeof(float));
l.output = calloc(batch*l.outputs, sizeof(float));
for(i = 0; i < total*2; ++i){
@ -213,7 +213,7 @@ void forward_region_layer(const layer l, network net)
for (b = 0; b < l.batch; ++b) {
if(l.softmax_tree){
int onlyclass = 0;
for(t = 0; t < 30; ++t){
for(t = 0; t < l.max_boxes; ++t){
box truth = float_to_box(net.truth + t*(l.coords + 1) + b*l.truths, 1);
if(!truth.x) break;
int class = net.truth[t*(l.coords + 1) + b*l.truths + l.coords];
@ -250,7 +250,7 @@ void forward_region_layer(const layer l, network net)
int box_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 0);
box pred = get_region_box(l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, net.w, net.h, l.w*l.h);
float best_iou = 0;
for(t = 0; t < 30; ++t){
for(t = 0; t < l.max_boxes; ++t){
box truth = float_to_box(net.truth + t*(l.coords + 1) + b*l.truths, 1);
if(!truth.x) break;
float iou = box_iou(pred, truth);
@ -279,7 +279,7 @@ void forward_region_layer(const layer l, network net)
}
}
}
for(t = 0; t < 30; ++t){
for(t = 0; t < l.max_boxes; ++t){
box truth = float_to_box(net.truth + t*(l.coords + 1) + b*l.truths, 1);
if(!truth.x) break;

View File

@ -8,7 +8,7 @@
layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2)
{
fprintf(stderr,"Shortcut Layer: %d\n", index);
fprintf(stderr, "res %3d %4d x%4d x%4d -> %4d x%4d x%4d\n",index, w2,h2,c2, w,h,c);
layer l = {0};
l.type = SHORTCUT;
l.batch = batch;
@ -38,6 +38,27 @@ layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int
return l;
}
void resize_shortcut_layer(layer *l, int w, int h)
{
assert(l->w == l->out_w);
assert(l->h == l->out_h);
l->w = l->out_w = w;
l->h = l->out_h = h;
l->outputs = w*h*l->out_c;
l->inputs = l->outputs;
l->delta = realloc(l->delta, l->outputs*l->batch*sizeof(float));
l->output = realloc(l->output, l->outputs*l->batch*sizeof(float));
#ifdef GPU
cuda_free(l->output_gpu);
cuda_free(l->delta_gpu);
l->output_gpu = cuda_make_array(l->output, l->outputs*l->batch);
l->delta_gpu = cuda_make_array(l->delta, l->outputs*l->batch);
#endif
}
void forward_shortcut_layer(const layer l, network net)
{
copy_cpu(l.outputs*l.batch, net.input, 1, l.output, 1);

View File

@ -7,6 +7,7 @@
layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2);
void forward_shortcut_layer(const layer l, network net);
void backward_shortcut_layer(const layer l, network net);
void resize_shortcut_layer(layer *l, int w, int h);
#ifdef GPU
void forward_shortcut_layer_gpu(const layer l, network net);

View File

@ -12,10 +12,16 @@ layer make_upsample_layer(int batch, int w, int h, int c, int stride)
l.w = w;
l.h = h;
l.c = c;
l.stride = stride;
l.out_w = w*stride;
l.out_h = h*stride;
l.out_c = c;
if(stride < 0){
stride = -stride;
l.reverse=1;
l.out_w = w/stride;
l.out_h = h/stride;
}
l.stride = stride;
l.outputs = l.out_w*l.out_h*l.out_c;
l.inputs = l.w*l.h*l.c;
l.delta = calloc(l.outputs*batch, sizeof(float));
@ -30,7 +36,8 @@ layer make_upsample_layer(int batch, int w, int h, int c, int stride)
l.delta_gpu = cuda_make_array(l.delta, l.outputs*batch);
l.output_gpu = cuda_make_array(l.output, l.outputs*batch);
#endif
fprintf(stderr, "upsample %2dx %4d x%4d x%4d -> %4d x%4d x%4d\n", stride, w, h, c, l.out_w, l.out_h, l.out_c);
if(l.reverse) fprintf(stderr, "downsample %2dx %4d x%4d x%4d -> %4d x%4d x%4d\n", stride, w, h, c, l.out_w, l.out_h, l.out_c);
else fprintf(stderr, "upsample %2dx %4d x%4d x%4d -> %4d x%4d x%4d\n", stride, w, h, c, l.out_w, l.out_h, l.out_c);
return l;
}
@ -40,6 +47,10 @@ void resize_upsample_layer(layer *l, int w, int h)
l->h = h;
l->out_w = w*l->stride;
l->out_h = h*l->stride;
if(l->reverse){
l->out_w = w/l->stride;
l->out_h = h/l->stride;
}
l->outputs = l->out_w*l->out_h*l->out_c;
l->inputs = l->h*l->w*l->c;
l->delta = realloc(l->delta, l->outputs*l->batch*sizeof(float));
@ -56,44 +67,40 @@ void resize_upsample_layer(layer *l, int w, int h)
void forward_upsample_layer(const layer l, network net)
{
int i, j, k, b;
for(b = 0; b < l.batch; ++b){
for(k = 0; k < l.c; ++k){
for(j = 0; j < l.h*l.stride; ++j){
for(i = 0; i < l.w*l.stride; ++i){
int in_index = b*l.inputs + k*l.w*l.h + (j/l.stride)*l.w + i/l.stride;
int out_index = b*l.inputs + k*l.w*l.h + j*l.w + i;
l.output[out_index] = net.input[in_index];
}
}
}
fill_cpu(l.outputs*l.batch, 0, l.output, 1);
if(l.reverse){
upsample_cpu(l.output, l.out_w, l.out_h, l.c, l.batch, l.stride, 0, net.input);
}else{
upsample_cpu(net.input, l.w, l.h, l.c, l.batch, l.stride, 1, l.output);
}
}
void backward_upsample_layer(const layer l, network net)
{
int i, j, k, b;
for(b = 0; b < l.batch; ++b){
for(k = 0; k < l.c; ++k){
for(j = 0; j < l.h*l.stride; ++j){
for(i = 0; i < l.w*l.stride; ++i){
int in_index = b*l.inputs + k*l.w*l.h + (j/l.stride)*l.w + i/l.stride;
int out_index = b*l.inputs + k*l.w*l.h + j*l.w + i;
net.delta[in_index] += l.delta[out_index];
}
}
}
if(l.reverse){
upsample_cpu(l.delta, l.out_w, l.out_h, l.c, l.batch, l.stride, 1, net.delta);
}else{
upsample_cpu(net.delta, l.w, l.h, l.c, l.batch, l.stride, 0, l.delta);
}
}
#ifdef GPU
void forward_upsample_layer_gpu(const layer l, network net)
{
upsample_gpu(net.input_gpu, l.w, l.h, l.c, l.batch, l.stride, 1, l.output_gpu);
fill_gpu(l.outputs*l.batch, 0, l.output_gpu, 1);
if(l.reverse){
upsample_gpu(l.output_gpu, l.out_w, l.out_h, l.c, l.batch, l.stride, 0, net.input_gpu);
}else{
upsample_gpu(net.input_gpu, l.w, l.h, l.c, l.batch, l.stride, 1, l.output_gpu);
}
}
void backward_upsample_layer_gpu(const layer l, network net)
{
upsample_gpu(net.delta_gpu, l.w, l.h, l.c, l.batch, l.stride, 0, l.delta_gpu);
if(l.reverse){
upsample_gpu(l.delta_gpu, l.out_w, l.out_h, l.c, l.batch, l.stride, 1, net.delta_gpu);
}else{
upsample_gpu(net.delta_gpu, l.w, l.h, l.c, l.batch, l.stride, 0, l.delta_gpu);
}
}
#endif