diff --git a/Makefile b/Makefile index 0ae99ecc..f50f279d 100644 --- a/Makefile +++ b/Makefile @@ -57,7 +57,7 @@ CFLAGS+= -DCUDNN LDFLAGS+= -lcudnn endif -OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o detection_layer.o route_layer.o upsample_layer.o box.o normalization_layer.o avgpool_layer.o layer.o local_layer.o shortcut_layer.o logistic_layer.o activation_layer.o rnn_layer.o gru_layer.o crnn_layer.o demo.o batchnorm_layer.o region_layer.o reorg_layer.o tree.o lstm_layer.o +OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o detection_layer.o route_layer.o upsample_layer.o box.o normalization_layer.o avgpool_layer.o layer.o local_layer.o shortcut_layer.o logistic_layer.o activation_layer.o rnn_layer.o gru_layer.o crnn_layer.o demo.o batchnorm_layer.o region_layer.o reorg_layer.o tree.o lstm_layer.o l2norm_layer.o EXECOBJA=captcha.o lsd.o super.o art.o tag.o cifar.o go.o rnn.o segmenter.o regressor.o classifier.o coco.o yolo.o detector.o nightmare.o darknet.o ifeq ($(GPU), 1) LDFLAGS+= -lstdc++ diff --git a/cfg/darknet.cfg b/cfg/darknet.cfg index 9bdee83b..ec902830 100644 --- a/cfg/darknet.cfg +++ b/cfg/darknet.cfg @@ -1,10 +1,10 @@ [net] # Train -# batch=128 -# subdivisions=1 + batch=128 + subdivisions=1 # Test -batch=1 -subdivisions=1 +#batch=1 +#subdivisions=1 height=256 width=256 channels=3 @@ -88,7 +88,6 @@ activation=leaky [maxpool] size=2 stride=2 -padding=1 [convolutional] batch_normalize=1 @@ -110,6 +109,3 @@ activation=leaky [softmax] groups=1 -[cost] -type=sse - diff --git a/cfg/darknet19.cfg b/cfg/darknet19.cfg index bf73fb7b..f56a46e2 100644 --- a/cfg/darknet19.cfg +++ b/cfg/darknet19.cfg @@ -1,17 +1,31 @@ [net] -batch=128 -subdivisions=1 -height=224 -width=224 +# Training +#batch=128 +#subdivisions=2 + +# Testing + batch=1 + subdivisions=1 + +height=256 +width=256 +min_crop=128 +max_crop=448 channels=3 momentum=0.9 decay=0.0005 -max_crop=448 +burn_in=1000 learning_rate=0.1 policy=poly power=4 -max_batches=1600000 +max_batches=800000 + +angle=7 +hue=.1 +saturation=.75 +exposure=.75 +aspect=.75 [convolutional] batch_normalize=1 diff --git a/examples/classifier.c b/examples/classifier.c index 4dda9516..b09947cc 100644 --- a/examples/classifier.c +++ b/examples/classifier.c @@ -578,6 +578,8 @@ void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *fi } image im = load_image_color(input, 0, 0); image r = letterbox_image(im, net->w, net->h); + //image r = resize_min(im, 320); + //printf("%d %d\n", r.w, r.h); //resize_network(net, r.w, r.h); //printf("%d %d\n", r.w, r.h); diff --git a/examples/lsd.c b/examples/lsd.c index 369f63da..247f639e 100644 --- a/examples/lsd.c +++ b/examples/lsd.c @@ -383,6 +383,92 @@ void train_pix2pix(char *cfg, char *weight, char *acfg, char *aweight, int clear } */ +void slerp(float *start, float *end, float s, int n, float *out) +{ + float omega = acos(dot_cpu(n, start, 1, end, 1)); + float so = sin(omega); + fill_cpu(n, 0, out, 1); + axpy_cpu(n, sin((1-s)*omega)/so, start, 1, out, 1); + axpy_cpu(n, sin(s*omega)/so, end, 1, out, 1); + + float mag = mag_array(out, n); + scale_array(out, n, 1./mag); +} + +image random_unit_vector_image(w, h, c) +{ + image im = make_image(w, h, c); + int i; + for(i = 0; i < im.w*im.h*im.c; ++i){ + im.data[i] = rand_normal(); + } + float mag = mag_array(im.data, im.w*im.h*im.c); + scale_array(im.data, im.w*im.h*im.c, 1./mag); + return im; +} + +void inter_dcgan(char *cfgfile, char *weightfile) +{ + network *net = load_network(cfgfile, weightfile, 0); + set_batch_network(net, 1); + srand(2222222); + + clock_t time; + char buff[256]; + char *input = buff; + int i, imlayer = 0; + + for (i = 0; i < net->n; ++i) { + if (net->layers[i].out_c == 3) { + imlayer = i; + printf("%d\n", i); + break; + } + } + image start = random_unit_vector_image(net->w, net->h, net->c); + image end = random_unit_vector_image(net->w, net->h, net->c); + image im = make_image(net->w, net->h, net->c); + image orig = copy_image(start); + + int c = 0; + int count = 0; + int max_count = 15; + while(1){ + ++c; + + if(count == max_count){ + count = 0; + free_image(start); + start = end; + end = random_unit_vector_image(net->w, net->h, net->c); + if(c > 300){ + end = orig; + } + if(c>300 + max_count) return; + } + ++count; + + slerp(start.data, end.data, (float)count / max_count, im.w*im.h*im.c, im.data); + + float *X = im.data; + time=clock(); + network_predict(net, X); + image out = get_network_image_layer(net, imlayer); + //yuv_to_rgb(out); + normalize_image(out); + printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); + //char buff[256]; + sprintf(buff, "out%05d", c); + show_image(out, "out"); + save_image(out, "out"); + save_image(out, buff); +#ifdef OPENCV + //cvWaitKey(0); +#endif + + } +} + void test_dcgan(char *cfgfile, char *weightfile) { network *net = load_network(cfgfile, weightfile, 0); @@ -409,7 +495,7 @@ void test_dcgan(char *cfgfile, char *weightfile) im.data[i] = rand_normal(); } float mag = mag_array(im.data, im.w*im.h*im.c); - //scale_array(im.data, im.w*im.h*im.c, 1./mag); + scale_array(im.data, im.w*im.h*im.c, 1./mag); float *X = im.data; time=clock(); @@ -429,7 +515,7 @@ void test_dcgan(char *cfgfile, char *weightfile) } -void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, int display, char *train_images) +void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, int display, char *train_images, int maxbatch) { #ifdef GPU char *backup_directory = "/home/pjreddie/backup/"; @@ -441,7 +527,6 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, network *anet = load_network(acfg, aweight, clear); //float orig_rate = anet->learning_rate; - int start = 0; int i, j, k; layer imlayer = {0}; for (i = 0; i < gnet->n; ++i) { @@ -488,8 +573,8 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, //data generated = copy_data(train); - while (get_current_batch(gnet) < gnet->max_batches) { - start += 1; + if (maxbatch == 0) maxbatch = gnet->max_batches; + while (get_current_batch(gnet) < maxbatch) { i += 1; time=clock(); pthread_join(load_thread, 0); @@ -519,6 +604,13 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, float mag = mag_array(gnet->input + z*gnet->inputs, gnet->inputs); scale_array(gnet->input + z*gnet->inputs, gnet->inputs, 1./mag); } + /* + for(z = 0; z < 100; ++z){ + printf("%f, ", gnet->input[z]); + } + printf("\n"); + printf("input: %f %f\n", mean_array(gnet->input, x_size), variance_array(gnet->input, x_size)); + */ //cuda_push_array(gnet->input_gpu, gnet->input, x_size); //cuda_push_array(gnet->truth_gpu, gnet->truth, y_size); @@ -533,18 +625,26 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, backward_network(anet); float genaloss = *anet->cost / anet->batch; - printf("%f\n", genaloss); + //printf("%f\n", genaloss); scal_gpu(imlayer.outputs*imlayer.batch, 1, imerror, 1); scal_gpu(imlayer.outputs*imlayer.batch, 0, gnet->layers[gnet->n-1].delta_gpu, 1); - printf("realness %f\n", cuda_mag_array(imerror, imlayer.outputs*imlayer.batch)); - printf("features %f\n", cuda_mag_array(gnet->layers[gnet->n-1].delta_gpu, imlayer.outputs*imlayer.batch)); + //printf("realness %f\n", cuda_mag_array(imerror, imlayer.outputs*imlayer.batch)); + //printf("features %f\n", cuda_mag_array(gnet->layers[gnet->n-1].delta_gpu, imlayer.outputs*imlayer.batch)); axpy_gpu(imlayer.outputs*imlayer.batch, 1, imerror, 1, gnet->layers[gnet->n-1].delta_gpu, 1); backward_network(gnet); + /* + for(k = 0; k < gnet->n; ++k){ + layer l = gnet->layers[k]; + cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch); + printf("%d: %f %f\n", k, mean_array(l.output, l.outputs*l.batch), variance_array(l.output, l.outputs*l.batch)); + } + */ + for(k = 0; k < gnet->batch; ++k){ int index = j*gnet->batch + k; copy_cpu(gnet->outputs, gnet->output + k*gnet->outputs, 1, gen.X.vals[index], 1); @@ -1104,6 +1204,7 @@ void run_lsd(int argc, char **argv) int clear = find_arg(argc, argv, "-clear"); int display = find_arg(argc, argv, "-display"); + int batches = find_int_arg(argc, argv, "-b", 0); char *file = find_char_arg(argc, argv, "-file", "/home/pjreddie/data/imagenet/imagenet1k.train.list"); char *cfg = argv[3]; @@ -1115,9 +1216,10 @@ void run_lsd(int argc, char **argv) //else if(0==strcmp(argv[2], "train2")) train_lsd2(cfg, weights, acfg, aweights, clear); //else if(0==strcmp(argv[2], "traincolor")) train_colorizer(cfg, weights, acfg, aweights, clear); //else if(0==strcmp(argv[2], "train3")) train_lsd3(argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], clear); - if(0==strcmp(argv[2], "traingan")) train_dcgan(cfg, weights, acfg, aweights, clear, display, file); + if(0==strcmp(argv[2], "traingan")) train_dcgan(cfg, weights, acfg, aweights, clear, display, file, batches); else if(0==strcmp(argv[2], "traincolor")) train_colorizer(cfg, weights, acfg, aweights, clear, display); else if(0==strcmp(argv[2], "gan")) test_dcgan(cfg, weights); + else if(0==strcmp(argv[2], "inter")) inter_dcgan(cfg, weights); else if(0==strcmp(argv[2], "test")) test_lsd(cfg, weights, filename, 0); else if(0==strcmp(argv[2], "color")) test_lsd(cfg, weights, filename, 1); /* diff --git a/include/darknet.h b/include/darknet.h index 0165284d..d8bec786 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -87,11 +87,12 @@ typedef enum { REORG, UPSAMPLE, LOGXENT, + L2NORM, BLANK } LAYER_TYPE; typedef enum{ - SSE, MASKED, L1, SEG, SMOOTH + SSE, MASKED, L1, SEG, SMOOTH,WGAN } COST_TYPE; typedef struct{ @@ -162,6 +163,7 @@ struct layer{ float shift; float ratio; float learning_rate_scale; + float clip; int softmax; int classes; int coords; @@ -475,6 +477,7 @@ typedef struct network{ int train; int index; float *cost; + float clip; #ifdef GPU float *input_gpu; @@ -604,6 +607,7 @@ void backward_network(network *net); void update_network(network *net); +float dot_cpu(int N, float *X, int INCX, float *Y, int INCY); void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY); void copy_cpu(int N, float *X, int INCX, float *Y, int INCY); void scal_cpu(int N, float ALPHA, float *X, int INCX); diff --git a/src/blas.c b/src/blas.c index bd8b5534..d1d01cf2 100644 --- a/src/blas.c +++ b/src/blas.c @@ -123,6 +123,27 @@ void variance_cpu(float *x, float *mean, int batch, int filters, int spatial, fl } } +void l2normalize_cpu(float *x, float *dx, int batch, int filters, int spatial) +{ + int b,f,i; + for(b = 0; b < batch; ++b){ + for(i = 0; i < spatial; ++i){ + float sum = 0; + for(f = 0; f < filters; ++f){ + int index = b*filters*spatial + f*spatial + i; + sum += powf(x[index], 2); + } + sum = sqrtf(sum); + for(f = 0; f < filters; ++f){ + int index = b*filters*spatial + f*spatial + i; + x[index] /= sum; + dx[index] = (1 - x[index]) / sum; + } + } + } +} + + void normalize_cpu(float *x, float *mean, float *variance, int batch, int filters, int spatial) { int b, f, i; @@ -310,3 +331,21 @@ void softmax_cpu(float *input, int n, int batch, int batch_offset, int groups, i } } +void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int forward, float *out) +{ + int i, j, k, b; + for(b = 0; b < batch; ++b){ + for(k = 0; k < c; ++k){ + for(j = 0; j < h*stride; ++j){ + for(i = 0; i < w*stride; ++i){ + int in_index = b*w*h*c + k*w*h + (j/stride)*w + i/stride; + int out_index = b*w*h*c + k*w*h + j*w + i; + if(forward) out[out_index] = in[in_index]; + else in[in_index] += out[out_index]; + } + } + } + } +} + + diff --git a/src/blas.h b/src/blas.h index ec827e4d..99e348be 100644 --- a/src/blas.h +++ b/src/blas.h @@ -19,7 +19,6 @@ void constrain_gpu(int N, float ALPHA, float * X, int INCX); void pow_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY); void mul_cpu(int N, float *X, int INCX, float *Y, int INCY); -float dot_cpu(int N, float *X, int INCX, float *Y, int INCY); int test_gpu_blas(); void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out); @@ -31,6 +30,7 @@ void backward_scale_cpu(float *x_norm, float *delta, int batch, int n, int size, void mean_delta_cpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta); void variance_delta_cpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta); void normalize_delta_cpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta); +void l2normalize_cpu(float *x, float *dx, int batch, int filters, int spatial); void smooth_l1_cpu(int n, float *pred, float *truth, float *delta, float *error); void l2_cpu(int n, float *pred, float *truth, float *delta, float *error); @@ -42,6 +42,7 @@ void weighted_delta_cpu(float *a, float *b, float *s, float *da, float *db, floa void softmax(float *input, int n, float temp, int stride, float *output); void softmax_cpu(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output); +void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int forward, float *out); #ifdef GPU #include "cuda.h" @@ -62,6 +63,7 @@ void mul_gpu(int N, float *X, int INCX, float *Y, int INCY); void mean_gpu(float *x, int batch, int filters, int spatial, float *mean); void variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance); void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial); +void l2normalize_gpu(float *x, float *dx, int batch, int filters, int spatial); void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta); @@ -82,6 +84,7 @@ void softmax_x_ent_gpu(int n, float *pred, float *truth, float *delta, float *er void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error); void l2_gpu(int n, float *pred, float *truth, float *delta, float *error); void l1_gpu(int n, float *pred, float *truth, float *delta, float *error); +void wgan_gpu(int n, float *pred, float *truth, float *delta, float *error); void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc); void weighted_sum_gpu(float *a, float *b, float *s, int num, float *c); void mult_add_into_gpu(int num, float *a, float *b, float *c); diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index b04c2461..56b311f1 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -164,8 +164,11 @@ __global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float { int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; if (index >= N) return; + + float mhat = m[index] / (1.f - powf(B1, t)); + float vhat = v[index] / (1.f - powf(B2, t)); - x[index] = x[index] + (rate * sqrtf(1.f-powf(B2, t)) / (1.f-powf(B1, t)) * m[index] / (sqrtf(v[index]) + eps)); + x[index] = x[index] + rate * mhat / (sqrtf(vhat) + eps); } extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t) @@ -466,6 +469,35 @@ extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch, check_error(cudaPeekAtLastError()); } +__global__ void l2norm_kernel(int N, float *x, float *dx, int batch, int filters, int spatial) +{ + int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (index >= N) return; + int b = index / spatial; + int i = index % spatial; + int f; + float sum = 0; + for(f = 0; f < filters; ++f){ + int index = b*filters*spatial + f*spatial + i; + sum += powf(x[index], 2); + } + sum = sqrtf(sum); + if(sum == 0) sum = 1; + //printf("%f\n", sum); + for(f = 0; f < filters; ++f){ + int index = b*filters*spatial + f*spatial + i; + x[index] /= sum; + dx[index] = (1 - x[index]) / sum; + } +} + +extern "C" void l2normalize_gpu(float *x, float *dx, int batch, int filters, int spatial) +{ + size_t N = batch*spatial; + l2norm_kernel<<>>(N, x, dx, batch, filters, spatial); + check_error(cudaPeekAtLastError()); +} + __global__ void fast_mean_kernel(float *x, int batch, int filters, int spatial, float *mean) { const int threads = BLOCK; @@ -757,7 +789,7 @@ __global__ void logistic_x_ent_kernel(int n, float *pred, float *truth, float *d if(i < n){ float t = truth[i]; float p = pred[i]; - error[i] = -t*log(p) - (1-t)*log(1-p); + error[i] = -t*log(p+.0000001) - (1-t)*log(1-p+.0000001); delta[i] = t-p; } } @@ -800,6 +832,21 @@ extern "C" void l1_gpu(int n, float *pred, float *truth, float *delta, float *er check_error(cudaPeekAtLastError()); } +__global__ void wgan_kernel(int n, float *pred, float *truth, float *delta, float *error) +{ + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if(i < n){ + error[i] = truth[i] ? -pred[i] : pred[i]; + delta[i] = (truth[i] > 0) ? 1 : -1; + } +} + +extern "C" void wgan_gpu(int n, float *pred, float *truth, float *delta, float *error) +{ + wgan_kernel<<>>(n, pred, truth, delta, error); + check_error(cudaPeekAtLastError()); +} + @@ -926,13 +973,13 @@ extern "C" void softmax_tree(float *input, int spatial, int batch, int stride, f int *tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups); int *tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups); /* - static int *tree_groups_size = 0; - static int *tree_groups_offset = 0; - if(!tree_groups_size){ - tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups); - tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups); - } - */ + static int *tree_groups_size = 0; + static int *tree_groups_offset = 0; + if(!tree_groups_size){ + tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups); + tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups); + } + */ int num = spatial*batch*hier.groups; softmax_tree_kernel<<>>(input, spatial, batch, stride, temp, output, hier.groups, tree_groups_size, tree_groups_offset); check_error(cudaPeekAtLastError()); @@ -976,7 +1023,7 @@ __global__ void upsample_kernel(size_t N, float *x, int w, int h, int c, int bat int in_index = b*w*h*c + in_c*w*h + in_h*w + in_w; - if(forward) out[out_index] = x[in_index]; + if(forward) out[out_index] += x[in_index]; else atomicAdd(x+in_index, out[out_index]); } extern "C" void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float *out) diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index 56043e78..8fa2ab2e 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -314,6 +314,9 @@ void update_convolutional_layer_gpu(layer l, update_args a) scal_gpu(l.n, momentum, l.scale_updates_gpu, 1); } } + if(l.clip){ + constrain_gpu(l.nweights, l.clip, l.weights_gpu, 1); + } } diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index aca2da38..e4fb9bde 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -322,7 +322,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int l.workspace_size = get_workspace_size(l); l.activation = activation; - fprintf(stderr, "conv %5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c); + fprintf(stderr, "conv %5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BFLOPs\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c, (2.0 * l.n * l.size*l.size*l.c/l.groups * l.out_h*l.out_w)/1000000000.); return l; } diff --git a/src/cost_layer.c b/src/cost_layer.c index ebf332fe..2138ff26 100644 --- a/src/cost_layer.c +++ b/src/cost_layer.c @@ -14,6 +14,7 @@ COST_TYPE get_cost_type(char *s) if (strcmp(s, "masked")==0) return MASKED; if (strcmp(s, "smooth")==0) return SMOOTH; if (strcmp(s, "L1")==0) return L1; + if (strcmp(s, "wgan")==0) return WGAN; fprintf(stderr, "Couldn't find cost type %s, going with SSE\n", s); return SSE; } @@ -31,6 +32,8 @@ char *get_cost_string(COST_TYPE a) return "smooth"; case L1: return "L1"; + case WGAN: + return "wgan"; } return "sse"; } @@ -133,6 +136,8 @@ void forward_cost_layer_gpu(cost_layer l, network net) smooth_l1_gpu(l.batch*l.inputs, net.input_gpu, net.truth_gpu, l.delta_gpu, l.output_gpu); } else if (l.cost_type == L1){ l1_gpu(l.batch*l.inputs, net.input_gpu, net.truth_gpu, l.delta_gpu, l.output_gpu); + } else if (l.cost_type == WGAN){ + wgan_gpu(l.batch*l.inputs, net.input_gpu, net.truth_gpu, l.delta_gpu, l.output_gpu); } else { l2_gpu(l.batch*l.inputs, net.input_gpu, net.truth_gpu, l.delta_gpu, l.output_gpu); } diff --git a/src/data.c b/src/data.c index 246528e8..9c691465 100644 --- a/src/data.c +++ b/src/data.c @@ -230,7 +230,7 @@ void fill_truth_swag(char *path, float *truth, int classes, int flip, float dx, int id; int i; - for (i = 0; i < count && i < 30; ++i) { + for (i = 0; i < count && i < 90; ++i) { x = boxes[i].x; y = boxes[i].y; w = boxes[i].w; @@ -424,6 +424,7 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, float x,y,w,h; int id; int i; + int sub = 0; for (i = 0; i < count; ++i) { x = boxes[i].x; @@ -432,13 +433,16 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, h = boxes[i].h; id = boxes[i].id; - if ((w < .001 || h < .001)) continue; + if ((w < .001 || h < .001)) { + ++sub; + continue; + } - truth[i*5+0] = x; - truth[i*5+1] = y; - truth[i*5+2] = w; - truth[i*5+3] = h; - truth[i*5+4] = id; + truth[(i-sub)*5+0] = x; + truth[(i-sub)*5+1] = y; + truth[(i-sub)*5+2] = w; + truth[(i-sub)*5+3] = h; + truth[(i-sub)*5+4] = id; } free(boxes); } @@ -907,7 +911,7 @@ data load_data_swag(char **paths, int n, int classes, float jitter) d.X.vals = calloc(d.X.rows, sizeof(float*)); d.X.cols = h*w*3; - int k = (4+classes)*30; + int k = (4+classes)*90; d.y = make_matrix(1, k); int dw = w*jitter; diff --git a/src/deconvolutional_layer.c b/src/deconvolutional_layer.c index 63e51ba6..00c0e857 100644 --- a/src/deconvolutional_layer.c +++ b/src/deconvolutional_layer.c @@ -15,6 +15,22 @@ static size_t get_workspace_size(layer l){ return (size_t)l.h*l.w*l.size*l.size*l.n*sizeof(float); } +void bilinear_init(layer l) +{ + int i,j,f; + float center = (l.size-1) / 2.; + for(f = 0; f < l.n; ++f){ + for(j = 0; j < l.size; ++j){ + for(i = 0; i < l.size; ++i){ + float val = (1 - fabs(i - center)) * (1 - fabs(j - center)); + int c = f%l.c; + int ind = f*l.size*l.size*l.c + c*l.size*l.size + j*l.size + i; + l.weights[ind] = val; + } + } + } +} + layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int adam) { @@ -38,9 +54,11 @@ layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size l.biases = calloc(n, sizeof(float)); l.bias_updates = calloc(n, sizeof(float)); + //float scale = n/(size*size*c); + //printf("scale: %f\n", scale); float scale = .02; - printf("scale: %f\n", scale); for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_normal(); + //bilinear_init(l); for(i = 0; i < n; ++i){ l.biases[i] = 0; } @@ -52,6 +70,8 @@ layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size l.outputs = l.out_w * l.out_h * l.out_c; l.inputs = l.w * l.h * l.c; + scal_cpu(l.nweights, (float)l.out_w*l.out_h/(l.w*l.h), l.weights, 1); + l.output = calloc(l.batch*l.outputs, sizeof(float)); l.delta = calloc(l.batch*l.outputs, sizeof(float)); @@ -122,7 +142,7 @@ layer make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size l.mean_delta_gpu = cuda_make_array(0, n); l.variance_delta_gpu = cuda_make_array(0, n); - l.scales_gpu = cuda_make_array(0, n); + l.scales_gpu = cuda_make_array(l.scales, n); l.scale_updates_gpu = cuda_make_array(0, n); l.x_gpu = cuda_make_array(0, l.batch*l.out_h*l.out_w*n); diff --git a/src/l2norm_layer.c b/src/l2norm_layer.c new file mode 100644 index 00000000..d099479b --- /dev/null +++ b/src/l2norm_layer.c @@ -0,0 +1,63 @@ +#include "l2norm_layer.h" +#include "activations.h" +#include "blas.h" +#include "cuda.h" + +#include +#include +#include +#include +#include + +layer make_l2norm_layer(int batch, int inputs) +{ + fprintf(stderr, "l2norm %4d\n", inputs); + layer l = {0}; + l.type = L2NORM; + l.batch = batch; + l.inputs = inputs; + l.outputs = inputs; + l.output = calloc(inputs*batch, sizeof(float)); + l.scales = calloc(inputs*batch, sizeof(float)); + l.delta = calloc(inputs*batch, sizeof(float)); + + l.forward = forward_l2norm_layer; + l.backward = backward_l2norm_layer; + #ifdef GPU + l.forward_gpu = forward_l2norm_layer_gpu; + l.backward_gpu = backward_l2norm_layer_gpu; + + l.output_gpu = cuda_make_array(l.output, inputs*batch); + l.scales_gpu = cuda_make_array(l.output, inputs*batch); + l.delta_gpu = cuda_make_array(l.delta, inputs*batch); + #endif + return l; +} + +void forward_l2norm_layer(const layer l, network net) +{ + copy_cpu(l.outputs*l.batch, net.input, 1, l.output, 1); + l2normalize_cpu(l.output, l.scales, l.batch, l.out_c, l.out_w*l.out_h); +} + +void backward_l2norm_layer(const layer l, network net) +{ + //axpy_cpu(l.inputs*l.batch, 1, l.scales, 1, l.delta, 1); + axpy_cpu(l.inputs*l.batch, 1, l.delta, 1, net.delta, 1); +} + +#ifdef GPU + +void forward_l2norm_layer_gpu(const layer l, network net) +{ + copy_gpu(l.outputs*l.batch, net.input_gpu, 1, l.output_gpu, 1); + l2normalize_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_w*l.out_h); +} + +void backward_l2norm_layer_gpu(const layer l, network net) +{ + axpy_gpu(l.batch*l.inputs, 1, l.scales_gpu, 1, l.delta_gpu, 1); + axpy_gpu(l.batch*l.inputs, 1, l.delta_gpu, 1, net.delta_gpu, 1); +} + +#endif diff --git a/src/l2norm_layer.h b/src/l2norm_layer.h new file mode 100644 index 00000000..1ca6f710 --- /dev/null +++ b/src/l2norm_layer.h @@ -0,0 +1,15 @@ +#ifndef L2NORM_LAYER_H +#define L2NORM_LAYER_H +#include "layer.h" +#include "network.h" + +layer make_l2norm_layer(int batch, int inputs); +void forward_l2norm_layer(const layer l, network net); +void backward_l2norm_layer(const layer l, network net); + +#ifdef GPU +void forward_l2norm_layer_gpu(const layer l, network net); +void backward_l2norm_layer_gpu(const layer l, network net); +#endif + +#endif diff --git a/src/logistic_layer.c b/src/logistic_layer.c index 6e2835aa..b2b3d6b1 100644 --- a/src/logistic_layer.c +++ b/src/logistic_layer.c @@ -60,13 +60,12 @@ void forward_logistic_layer_gpu(const layer l, network net) logistic_x_ent_gpu(l.batch*l.inputs, l.output_gpu, net.truth_gpu, l.delta_gpu, l.loss_gpu); cuda_pull_array(l.loss_gpu, l.loss, l.batch*l.inputs); l.cost[0] = sum_array(l.loss, l.batch*l.inputs); - printf("hey: %f\n", l.cost[0]); } } -void backward_logistic_layer_gpu(const layer layer, network net) +void backward_logistic_layer_gpu(const layer l, network net) { - axpy_gpu(layer.batch*layer.inputs, 1, layer.delta_gpu, 1, net.delta_gpu, 1); + axpy_gpu(l.batch*l.inputs, 1, l.delta_gpu, 1, net.delta_gpu, 1); } #endif diff --git a/src/network.c b/src/network.c index b52db448..f4966211 100644 --- a/src/network.c +++ b/src/network.c @@ -378,6 +378,8 @@ int resize_network(network *net, int w, int h) resize_region_layer(&l, w, h); }else if(l.type == ROUTE){ resize_route_layer(&l, net); + }else if(l.type == SHORTCUT){ + resize_shortcut_layer(&l, w, h); }else if(l.type == UPSAMPLE){ resize_upsample_layer(&l, w, h); }else if(l.type == REORG){ diff --git a/src/parser.c b/src/parser.c index ce045c4b..00767f07 100644 --- a/src/parser.c +++ b/src/parser.c @@ -5,6 +5,7 @@ #include "activation_layer.h" #include "logistic_layer.h" +#include "l2norm_layer.h" #include "activations.h" #include "avgpool_layer.h" #include "batchnorm_layer.h" @@ -56,6 +57,7 @@ LAYER_TYPE string_to_layer_type(char * type) || strcmp(type, "[deconvolutional]")==0) return DECONVOLUTIONAL; if (strcmp(type, "[activation]")==0) return ACTIVE; if (strcmp(type, "[logistic]")==0) return LOGXENT; + if (strcmp(type, "[l2norm]")==0) return L2NORM; if (strcmp(type, "[net]")==0 || strcmp(type, "[network]")==0) return NETWORK; if (strcmp(type, "[crnn]")==0) return CRNN; @@ -307,7 +309,7 @@ layer parse_region(list *options, size_params params) l.softmax = option_find_int(options, "softmax", 0); l.background = option_find_int_quiet(options, "background", 0); - l.max_boxes = option_find_int_quiet(options, "max",30); + l.max_boxes = option_find_int_quiet(options, "max",90); l.jitter = option_find_float(options, "jitter", .2); l.rescore = option_find_int_quiet(options, "rescore",0); @@ -356,7 +358,7 @@ detection_layer parse_detection(list *options, size_params params) layer.softmax = option_find_int(options, "softmax", 0); layer.sqrt = option_find_int(options, "sqrt", 0); - layer.max_boxes = option_find_int_quiet(options, "max",30); + layer.max_boxes = option_find_int_quiet(options, "max",90); layer.coord_scale = option_find_float(options, "coord_scale", 1); layer.forced = option_find_int(options, "forced", 0); layer.object_scale = option_find_float(options, "object_scale", 1); @@ -496,10 +498,20 @@ layer parse_shortcut(list *options, size_params params, network *net) } +layer parse_l2norm(list *options, size_params params) +{ + layer l = make_l2norm_layer(params.batch, params.inputs); + l.h = l.out_h = params.h; + l.w = l.out_w = params.w; + l.c = l.out_c = params.c; + return l; +} + + layer parse_logistic(list *options, size_params params) { layer l = make_logistic_layer(params.batch, params.inputs); - l.w = l.out_h = params.h; + l.h = l.out_h = params.h; l.w = l.out_w = params.w; l.c = l.out_c = params.c; return l; @@ -512,12 +524,9 @@ layer parse_activation(list *options, size_params params) layer l = make_activation_layer(params.batch, params.inputs, activation); - l.out_h = params.h; - l.out_w = params.w; - l.out_c = params.c; - l.h = params.h; - l.w = params.w; - l.c = params.c; + l.h = l.out_h = params.h; + l.w = l.out_w = params.w; + l.c = l.out_c = params.c; return l; } @@ -614,6 +623,7 @@ void parse_net_options(list *options, network *net) net->max_ratio = option_find_float_quiet(options, "max_ratio", (float) net->max_crop / net->w); net->min_ratio = option_find_float_quiet(options, "min_ratio", (float) net->min_crop / net->w); net->center = option_find_int_quiet(options, "center",0); + net->clip = option_find_float_quiet(options, "clip", 0); net->angle = option_find_float_quiet(options, "angle", 0); net->aspect = option_find_float_quiet(options, "aspect", 1); @@ -714,6 +724,8 @@ network *parse_network_cfg(char *filename) l = parse_activation(options, params); }else if(lt == LOGXENT){ l = parse_logistic(options, params); + }else if(lt == L2NORM){ + l = parse_l2norm(options, params); }else if(lt == RNN){ l = parse_rnn(options, params); }else if(lt == GRU){ @@ -762,6 +774,7 @@ network *parse_network_cfg(char *filename) }else{ fprintf(stderr, "Type not recognized: %s\n", s->type); } + l.clip = net->clip; l.truth = option_find_int_quiet(options, "truth", 0); l.onlyforward = option_find_int_quiet(options, "onlyforward", 0); l.stopbackward = option_find_int_quiet(options, "stopbackward", 0); diff --git a/src/region_layer.c b/src/region_layer.c index 8313b609..7cf6be8d 100644 --- a/src/region_layer.c +++ b/src/region_layer.c @@ -39,7 +39,7 @@ layer make_region_layer(int batch, int w, int h, int n, int total, int *mask, in l.bias_updates = calloc(n*2, sizeof(float)); l.outputs = h*w*n*(classes + coords + 1); l.inputs = l.outputs; - l.truths = 30*(l.coords + 1); + l.truths = 90*(l.coords + 1); l.delta = calloc(batch*l.outputs, sizeof(float)); l.output = calloc(batch*l.outputs, sizeof(float)); for(i = 0; i < total*2; ++i){ @@ -213,7 +213,7 @@ void forward_region_layer(const layer l, network net) for (b = 0; b < l.batch; ++b) { if(l.softmax_tree){ int onlyclass = 0; - for(t = 0; t < 30; ++t){ + for(t = 0; t < l.max_boxes; ++t){ box truth = float_to_box(net.truth + t*(l.coords + 1) + b*l.truths, 1); if(!truth.x) break; int class = net.truth[t*(l.coords + 1) + b*l.truths + l.coords]; @@ -250,7 +250,7 @@ void forward_region_layer(const layer l, network net) int box_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 0); box pred = get_region_box(l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, net.w, net.h, l.w*l.h); float best_iou = 0; - for(t = 0; t < 30; ++t){ + for(t = 0; t < l.max_boxes; ++t){ box truth = float_to_box(net.truth + t*(l.coords + 1) + b*l.truths, 1); if(!truth.x) break; float iou = box_iou(pred, truth); @@ -279,7 +279,7 @@ void forward_region_layer(const layer l, network net) } } } - for(t = 0; t < 30; ++t){ + for(t = 0; t < l.max_boxes; ++t){ box truth = float_to_box(net.truth + t*(l.coords + 1) + b*l.truths, 1); if(!truth.x) break; diff --git a/src/shortcut_layer.c b/src/shortcut_layer.c index 0818ca7e..e1b9bc52 100644 --- a/src/shortcut_layer.c +++ b/src/shortcut_layer.c @@ -8,7 +8,7 @@ layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2) { - fprintf(stderr,"Shortcut Layer: %d\n", index); + fprintf(stderr, "res %3d %4d x%4d x%4d -> %4d x%4d x%4d\n",index, w2,h2,c2, w,h,c); layer l = {0}; l.type = SHORTCUT; l.batch = batch; @@ -38,6 +38,27 @@ layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int return l; } +void resize_shortcut_layer(layer *l, int w, int h) +{ + assert(l->w == l->out_w); + assert(l->h == l->out_h); + l->w = l->out_w = w; + l->h = l->out_h = h; + l->outputs = w*h*l->out_c; + l->inputs = l->outputs; + l->delta = realloc(l->delta, l->outputs*l->batch*sizeof(float)); + l->output = realloc(l->output, l->outputs*l->batch*sizeof(float)); + +#ifdef GPU + cuda_free(l->output_gpu); + cuda_free(l->delta_gpu); + l->output_gpu = cuda_make_array(l->output, l->outputs*l->batch); + l->delta_gpu = cuda_make_array(l->delta, l->outputs*l->batch); +#endif + +} + + void forward_shortcut_layer(const layer l, network net) { copy_cpu(l.outputs*l.batch, net.input, 1, l.output, 1); diff --git a/src/shortcut_layer.h b/src/shortcut_layer.h index 32e4ebdc..5f684fc1 100644 --- a/src/shortcut_layer.h +++ b/src/shortcut_layer.h @@ -7,6 +7,7 @@ layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2); void forward_shortcut_layer(const layer l, network net); void backward_shortcut_layer(const layer l, network net); +void resize_shortcut_layer(layer *l, int w, int h); #ifdef GPU void forward_shortcut_layer_gpu(const layer l, network net); diff --git a/src/upsample_layer.c b/src/upsample_layer.c index ab6a3b6e..2f5d1d0c 100644 --- a/src/upsample_layer.c +++ b/src/upsample_layer.c @@ -12,10 +12,16 @@ layer make_upsample_layer(int batch, int w, int h, int c, int stride) l.w = w; l.h = h; l.c = c; - l.stride = stride; l.out_w = w*stride; l.out_h = h*stride; l.out_c = c; + if(stride < 0){ + stride = -stride; + l.reverse=1; + l.out_w = w/stride; + l.out_h = h/stride; + } + l.stride = stride; l.outputs = l.out_w*l.out_h*l.out_c; l.inputs = l.w*l.h*l.c; l.delta = calloc(l.outputs*batch, sizeof(float)); @@ -30,7 +36,8 @@ layer make_upsample_layer(int batch, int w, int h, int c, int stride) l.delta_gpu = cuda_make_array(l.delta, l.outputs*batch); l.output_gpu = cuda_make_array(l.output, l.outputs*batch); #endif - fprintf(stderr, "upsample %2dx %4d x%4d x%4d -> %4d x%4d x%4d\n", stride, w, h, c, l.out_w, l.out_h, l.out_c); + if(l.reverse) fprintf(stderr, "downsample %2dx %4d x%4d x%4d -> %4d x%4d x%4d\n", stride, w, h, c, l.out_w, l.out_h, l.out_c); + else fprintf(stderr, "upsample %2dx %4d x%4d x%4d -> %4d x%4d x%4d\n", stride, w, h, c, l.out_w, l.out_h, l.out_c); return l; } @@ -40,6 +47,10 @@ void resize_upsample_layer(layer *l, int w, int h) l->h = h; l->out_w = w*l->stride; l->out_h = h*l->stride; + if(l->reverse){ + l->out_w = w/l->stride; + l->out_h = h/l->stride; + } l->outputs = l->out_w*l->out_h*l->out_c; l->inputs = l->h*l->w*l->c; l->delta = realloc(l->delta, l->outputs*l->batch*sizeof(float)); @@ -56,44 +67,40 @@ void resize_upsample_layer(layer *l, int w, int h) void forward_upsample_layer(const layer l, network net) { - int i, j, k, b; - for(b = 0; b < l.batch; ++b){ - for(k = 0; k < l.c; ++k){ - for(j = 0; j < l.h*l.stride; ++j){ - for(i = 0; i < l.w*l.stride; ++i){ - int in_index = b*l.inputs + k*l.w*l.h + (j/l.stride)*l.w + i/l.stride; - int out_index = b*l.inputs + k*l.w*l.h + j*l.w + i; - l.output[out_index] = net.input[in_index]; - } - } - } + fill_cpu(l.outputs*l.batch, 0, l.output, 1); + if(l.reverse){ + upsample_cpu(l.output, l.out_w, l.out_h, l.c, l.batch, l.stride, 0, net.input); + }else{ + upsample_cpu(net.input, l.w, l.h, l.c, l.batch, l.stride, 1, l.output); } } void backward_upsample_layer(const layer l, network net) { - int i, j, k, b; - for(b = 0; b < l.batch; ++b){ - for(k = 0; k < l.c; ++k){ - for(j = 0; j < l.h*l.stride; ++j){ - for(i = 0; i < l.w*l.stride; ++i){ - int in_index = b*l.inputs + k*l.w*l.h + (j/l.stride)*l.w + i/l.stride; - int out_index = b*l.inputs + k*l.w*l.h + j*l.w + i; - net.delta[in_index] += l.delta[out_index]; - } - } - } + if(l.reverse){ + upsample_cpu(l.delta, l.out_w, l.out_h, l.c, l.batch, l.stride, 1, net.delta); + }else{ + upsample_cpu(net.delta, l.w, l.h, l.c, l.batch, l.stride, 0, l.delta); } } #ifdef GPU void forward_upsample_layer_gpu(const layer l, network net) { - upsample_gpu(net.input_gpu, l.w, l.h, l.c, l.batch, l.stride, 1, l.output_gpu); + fill_gpu(l.outputs*l.batch, 0, l.output_gpu, 1); + if(l.reverse){ + upsample_gpu(l.output_gpu, l.out_w, l.out_h, l.c, l.batch, l.stride, 0, net.input_gpu); + }else{ + upsample_gpu(net.input_gpu, l.w, l.h, l.c, l.batch, l.stride, 1, l.output_gpu); + } } void backward_upsample_layer_gpu(const layer l, network net) { - upsample_gpu(net.delta_gpu, l.w, l.h, l.c, l.batch, l.stride, 0, l.delta_gpu); + if(l.reverse){ + upsample_gpu(l.delta_gpu, l.out_w, l.out_h, l.c, l.batch, l.stride, 1, net.delta_gpu); + }else{ + upsample_gpu(net.delta_gpu, l.w, l.h, l.c, l.batch, l.stride, 0, l.delta_gpu); + } } #endif