diff --git a/Makefile b/Makefile index 26c4076c..2180094f 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,7 @@ OPENCV=1 DEBUG=0 ARCH= --gpu-architecture=compute_20 --gpu-code=compute_20 +ARCH= -arch=sm_52 --use_fast_math VPATH=./src/ EXEC=darknet @@ -36,7 +37,7 @@ endif OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o swag.o classifier.o ifeq ($(GPU), 1) -OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o +OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o swag_kernels.o endif OBJS = $(addprefix $(OBJDIR), $(OBJ)) diff --git a/cfg/strided.cfg b/cfg/strided.cfg index 4fd71e88..a52700b4 100644 --- a/cfg/strided.cfg +++ b/cfg/strided.cfg @@ -4,10 +4,16 @@ subdivisions=4 height=256 width=256 channels=3 -learning_rate=0.01 momentum=0.9 decay=0.0005 +learning_rate=0.01 +policy=steps +scales=.1,.1,.1 +steps=200000,300000,400000 +max_batches=800000 + + [crop] crop_height=224 crop_width=224 @@ -15,6 +21,7 @@ flip=1 angle=0 saturation=1 exposure=1 +shift=.2 [convolutional] filters=64 @@ -160,9 +167,6 @@ activation=ramp size=3 stride=2 -[dropout] -probability=0.5 - [connected] output=4096 activation=ramp diff --git a/cfg/yolo.cfg b/cfg/yolo.cfg index 140de88a..ee167269 100644 --- a/cfg/yolo.cfg +++ b/cfg/yolo.cfg @@ -1,210 +1,235 @@ [net] batch=64 -subdivisions=4 +subdivisions=2 height=448 width=448 channels=3 -learning_rate=0.01 momentum=0.9 decay=0.0005 +learning_rate=0.001 policy=steps -steps=20000 -scales=.1 -max_batches = 35000 +steps=100,200,300,400,500,600,700,20000,30000 +scales=2,2,1.25,1.25,1.25,1.25,1.03,.1,.1 +max_batches = 40000 [crop] crop_width=448 crop_height=448 flip=0 angle=0 -saturation = 2 -exposure = 2 +saturation = 1.5 +exposure = 1.5 [convolutional] filters=64 size=7 stride=2 pad=1 -activation=ramp +activation=leaky + +[maxpool] +size=2 +stride=2 [convolutional] filters=192 size=3 -stride=2 +stride=1 pad=1 -activation=ramp +activation=leaky + +[maxpool] +size=2 +stride=2 [convolutional] filters=128 size=1 stride=1 pad=1 -activation=ramp +activation=leaky [convolutional] filters=256 size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +filters=256 +size=1 +stride=1 +pad=1 +activation=leaky + +[convolutional] +filters=512 +size=3 +stride=1 +pad=1 +activation=leaky + +[maxpool] +size=2 stride=2 -pad=1 -activation=ramp - -[convolutional] -filters=128 -size=1 -stride=1 -pad=1 -activation=ramp - -[convolutional] -filters=256 -size=3 -stride=1 -pad=1 -activation=ramp - -[convolutional] -filters=128 -size=1 -stride=1 -pad=1 -activation=ramp - -[convolutional] -filters=512 -size=3 -stride=2 -pad=1 -activation=ramp [convolutional] filters=256 size=1 stride=1 pad=1 -activation=ramp +activation=leaky [convolutional] filters=512 size=3 stride=1 pad=1 -activation=ramp +activation=leaky [convolutional] filters=256 size=1 stride=1 pad=1 -activation=ramp +activation=leaky [convolutional] filters=512 size=3 stride=1 pad=1 -activation=ramp +activation=leaky [convolutional] filters=256 size=1 stride=1 pad=1 -activation=ramp +activation=leaky [convolutional] filters=512 size=3 stride=1 pad=1 -activation=ramp +activation=leaky [convolutional] filters=256 size=1 stride=1 pad=1 -activation=ramp +activation=leaky [convolutional] filters=512 size=3 stride=1 pad=1 -activation=ramp - -[convolutional] -filters=256 -size=1 -stride=1 -pad=1 -activation=ramp - -[convolutional] -filters=1024 -size=3 -stride=2 -pad=1 -activation=ramp +activation=leaky [convolutional] filters=512 size=1 stride=1 pad=1 -activation=ramp +activation=leaky [convolutional] filters=1024 size=3 stride=1 pad=1 -activation=ramp +activation=leaky + +[maxpool] +size=2 +stride=2 + +[convolutional] +filters=512 +size=1 +stride=1 +pad=1 +activation=leaky + +[convolutional] +filters=1024 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +filters=512 +size=1 +stride=1 +pad=1 +activation=leaky + +[convolutional] +filters=1024 +size=3 +stride=1 +pad=1 +activation=leaky + +####### [convolutional] size=3 stride=1 pad=1 filters=1024 -activation=ramp +activation=leaky [convolutional] size=3 stride=2 pad=1 filters=1024 -activation=ramp +activation=leaky [convolutional] size=3 stride=1 pad=1 filters=1024 -activation=ramp +activation=leaky [convolutional] size=3 stride=1 pad=1 filters=1024 -activation=ramp +activation=leaky [connected] output=4096 -activation=ramp +activation=leaky [dropout] probability=.5 [connected] -output=1225 -activation=logistic +output= 1470 +activation=linear -[detection] +[region] classes=20 coords=4 -rescore=0 -joint=0 -objectness=1 -background=0 +rescore=1 +side=7 +num=2 +softmax=0 +sqrt=1 +jitter=.2 + +object_scale=1 +noobject_scale=.5 +class_scale=1 +coord_scale=5 diff --git a/src/blas.c b/src/blas.c index 8d93dc74..37859376 100644 --- a/src/blas.c +++ b/src/blas.c @@ -1,6 +1,51 @@ #include "blas.h" #include "math.h" +void mean_cpu(float *x, int batch, int filters, int spatial, float *mean) +{ + float scale = 1./(batch * spatial); + int i,j,k; + for(i = 0; i < filters; ++i){ + mean[i] = 0; + for(j = 0; j < batch; ++j){ + for(k = 0; k < spatial; ++k){ + int index = j*filters*spatial + i*spatial + k; + mean[i] += x[index]; + } + } + mean[i] *= scale; + } +} + +void variance_cpu(float *x, float *mean, int batch, int filters, int spatial, float *variance) +{ + float scale = 1./(batch * spatial); + int i,j,k; + for(i = 0; i < filters; ++i){ + variance[i] = 0; + for(j = 0; j < batch; ++j){ + for(k = 0; k < spatial; ++k){ + int index = j*filters*spatial + i*spatial + k; + variance[i] += pow((x[index] - mean[i]), 2); + } + } + variance[i] *= scale; + } +} + +void normalize_cpu(float *x, float *mean, float *variance, int batch, int filters, int spatial) +{ + int b, f, i; + for(b = 0; b < batch; ++b){ + for(f = 0; f < filters; ++f){ + for(i = 0; i < spatial; ++i){ + int index = b*filters*spatial + f*spatial + i; + x[index] = (x[index] - mean[f])/(sqrt(variance[f])); + } + } + } +} + void const_cpu(int N, float ALPHA, float *X, int INCX) { int i; diff --git a/src/blas.h b/src/blas.h index 99099253..be7da00b 100644 --- a/src/blas.h +++ b/src/blas.h @@ -16,6 +16,10 @@ void scal_cpu(int N, float ALPHA, float *X, int INCX); float dot_cpu(int N, float *X, int INCX, float *Y, int INCY); void test_gpu_blas(); +void mean_cpu(float *x, int batch, int filters, int spatial, float *mean); +void variance_cpu(float *x, float *mean, int batch, int filters, int spatial, float *variance); +void normalize_cpu(float *x, float *mean, float *variance, int batch, int filters, int spatial); + #ifdef GPU void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY); void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY); @@ -26,6 +30,20 @@ void mask_ongpu(int N, float * X, float mask_num, float * mask); void const_ongpu(int N, float ALPHA, float *X, int INCX); void pow_ongpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY); void mul_ongpu(int N, float *X, int INCX, float *Y, int INCY); +void fill_ongpu(int N, float ALPHA, float * X, int INCX); +void mean_gpu(float *x, int batch, int filters, int spatial, float *mean); +void variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance); +void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial); + +void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta); +void variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta); +void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta); + +void fast_mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *spatial_mean_delta, float *mean_delta); +void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *spatial_variance_delta, float *variance_delta); + +void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *spatial_variance, float *variance); +void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *spatial_mean, float *mean); #endif #endif diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index 0c89c475..b990ca33 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -4,6 +4,181 @@ extern "C" { #include "utils.h" } +__global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial) +{ + int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (index >= N) return; + int f = (index/spatial)%filters; + + x[index] = (x[index] - mean[f])/(sqrt(variance[f]) + .00001f); +} + +__global__ void normalize_delta_kernel(int N, float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta) +{ + int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (index >= N) return; + int f = (index/spatial)%filters; + + delta[index] = delta[index] * 1./(sqrt(variance[f]) + .00001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch); +} + +extern "C" void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta) +{ + size_t N = batch*filters*spatial; + normalize_delta_kernel<<>>(N, x, mean, variance, mean_delta, variance_delta, batch, filters, spatial, delta); + check_error(cudaPeekAtLastError()); +} + +__global__ void variance_delta_kernel(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta) +{ + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i >= filters) return; + int j,k; + variance_delta[i] = 0; + for(j = 0; j < batch; ++j){ + for(k = 0; k < spatial; ++k){ + int index = j*filters*spatial + i*spatial + k; + variance_delta[i] += delta[index]*(x[index] - mean[i]); + } + } + variance_delta[i] *= -.5 * pow(variance[i] + .00001f, (float)(-3./2.)); +} + +__global__ void spatial_variance_delta_kernel(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *spatial_variance_delta) +{ + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i >= batch*filters) return; + int f = i%filters; + int b = i/filters; + + int k; + spatial_variance_delta[i] = 0; + for (k = 0; k < spatial; ++k) { + int index = b*filters*spatial + f*spatial + k; + spatial_variance_delta[i] += delta[index]*(x[index] - mean[f]); + } + spatial_variance_delta[i] *= -.5 * pow(variance[f] + .00001f, (float)(-3./2.)); +} + +extern "C" void variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta) +{ + variance_delta_kernel<<>>(x, delta, mean, variance, batch, filters, spatial, variance_delta); + check_error(cudaPeekAtLastError()); +} + +__global__ void accumulate_kernel(float *x, int n, int groups, float *sum) +{ + int k; + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i >= groups) return; + sum[i] = 0; + for(k = 0; k < n; ++k){ + sum[i] += x[k*groups + i]; + } +} + +extern "C" void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *spatial_variance_delta, float *variance_delta) +{ + spatial_variance_delta_kernel<<>>(x, delta, mean, variance, batch, filters, spatial, spatial_variance_delta); + check_error(cudaPeekAtLastError()); + accumulate_kernel<<>>(spatial_variance_delta, batch, filters, variance_delta); + check_error(cudaPeekAtLastError()); +} + +__global__ void spatial_mean_delta_kernel(float *delta, float *variance, int batch, int filters, int spatial, float *spatial_mean_delta) +{ + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i >= batch*filters) return; + int f = i%filters; + int b = i/filters; + + int k; + spatial_mean_delta[i] = 0; + for (k = 0; k < spatial; ++k) { + int index = b*filters*spatial + f*spatial + k; + spatial_mean_delta[i] += delta[index]; + } + spatial_mean_delta[i] *= (-1./sqrt(variance[f] + .00001f)); +} + +extern "C" void fast_mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *spatial_mean_delta, float *mean_delta) +{ + spatial_mean_delta_kernel<<>>(delta, variance, batch, filters, spatial, spatial_mean_delta); + check_error(cudaPeekAtLastError()); + accumulate_kernel<<>>(spatial_mean_delta, batch, filters, mean_delta); + check_error(cudaPeekAtLastError()); +} + +__global__ void mean_delta_kernel(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta) +{ + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i >= filters) return; + int j,k; + mean_delta[i] = 0; + for (j = 0; j < batch; ++j) { + for (k = 0; k < spatial; ++k) { + int index = j*filters*spatial + i*spatial + k; + mean_delta[i] += delta[index]; + } + } + mean_delta[i] *= (-1./sqrt(variance[i] + .00001f)); +} + +extern "C" void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta) +{ + mean_delta_kernel<<>>(delta, variance, batch, filters, spatial, mean_delta); + check_error(cudaPeekAtLastError()); +} + +__global__ void mean_kernel(float *x, int batch, int filters, int spatial, float *mean) +{ + float scale = 1./(batch * spatial); + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i >= filters) return; + int j,k; + mean[i] = 0; + for(j = 0; j < batch; ++j){ + for(k = 0; k < spatial; ++k){ + int index = j*filters*spatial + i*spatial + k; + mean[i] += x[index]; + } + } + mean[i] *= scale; +} + +__global__ void spatial_variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance) +{ + float scale = 1./(spatial*batch-1); + int k; + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i >= batch*filters) return; + int f = i%filters; + int b = i/filters; + + variance[i] = 0; + for(k = 0; k < spatial; ++k){ + int index = b*filters*spatial + f*spatial + k; + variance[i] += pow((x[index] - mean[f]), 2); + } + variance[i] *= scale; +} + +__global__ void variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance) +{ + float scale = 1./(batch * spatial); + int j,k; + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i >= filters) return; + variance[i] = 0; + for(j = 0; j < batch; ++j){ + for(k = 0; k < spatial; ++k){ + int index = j*filters*spatial + i*spatial + k; + variance[i] += pow((x[index] - mean[i]), 2); + } + } + variance[i] *= scale; +} + __global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; @@ -28,6 +203,12 @@ __global__ void scal_kernel(int N, float ALPHA, float *X, int INCX) if(i < N) X[i*INCX] *= ALPHA; } +__global__ void fill_kernel(int N, float ALPHA, float *X, int INCX) +{ + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if(i < N) X[i*INCX] = ALPHA; +} + __global__ void mask_kernel(int n, float *x, float mask_num, float *mask) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; @@ -46,6 +227,41 @@ __global__ void mul_kernel(int N, float *X, int INCX, float *Y, int INCY) if(i < N) Y[i*INCY] *= X[i*INCX]; } +extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial) +{ + size_t N = batch*filters*spatial; + normalize_kernel<<>>(N, x, mean, variance, batch, filters, spatial); + check_error(cudaPeekAtLastError()); +} + +extern "C" void mean_gpu(float *x, int batch, int filters, int spatial, float *mean) +{ + mean_kernel<<>>(x, batch, filters, spatial, mean); + check_error(cudaPeekAtLastError()); +} + +extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *spatial_mean, float *mean) +{ + mean_kernel<<>>(x, 1, filters*batch, spatial, spatial_mean); + check_error(cudaPeekAtLastError()); + mean_kernel<<>>(spatial_mean, batch, filters, 1, mean); + check_error(cudaPeekAtLastError()); +} + +extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *spatial_variance, float *variance) +{ + spatial_variance_kernel<<>>(x, mean, batch, filters, spatial, spatial_variance); + check_error(cudaPeekAtLastError()); + accumulate_kernel<<>>(spatial_variance, batch, filters, variance); + check_error(cudaPeekAtLastError()); +} + +extern "C" void variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance) +{ + variance_kernel<<>>(x, mean, batch, filters, spatial, variance); + check_error(cudaPeekAtLastError()); +} + extern "C" void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY) { axpy_ongpu_offset(N, ALPHA, X, 0, INCX, Y, 0, INCY); @@ -97,3 +313,9 @@ extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX) scal_kernel<<>>(N, ALPHA, X, INCX); check_error(cudaPeekAtLastError()); } + +extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX) +{ + fill_kernel<<>>(N, ALPHA, X, INCX); + check_error(cudaPeekAtLastError()); +} diff --git a/src/box.c b/src/box.c index b99300d3..95685991 100644 --- a/src/box.c +++ b/src/box.c @@ -1,6 +1,7 @@ #include "box.h" #include #include +#include box float_to_box(float *f) { @@ -229,6 +230,52 @@ dbox diou(box a, box b) return dd; } +typedef struct{ + int index; + int class; + float **probs; +} sortable_bbox; + +int nms_comparator(const void *pa, const void *pb) +{ + sortable_bbox a = *(sortable_bbox *)pa; + sortable_bbox b = *(sortable_bbox *)pb; + float diff = a.probs[a.index][b.class] - b.probs[b.index][b.class]; + if(diff < 0) return 1; + else if(diff > 0) return -1; + return 0; +} + +void do_nms_sort(box *boxes, float **probs, int total, int classes, float thresh) +{ + int i, j, k; + sortable_bbox *s = calloc(total, sizeof(sortable_bbox)); + + for(i = 0; i < total; ++i){ + s[i].index = i; + s[i].class = 0; + s[i].probs = probs; + } + + for(k = 0; k < classes; ++k){ + for(i = 0; i < total; ++i){ + s[i].class = k; + } + qsort(s, total, sizeof(sortable_bbox), nms_comparator); + for(i = 0; i < total; ++i){ + if(probs[s[i].index][k] == 0) continue; + box a = boxes[s[i].index]; + for(j = i+1; j < total; ++j){ + box b = boxes[s[j].index]; + if (box_iou(a, b) > thresh){ + probs[s[j].index][k] = 0; + } + } + } + } + free(s); +} + void do_nms(box *boxes, float **probs, int total, int classes, float thresh) { int i, j, k; diff --git a/src/box.h b/src/box.h index 9b57fb45..a5f8cee3 100644 --- a/src/box.h +++ b/src/box.h @@ -14,6 +14,7 @@ float box_iou(box a, box b); float box_rmse(box a, box b); dbox diou(box a, box b); void do_nms(box *boxes, float **probs, int total, int classes, float thresh); +void do_nms_sort(box *boxes, float **probs, int total, int classes, float thresh); box decode_box(box b, box anchor); box encode_box(box b, box anchor); diff --git a/src/classifier.c b/src/classifier.c new file mode 100644 index 00000000..e2439659 --- /dev/null +++ b/src/classifier.c @@ -0,0 +1,316 @@ +#include "network.h" +#include "utils.h" +#include "parser.h" +#include "option_list.h" + +#ifdef OPENCV +#include "opencv2/highgui/highgui_c.h" +#endif + +list *read_data_cfg(char *filename) +{ + FILE *file = fopen(filename, "r"); + if(file == 0) file_error(filename); + char *line; + int nu = 0; + list *options = make_list(); + while((line=fgetl(file)) != 0){ + ++ nu; + strip(line); + switch(line[0]){ + case '\0': + case '#': + case ';': + free(line); + break; + default: + if(!read_option(line, options)){ + fprintf(stderr, "Config file error line %d, could parse: %s\n", nu, line); + free(line); + } + break; + } + } + fclose(file); + return options; +} + +void train_classifier(char *datacfg, char *cfgfile, char *weightfile) +{ + data_seed = time(0); + srand(time(0)); + float avg_loss = -1; + char *base = basecfg(cfgfile); + printf("%s\n", base); + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + int imgs = 1024; + + list *options = read_data_cfg(datacfg); + + char *backup_directory = option_find_str(options, "backup", "/backup/"); + char *label_list = option_find_str(options, "labels", "data/labels.list"); + char *train_list = option_find_str(options, "train", "data/train.list"); + int classes = option_find_int(options, "classes", 2); + + char **labels = get_labels(label_list); + list *plist = get_paths(train_list); + char **paths = (char **)list_to_array(plist); + printf("%d\n", plist->size); + int N = plist->size; + clock_t time; + pthread_t load_thread; + data train; + data buffer; + + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.paths = paths; + args.classes = classes; + args.n = imgs; + args.m = N; + args.labels = labels; + args.d = &buffer; + args.type = CLASSIFICATION_DATA; + + load_thread = load_data_in_thread(args); + int epoch = (*net.seen)/N; + while(get_current_batch(net) < net.max_batches || net.max_batches == 0){ + time=clock(); + pthread_join(load_thread, 0); + train = buffer; + + load_thread = load_data_in_thread(args); + printf("Loaded: %lf seconds\n", sec(clock()-time)); + time=clock(); + float loss = train_network(net, train); + if(avg_loss == -1) avg_loss = loss; + avg_loss = avg_loss*.9 + loss*.1; + printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen); + free_data(train); + if(*net.seen/N > epoch){ + epoch = *net.seen/N; + char buff[256]; + sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch); + save_weights(net, buff); + } + } + char buff[256]; + sprintf(buff, "%s/%s.weights", backup_directory, base); + save_weights(net, buff); + + pthread_join(load_thread, 0); + free_data(buffer); + free_network(net); + free_ptrs((void**)labels, classes); + free_ptrs((void**)paths, plist->size); + free_list(plist); + free(base); +} + +void validate_classifier(char *datacfg, char *filename, char *weightfile) +{ + int i = 0; + network net = parse_network_cfg(filename); + if(weightfile){ + load_weights(&net, weightfile); + } + srand(time(0)); + + list *options = read_data_cfg(datacfg); + + char *label_list = option_find_str(options, "labels", "data/labels.list"); + char *valid_list = option_find_str(options, "valid", "data/train.list"); + int classes = option_find_int(options, "classes", 2); + int topk = option_find_int(options, "topk", 1); + + char **labels = get_labels(label_list); + list *plist = get_paths(valid_list); + + char **paths = (char **)list_to_array(plist); + int m = plist->size; + free_list(plist); + + clock_t time; + float avg_acc = 0; + float avg_topk = 0; + int splits = 50; + int num = (i+1)*m/splits - i*m/splits; + + data val, buffer; + + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.paths = paths; + args.classes = classes; + args.n = num; + args.m = 0; + args.labels = labels; + args.d = &buffer; + args.type = CLASSIFICATION_DATA; + + pthread_t load_thread = load_data_in_thread(args); + for(i = 1; i <= splits; ++i){ + time=clock(); + + pthread_join(load_thread, 0); + val = buffer; + + num = (i+1)*m/splits - i*m/splits; + char **part = paths+(i*m/splits); + if(i != splits){ + args.paths = part; + load_thread = load_data_in_thread(args); + } + printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time)); + + time=clock(); + float *acc = network_accuracies(net, val, topk); + avg_acc += acc[0]; + avg_topk += acc[1]; + printf("%d: top 1: %f, top %d: %f, %lf seconds, %d images\n", i, avg_acc/i, topk, avg_topk/i, sec(clock()-time), val.X.rows); + free_data(val); + } +} + +void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename) +{ + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + set_batch_network(&net, 1); + srand(2222222); + + list *options = read_data_cfg(datacfg); + + char *label_list = option_find_str(options, "labels", "data/labels.list"); + int top = option_find_int(options, "top", 1); + + int i = 0; + char **names = get_labels(label_list); + clock_t time; + int indexes[10]; + char buff[256]; + char *input = buff; + while(1){ + if(filename){ + strncpy(input, filename, 256); + }else{ + printf("Enter Image Path: "); + fflush(stdout); + input = fgets(input, 256, stdin); + if(!input) return; + strtok(input, "\n"); + } + image im = load_image_color(input, 256, 256); + float *X = im.data; + time=clock(); + float *predictions = network_predict(net, X); + top_predictions(net, top, indexes); + printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); + for(i = 0; i < top; ++i){ + int index = indexes[i]; + printf("%s: %f\n", names[index], predictions[index]); + } + free_image(im); + if (filename) break; + } +} + +void test_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int target_layer) +{ + int curr = 0; + network net = parse_network_cfg(filename); + if(weightfile){ + load_weights(&net, weightfile); + } + srand(time(0)); + + list *options = read_data_cfg(datacfg); + + char *test_list = option_find_str(options, "test", "data/test.list"); + char *label_list = option_find_str(options, "labels", "data/labels.list"); + int classes = option_find_int(options, "classes", 2); + + char **labels = get_labels(label_list); + list *plist = get_paths(test_list); + + char **paths = (char **)list_to_array(plist); + int m = plist->size; + free_list(plist); + + clock_t time; + + data val, buffer; + + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.paths = paths; + args.classes = classes; + args.n = net.batch; + args.m = 0; + args.labels = labels; + args.d = &buffer; + args.type = CLASSIFICATION_DATA; + + pthread_t load_thread = load_data_in_thread(args); + for(curr = net.batch; curr < m; curr += net.batch){ + time=clock(); + + pthread_join(load_thread, 0); + val = buffer; + + if(curr < m){ + args.paths = paths + curr; + if (curr + net.batch > m) args.n = m - curr; + load_thread = load_data_in_thread(args); + } + fprintf(stderr, "Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time)); + + time=clock(); + matrix pred = network_predict_data(net, val); + + int i; + if (target_layer >= 0){ + //layer l = net.layers[target_layer]; + } + + for(i = 0; i < val.X.rows; ++i){ + + } + + free_matrix(pred); + + fprintf(stderr, "%lf seconds, %d images\n", sec(clock()-time), val.X.rows); + free_data(val); + } +} + + +void run_classifier(int argc, char **argv) +{ + if(argc < 4){ + fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); + return; + } + + char *data = argv[3]; + char *cfg = argv[4]; + char *weights = (argc > 5) ? argv[5] : 0; + char *filename = (argc > 6) ? argv[6]: 0; + char *layer_s = (argc > 7) ? argv[7]: 0; + int layer = layer_s ? atoi(layer_s) : -1; + if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename); + else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights); + else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights,filename, layer); + else if(0==strcmp(argv[2], "valid")) validate_classifier(data, cfg, weights); +} + + diff --git a/src/coco.c b/src/coco.c index f6b135f2..e30eeb7e 100644 --- a/src/coco.c +++ b/src/coco.c @@ -1,7 +1,7 @@ #include #include "network.h" -#include "detection_layer.h" +#include "region_layer.h" #include "cost_layer.h" #include "utils.h" #include "parser.h" @@ -15,32 +15,27 @@ char *coco_classes[] = {"person","bicycle","car","motorcycle","airplane","bus"," int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90}; -void draw_coco(image im, float *pred, int side, char *label) +void draw_coco(image im, int num, float thresh, box *boxes, float **probs, char *label) { - int classes = 1; - int elems = 4+classes; - int j; - int r, c; + int classes = 80; + int i; - for(r = 0; r < side; ++r){ - for(c = 0; c < side; ++c){ - j = (r*side + c) * elems; - int class = max_index(pred+j, classes); - if (pred[j+class] > 0.2){ - int width = pred[j+class]*5 + 1; - printf("%f %s\n", pred[j+class], "object"); //coco_classes[class-1]); - float red = get_color(0,class,classes); - float green = get_color(1,class,classes); - float blue = get_color(2,class,classes); + for(i = 0; i < num; ++i){ + int class = max_index(probs[i], classes); + float prob = probs[i][class]; + if(prob > thresh){ + int width = sqrt(prob)*5 + 1; + printf("%f %s\n", prob, coco_classes[class]); + float red = get_color(0,class,classes); + float green = get_color(1,class,classes); + float blue = get_color(2,class,classes); + box b = boxes[i]; - j += classes; - - box predict = {pred[j+0], pred[j+1], pred[j+2], pred[j+3]}; - predict.x = (predict.x+c)/side; - predict.y = (predict.y+r)/side; - - draw_bbox(im, predict, width, red, green, blue); - } + int left = (b.x-b.w/2.)*im.w; + int right = (b.x+b.w/2.)*im.w; + int top = (b.y-b.h/2.)*im.h; + int bot = (b.y+b.h/2.)*im.h; + draw_box_width(im, left, top, right, bot, width, red, green, blue); } } show_image(im, label); @@ -48,8 +43,8 @@ void draw_coco(image im, float *pred, int side, char *label) void train_coco(char *cfgfile, char *weightfile) { - //char *train_images = "/home/pjreddie/data/coco/train.txt"; - char *train_images = "/home/pjreddie/data/voc/test/train.txt"; + //char *train_images = "/home/pjreddie/data/voc/test/train.txt"; + char *train_images = "/home/pjreddie/data/coco/train.txt"; char *backup_directory = "/home/pjreddie/backup/"; srand(time(0)); data_seed = time(0); @@ -61,7 +56,7 @@ void train_coco(char *cfgfile, char *weightfile) load_weights(&net, weightfile); } printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); - int imgs = 128; + int imgs = net.batch*net.subdivisions; int i = *net.seen/imgs; data train, buffer; @@ -70,9 +65,10 @@ void train_coco(char *cfgfile, char *weightfile) int side = l.side; int classes = l.classes; + float jitter = l.jitter; list *plist = get_paths(train_images); - int N = plist->size; + //int N = plist->size; char **paths = (char **)list_to_array(plist); load_args args = {0}; @@ -82,13 +78,15 @@ void train_coco(char *cfgfile, char *weightfile) args.n = imgs; args.m = plist->size; args.classes = classes; + args.jitter = jitter; args.num_boxes = side; args.d = &buffer; args.type = REGION_DATA; pthread_t load_thread = load_data_in_thread(args); clock_t time; - while(i*imgs < N*120){ + //while(i*imgs < N*120){ + while(get_current_batch(net) < net.max_batches){ i += 1; time=clock(); pthread_join(load_thread, 0); @@ -97,20 +95,20 @@ void train_coco(char *cfgfile, char *weightfile) printf("Loaded: %lf seconds\n", sec(clock()-time)); -/* - image im = float_to_image(net.w, net.h, 3, train.X.vals[113]); - image copy = copy_image(im); - draw_coco(copy, train.y.vals[113], 7, "truth"); - cvWaitKey(0); - free_image(copy); - */ + /* + image im = float_to_image(net.w, net.h, 3, train.X.vals[113]); + image copy = copy_image(im); + draw_coco(copy, train.y.vals[113], 7, "truth"); + cvWaitKey(0); + free_image(copy); + */ time=clock(); float loss = train_network(net, train); if (avg_loss < 0) avg_loss = loss; avg_loss = avg_loss*.9 + loss*.1; - printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs); + printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs); if(i%1000==0){ char buff[256]; sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); @@ -123,60 +121,38 @@ void train_coco(char *cfgfile, char *weightfile) save_weights(net, buff); } -void get_probs(float *predictions, int total, int classes, int inc, float **probs) +void convert_coco_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness) { - int i,j; - for (i = 0; i < total; ++i){ - int index = i*inc; - float scale = predictions[index]; - probs[i][0] = scale; - for(j = 0; j < classes; ++j){ - probs[i][j] = scale*predictions[index+j+1]; + int i,j,n; + //int per_cell = 5*num+classes; + for (i = 0; i < side*side; ++i){ + int row = i / side; + int col = i % side; + for(n = 0; n < num; ++n){ + int index = i*num + n; + int p_index = side*side*classes + i*num + n; + float scale = predictions[p_index]; + int box_index = side*side*(classes + num) + (i*num + n)*4; + boxes[index].x = (predictions[box_index + 0] + col) / side * w; + boxes[index].y = (predictions[box_index + 1] + row) / side * h; + boxes[index].w = pow(predictions[box_index + 2], (square?2:1)) * w; + boxes[index].h = pow(predictions[box_index + 3], (square?2:1)) * h; + for(j = 0; j < classes; ++j){ + int class_index = i*classes; + float prob = scale*predictions[class_index+j]; + probs[index][j] = (prob > thresh) ? prob : 0; + } + if(only_objectness){ + probs[index][0] = scale; + } } } } -void get_boxes(float *predictions, int n, int num_boxes, int per_box, box *boxes) -{ - int i,j; - for (i = 0; i < num_boxes*num_boxes; ++i){ - for(j = 0; j < n; ++j){ - int index = i*n+j; - int offset = index*per_box; - int row = i / num_boxes; - int col = i % num_boxes; - boxes[index].x = (predictions[offset + 0] + col) / num_boxes; - boxes[index].y = (predictions[offset + 1] + row) / num_boxes; - boxes[index].w = predictions[offset + 2]; - boxes[index].h = predictions[offset + 3]; - } - } -} - -void convert_cocos(float *predictions, int classes, int num_boxes, int num, int w, int h, float thresh, float **probs, box *boxes) -{ - int i,j; - int per_box = 4+classes; - for (i = 0; i < num_boxes*num_boxes*num; ++i){ - int offset = i*per_box; - for(j = 0; j < classes; ++j){ - float prob = predictions[offset+j]; - probs[i][j] = (prob > thresh) ? prob : 0; - } - int row = i / num_boxes; - int col = i % num_boxes; - offset += classes; - boxes[i].x = (predictions[offset + 0] + col) / num_boxes; - boxes[i].y = (predictions[offset + 1] + row) / num_boxes; - boxes[i].w = predictions[offset + 2]; - boxes[i].h = predictions[offset + 3]; - } -} - void print_cocos(FILE *fp, int image_id, box *boxes, float **probs, int num_boxes, int classes, int w, int h) { int i, j; - for(i = 0; i < num_boxes*num_boxes; ++i){ + for(i = 0; i < num_boxes; ++i){ float xmin = boxes[i].x - boxes[i].w/2.; float xmax = boxes[i].x + boxes[i].w/2.; float ymin = boxes[i].y - boxes[i].h/2.; @@ -204,201 +180,6 @@ int get_coco_image_id(char *filename) return atoi(p+1); } -void validate_recall(char *cfgfile, char *weightfile) -{ - network net = parse_network_cfg(cfgfile); - if(weightfile){ - load_weights(&net, weightfile); - } - set_batch_network(&net, 1); - fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); - srand(time(0)); - - char *val_images = "/home/pjreddie/data/voc/test/2007_test.txt"; - list *plist = get_paths(val_images); - char **paths = (char **)list_to_array(plist); - - layer l = net.layers[net.n - 1]; - - int num_boxes = l.side; - int num = l.n; - int classes = l.classes; - - int j; - - box *boxes = calloc(num_boxes*num_boxes*num, sizeof(box)); - float **probs = calloc(num_boxes*num_boxes*num, sizeof(float *)); - for(j = 0; j < num_boxes*num_boxes*num; ++j) probs[j] = calloc(classes+1, sizeof(float *)); - - int N = plist->size; - int i=0; - int k; - - float iou_thresh = .5; - float thresh = .1; - int total = 0; - int correct = 0; - float avg_iou = 0; - int nms = 1; - int proposals = 0; - int save = 1; - - for (i = 0; i < N; ++i) { - char *path = paths[i]; - image orig = load_image_color(path, 0, 0); - image resized = resize_image(orig, net.w, net.h); - - float *X = resized.data; - float *predictions = network_predict(net, X); - get_boxes(predictions+1+classes, num, num_boxes, 5+classes, boxes); - get_probs(predictions, num*num_boxes*num_boxes, classes, 5+classes, probs); - if (nms) do_nms(boxes, probs, num*num_boxes*num_boxes, (classes>0) ? classes : 1, iou_thresh); - - char *labelpath = find_replace(path, "images", "labels"); - labelpath = find_replace(labelpath, "JPEGImages", "labels"); - labelpath = find_replace(labelpath, ".jpg", ".txt"); - labelpath = find_replace(labelpath, ".JPEG", ".txt"); - - int num_labels = 0; - box_label *truth = read_boxes(labelpath, &num_labels); - for(k = 0; k < num_boxes*num_boxes*num; ++k){ - if(probs[k][0] > thresh){ - ++proposals; - if(save){ - char buff[256]; - sprintf(buff, "/data/extracted/nms_preds/%d", proposals); - int dx = (boxes[k].x - boxes[k].w/2) * orig.w; - int dy = (boxes[k].y - boxes[k].h/2) * orig.h; - int w = boxes[k].w * orig.w; - int h = boxes[k].h * orig.h; - image cropped = crop_image(orig, dx, dy, w, h); - image sized = resize_image(cropped, 224, 224); -#ifdef OPENCV - save_image_jpg(sized, buff); -#endif - free_image(sized); - free_image(cropped); - sprintf(buff, "/data/extracted/nms_pred_boxes/%d.txt", proposals); - char *im_id = basecfg(path); - FILE *fp = fopen(buff, "w"); - fprintf(fp, "%s %d %d %d %d\n", im_id, dx, dy, dx+w, dy+h); - fclose(fp); - free(im_id); - } - } - } - for (j = 0; j < num_labels; ++j) { - ++total; - box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h}; - float best_iou = 0; - for(k = 0; k < num_boxes*num_boxes*num; ++k){ - float iou = box_iou(boxes[k], t); - if(probs[k][0] > thresh && iou > best_iou){ - best_iou = iou; - } - } - avg_iou += best_iou; - if(best_iou > iou_thresh){ - ++correct; - } - } - free(truth); - free_image(orig); - free_image(resized); - fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total); - } -} - -void extract_boxes(char *cfgfile, char *weightfile) -{ - network net = parse_network_cfg(cfgfile); - if(weightfile){ - load_weights(&net, weightfile); - } - set_batch_network(&net, 1); - fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); - srand(time(0)); - - char *val_images = "/home/pjreddie/data/voc/test/train.txt"; - list *plist = get_paths(val_images); - char **paths = (char **)list_to_array(plist); - - layer l = net.layers[net.n - 1]; - - int num_boxes = l.side; - int num = l.n; - int classes = l.classes; - - int j; - - box *boxes = calloc(num_boxes*num_boxes*num, sizeof(box)); - float **probs = calloc(num_boxes*num_boxes*num, sizeof(float *)); - for(j = 0; j < num_boxes*num_boxes*num; ++j) probs[j] = calloc(classes+1, sizeof(float *)); - - int N = plist->size; - int i=0; - int k; - - int count = 0; - float iou_thresh = .3; - - for (i = 0; i < N; ++i) { - fprintf(stderr, "%5d %5d\n", i, count); - char *path = paths[i]; - image orig = load_image_color(path, 0, 0); - image resized = resize_image(orig, net.w, net.h); - - float *X = resized.data; - float *predictions = network_predict(net, X); - get_boxes(predictions+1+classes, num, num_boxes, 5+classes, boxes); - get_probs(predictions, num*num_boxes*num_boxes, classes, 5+classes, probs); - - char *labelpath = find_replace(path, "images", "labels"); - labelpath = find_replace(labelpath, "JPEGImages", "labels"); - labelpath = find_replace(labelpath, ".jpg", ".txt"); - labelpath = find_replace(labelpath, ".JPEG", ".txt"); - - int num_labels = 0; - box_label *truth = read_boxes(labelpath, &num_labels); - FILE *label = stdin; - for(k = 0; k < num_boxes*num_boxes*num; ++k){ - int overlaps = 0; - for (j = 0; j < num_labels; ++j) { - box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h}; - float iou = box_iou(boxes[k], t); - if (iou > iou_thresh){ - if (!overlaps) { - char buff[256]; - sprintf(buff, "/data/extracted/labels/%d.txt", count); - label = fopen(buff, "w"); - overlaps = 1; - } - fprintf(label, "%d %f\n", truth[j].id, iou); - } - } - if (overlaps) { - char buff[256]; - sprintf(buff, "/data/extracted/imgs/%d", count++); - int dx = (boxes[k].x - boxes[k].w/2) * orig.w; - int dy = (boxes[k].y - boxes[k].h/2) * orig.h; - int w = boxes[k].w * orig.w; - int h = boxes[k].h * orig.h; - image cropped = crop_image(orig, dx, dy, w, h); - image sized = resize_image(cropped, 224, 224); -#ifdef OPENCV - save_image_jpg(sized, buff); -#endif - free_image(sized); - free_image(cropped); - fclose(label); - } - } - free(truth); - free_image(orig); - free_image(resized); - } -} - void validate_coco(char *cfgfile, char *weightfile) { network net = parse_network_cfg(cfgfile); @@ -409,13 +190,16 @@ void validate_coco(char *cfgfile, char *weightfile) fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); srand(time(0)); - char *base = "/home/pjreddie/backup/"; + char *base = "results/"; list *plist = get_paths("data/coco_val_5k.list"); + //list *plist = get_paths("/home/pjreddie/data/people-art/test.txt"); + //list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt"); char **paths = (char **)list_to_array(plist); - int num_boxes = 9; - int num = 4; - int classes = 1; + layer l = net.layers[net.n-1]; + int classes = l.classes; + int square = l.sqrt; + int side = l.side; int j; char buff[1024]; @@ -423,29 +207,30 @@ void validate_coco(char *cfgfile, char *weightfile) FILE *fp = fopen(buff, "w"); fprintf(fp, "[\n"); - box *boxes = calloc(num_boxes*num_boxes*num, sizeof(box)); - float **probs = calloc(num_boxes*num_boxes*num, sizeof(float *)); - for(j = 0; j < num_boxes*num_boxes*num; ++j) probs[j] = calloc(classes, sizeof(float *)); + box *boxes = calloc(side*side*l.n, sizeof(box)); + float **probs = calloc(side*side*l.n, sizeof(float *)); + for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *)); int m = plist->size; int i=0; int t; - float thresh = .01; + float thresh = .001; int nms = 1; float iou_thresh = .5; - load_args args = {0}; - args.w = net.w; - args.h = net.h; - args.type = IMAGE_DATA; - int nthreads = 8; image *val = calloc(nthreads, sizeof(image)); image *val_resized = calloc(nthreads, sizeof(image)); image *buf = calloc(nthreads, sizeof(image)); image *buf_resized = calloc(nthreads, sizeof(image)); pthread_t *thr = calloc(nthreads, sizeof(pthread_t)); + + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.type = IMAGE_DATA; + for(t = 0; t < nthreads; ++t){ args.path = paths[i+t]; args.im = &buf[t]; @@ -473,9 +258,9 @@ void validate_coco(char *cfgfile, char *weightfile) float *predictions = network_predict(net, X); int w = val[t].w; int h = val[t].h; - convert_cocos(predictions, classes, num_boxes, num, w, h, thresh, probs, boxes); - if (nms) do_nms(boxes, probs, num_boxes, classes, iou_thresh); - print_cocos(fp, image_id, boxes, probs, num_boxes, classes, w, h); + convert_coco_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes, 0); + if (nms) do_nms_sort(boxes, probs, side*side*l.n, classes, iou_thresh); + print_cocos(fp, image_id, boxes, probs, side*side*l.n, classes, w, h); free_image(val[t]); free_image(val_resized[t]); } @@ -483,21 +268,114 @@ void validate_coco(char *cfgfile, char *weightfile) fseek(fp, -2, SEEK_CUR); fprintf(fp, "\n]\n"); fclose(fp); + fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start)); } -void test_coco(char *cfgfile, char *weightfile, char *filename) +void validate_coco_recall(char *cfgfile, char *weightfile) +{ + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + set_batch_network(&net, 1); + fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + srand(time(0)); + + char *base = "results/comp4_det_test_"; + list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt"); + char **paths = (char **)list_to_array(plist); + + layer l = net.layers[net.n-1]; + int classes = l.classes; + int square = l.sqrt; + int side = l.side; + + int j, k; + FILE **fps = calloc(classes, sizeof(FILE *)); + for(j = 0; j < classes; ++j){ + char buff[1024]; + snprintf(buff, 1024, "%s%s.txt", base, coco_classes[j]); + fps[j] = fopen(buff, "w"); + } + box *boxes = calloc(side*side*l.n, sizeof(box)); + float **probs = calloc(side*side*l.n, sizeof(float *)); + for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *)); + + int m = plist->size; + int i=0; + + float thresh = .001; + int nms = 0; + float iou_thresh = .5; + float nms_thresh = .5; + + int total = 0; + int correct = 0; + int proposals = 0; + float avg_iou = 0; + + for(i = 0; i < m; ++i){ + char *path = paths[i]; + image orig = load_image_color(path, 0, 0); + image sized = resize_image(orig, net.w, net.h); + char *id = basecfg(path); + float *predictions = network_predict(net, sized.data); + convert_coco_detections(predictions, classes, l.n, square, side, 1, 1, thresh, probs, boxes, 1); + if (nms) do_nms(boxes, probs, side*side*l.n, 1, nms_thresh); + + char *labelpath = find_replace(path, "images", "labels"); + labelpath = find_replace(labelpath, "JPEGImages", "labels"); + labelpath = find_replace(labelpath, ".jpg", ".txt"); + labelpath = find_replace(labelpath, ".JPEG", ".txt"); + + int num_labels = 0; + box_label *truth = read_boxes(labelpath, &num_labels); + for(k = 0; k < side*side*l.n; ++k){ + if(probs[k][0] > thresh){ + ++proposals; + } + } + for (j = 0; j < num_labels; ++j) { + ++total; + box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h}; + float best_iou = 0; + for(k = 0; k < side*side*l.n; ++k){ + float iou = box_iou(boxes[k], t); + if(probs[k][0] > thresh && iou > best_iou){ + best_iou = iou; + } + } + avg_iou += best_iou; + if(best_iou > iou_thresh){ + ++correct; + } + } + + fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total); + free(id); + free_image(orig); + free_image(sized); + } +} + +void test_coco(char *cfgfile, char *weightfile, char *filename, float thresh) { network net = parse_network_cfg(cfgfile); if(weightfile){ load_weights(&net, weightfile); } + region_layer l = net.layers[net.n-1]; set_batch_network(&net, 1); srand(2222222); clock_t time; char buff[256]; char *input = buff; + int j; + box *boxes = calloc(l.side*l.side*l.n, sizeof(box)); + float **probs = calloc(l.side*l.side*l.n, sizeof(float *)); + for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *)); while(1){ if(filename){ strncpy(input, filename, 256); @@ -514,7 +392,10 @@ void test_coco(char *cfgfile, char *weightfile, char *filename) time=clock(); float *predictions = network_predict(net, X); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); - draw_coco(im, predictions, 7, "predictions"); + convert_coco_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0); + draw_coco(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions"); + + show_image(sized, "resized"); free_image(im); free_image(sized); #ifdef OPENCV @@ -527,6 +408,7 @@ void test_coco(char *cfgfile, char *weightfile, char *filename) void run_coco(int argc, char **argv) { + float thresh = find_float_arg(argc, argv, "-thresh", .2); if(argc < 4){ fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); return; @@ -535,8 +417,8 @@ void run_coco(int argc, char **argv) char *cfg = argv[3]; char *weights = (argc > 4) ? argv[4] : 0; char *filename = (argc > 5) ? argv[5]: 0; - if(0==strcmp(argv[2], "test")) test_coco(cfg, weights, filename); + if(0==strcmp(argv[2], "test")) test_coco(cfg, weights, filename, thresh); else if(0==strcmp(argv[2], "train")) train_coco(cfg, weights); - else if(0==strcmp(argv[2], "extract")) extract_boxes(cfg, weights); - else if(0==strcmp(argv[2], "valid")) validate_recall(cfg, weights); + else if(0==strcmp(argv[2], "valid")) validate_coco(cfg, weights); + else if(0==strcmp(argv[2], "recall")) validate_coco_recall(cfg, weights); } diff --git a/src/compare.c b/src/compare.c index 76e0b60e..a1f494e0 100644 --- a/src/compare.c +++ b/src/compare.c @@ -307,7 +307,7 @@ void BattleRoyaleWithCheese(char *filename, char *weightfile) qsort(boxes, N, sizeof(sortable_bbox), elo_comparator); N /= 2; - for(round = 1; round <= 20; ++round){ + for(round = 1; round <= 100; ++round){ clock_t round_time=clock(); printf("Round: %d\n", round); @@ -316,7 +316,7 @@ void BattleRoyaleWithCheese(char *filename, char *weightfile) bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, class); } qsort(boxes, N, sizeof(sortable_bbox), elo_comparator); - N = (N*9/10)/2*2; + if(round <= 20) N = (N*9/10)/2*2; printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N); } diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index a150c205..60a18795 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -8,21 +8,65 @@ extern "C" { #include "cuda.h" } -__global__ void bias_output_kernel(float *output, float *biases, int n, int size) +__global__ void scale_bias_kernel(float *output, float *biases, int n, int size) { int offset = blockIdx.x * blockDim.x + threadIdx.x; int filter = blockIdx.y; int batch = blockIdx.z; - if(offset < size) output[(batch*n+filter)*size + offset] = biases[filter]; + if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter]; } -void bias_output_gpu(float *output, float *biases, int batch, int n, int size) +void scale_bias_gpu(float *output, float *biases, int batch, int n, int size) { dim3 dimGrid((size-1)/BLOCK + 1, n, batch); dim3 dimBlock(BLOCK, 1, 1); - bias_output_kernel<<>>(output, biases, n, size); + scale_bias_kernel<<>>(output, biases, n, size); + check_error(cudaPeekAtLastError()); +} + +__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates) +{ + __shared__ float part[BLOCK]; + int i,b; + int filter = blockIdx.x; + int p = threadIdx.x; + float sum = 0; + for(b = 0; b < batch; ++b){ + for(i = 0; i < size; i += BLOCK){ + int index = p + i + size*(filter + n*b); + sum += (p+i < size) ? delta[index]*x_norm[index] : 0; + } + } + part[p] = sum; + __syncthreads(); + if (p == 0) { + for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i]; + } +} + +void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates) +{ + backward_scale_kernel<<>>(x_norm, delta, batch, n, size, scale_updates); + check_error(cudaPeekAtLastError()); +} + +__global__ void add_bias_kernel(float *output, float *biases, int n, int size) +{ + int offset = blockIdx.x * blockDim.x + threadIdx.x; + int filter = blockIdx.y; + int batch = blockIdx.z; + + if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter]; +} + +void add_bias_gpu(float *output, float *biases, int batch, int n, int size) +{ + dim3 dimGrid((size-1)/BLOCK + 1, n, batch); + dim3 dimBlock(BLOCK, 1, 1); + + add_bias_kernel<<>>(output, biases, n, size); check_error(cudaPeekAtLastError()); } @@ -41,7 +85,7 @@ __global__ void backward_bias_kernel(float *bias_updates, float *delta, int batc } part[p] = sum; __syncthreads(); - if(p == 0){ + if (p == 0) { for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i]; } } @@ -52,53 +96,88 @@ void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int check_error(cudaPeekAtLastError()); } -void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state) +void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) { int i; - int m = layer.n; - int k = layer.size*layer.size*layer.c; - int n = convolutional_out_height(layer)* - convolutional_out_width(layer); + int m = l.n; + int k = l.size*l.size*l.c; + int n = convolutional_out_height(l)* + convolutional_out_width(l); - bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, n); - for(i = 0; i < layer.batch; ++i){ - im2col_ongpu(state.input + i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_gpu); - float * a = layer.filters_gpu; - float * b = layer.col_image_gpu; - float * c = layer.output_gpu; + fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1); + for(i = 0; i < l.batch; ++i){ + im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu); + float * a = l.filters_gpu; + float * b = l.col_image_gpu; + float * c = l.output_gpu; gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n); } - activate_array_ongpu(layer.output_gpu, m*n*layer.batch, layer.activation); + + if(l.batch_normalize){ + if(state.train){ + fast_mean_gpu(l.output_gpu, l.batch, l.n, l.out_h*l.out_w, l.spatial_mean_gpu, l.mean_gpu); + fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.n, l.out_h*l.out_w, l.spatial_variance_gpu, l.variance_gpu); + + scal_ongpu(l.n, .95, l.rolling_mean_gpu, 1); + axpy_ongpu(l.n, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1); + scal_ongpu(l.n, .95, l.rolling_variance_gpu, 1); + axpy_ongpu(l.n, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1); + + // cuda_pull_array(l.variance_gpu, l.mean, l.n); + // printf("%f\n", l.mean[0]); + + copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1); + normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_h*l.out_w); + copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1); + } else { + normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.n, l.out_h*l.out_w); + } + + scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w); + } + add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, n); + + activate_array_ongpu(l.output_gpu, m*n*l.batch, l.activation); } -void backward_convolutional_layer_gpu(convolutional_layer layer, network_state state) +void backward_convolutional_layer_gpu(convolutional_layer l, network_state state) { int i; - int m = layer.n; - int n = layer.size*layer.size*layer.c; - int k = convolutional_out_height(layer)* - convolutional_out_width(layer); + int m = l.n; + int n = l.size*l.size*l.c; + int k = convolutional_out_height(l)* + convolutional_out_width(l); - gradient_array_ongpu(layer.output_gpu, m*k*layer.batch, layer.activation, layer.delta_gpu); - backward_bias_gpu(layer.bias_updates_gpu, layer.delta_gpu, layer.batch, layer.n, k); + gradient_array_ongpu(l.output_gpu, m*k*l.batch, l.activation, l.delta_gpu); - for(i = 0; i < layer.batch; ++i){ - float * a = layer.delta_gpu; - float * b = layer.col_image_gpu; - float * c = layer.filter_updates_gpu; + backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, k); - im2col_ongpu(state.input + i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_gpu); + if(l.batch_normalize){ + backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.scale_updates_gpu); + + scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w); + + fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.spatial_mean_delta_gpu, l.mean_delta_gpu); + fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.spatial_variance_delta_gpu, l.variance_delta_gpu); + normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.delta_gpu); + } + + for(i = 0; i < l.batch; ++i){ + float * a = l.delta_gpu; + float * b = l.col_image_gpu; + float * c = l.filter_updates_gpu; + + im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu); gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n); if(state.delta){ - - float * a = layer.filters_gpu; - float * b = layer.delta_gpu; - float * c = layer.col_image_gpu; + float * a = l.filters_gpu; + float * b = l.delta_gpu; + float * c = l.col_image_gpu; gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k); - col2im_ongpu(layer.col_image_gpu, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, state.delta + i*layer.c*layer.h*layer.w); + col2im_ongpu(l.col_image_gpu, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta + i*l.c*l.h*l.w); } } } @@ -109,6 +188,11 @@ void pull_convolutional_layer(convolutional_layer layer) cuda_pull_array(layer.biases_gpu, layer.biases, layer.n); cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size); cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n); + if (layer.batch_normalize){ + cuda_pull_array(layer.scales_gpu, layer.scales, layer.n); + cuda_pull_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n); + cuda_pull_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n); + } } void push_convolutional_layer(convolutional_layer layer) @@ -117,6 +201,11 @@ void push_convolutional_layer(convolutional_layer layer) cuda_push_array(layer.biases_gpu, layer.biases, layer.n); cuda_push_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size); cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n); + if (layer.batch_normalize){ + cuda_push_array(layer.scales_gpu, layer.scales, layer.n); + cuda_push_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n); + cuda_push_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n); + } } void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay) @@ -126,8 +215,12 @@ void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1); scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1); + axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1); + scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1); + axpy_ongpu(size, -decay*batch, layer.filters_gpu, 1, layer.filter_updates_gpu, 1); axpy_ongpu(size, learning_rate/batch, layer.filter_updates_gpu, 1, layer.filters_gpu, 1); scal_ongpu(size, momentum, layer.filter_updates_gpu, 1); } + diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index f3609eaf..b9fd3c95 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -41,7 +41,7 @@ image get_convolutional_delta(convolutional_layer l) return float_to_image(w,h,c,l.delta); } -convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation) +convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize) { int i; convolutional_layer l = {0}; @@ -55,18 +55,17 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int l.stride = stride; l.size = size; l.pad = pad; + l.batch_normalize = batch_normalize; l.filters = calloc(c*n*size*size, sizeof(float)); l.filter_updates = calloc(c*n*size*size, sizeof(float)); l.biases = calloc(n, sizeof(float)); l.bias_updates = calloc(n, sizeof(float)); + // float scale = 1./sqrt(size*size*c); float scale = sqrt(2./(size*size*c)); for(i = 0; i < c*n*size*size; ++i) l.filters[i] = 2*scale*rand_uniform() - scale; - for(i = 0; i < n; ++i){ - l.biases[i] = scale; - } int out_h = convolutional_out_height(l); int out_w = convolutional_out_width(l); l.out_h = out_h; @@ -79,17 +78,55 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int l.output = calloc(l.batch*out_h * out_w * n, sizeof(float)); l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float)); - #ifdef GPU + if(batch_normalize){ + l.scales = calloc(n, sizeof(float)); + l.scale_updates = calloc(n, sizeof(float)); + for(i = 0; i < n; ++i){ + l.scales[i] = 1; + } + + l.mean = calloc(n, sizeof(float)); + l.spatial_mean = calloc(n*l.batch, sizeof(float)); + + l.variance = calloc(n, sizeof(float)); + l.rolling_mean = calloc(n, sizeof(float)); + l.rolling_variance = calloc(n, sizeof(float)); + } + +#ifdef GPU l.filters_gpu = cuda_make_array(l.filters, c*n*size*size); l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size); l.biases_gpu = cuda_make_array(l.biases, n); l.bias_updates_gpu = cuda_make_array(l.bias_updates, n); + l.scales_gpu = cuda_make_array(l.scales, n); + l.scale_updates_gpu = cuda_make_array(l.scale_updates, n); + l.col_image_gpu = cuda_make_array(l.col_image, out_h*out_w*size*size*c); l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n); l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); - #endif + + if(batch_normalize){ + l.mean_gpu = cuda_make_array(l.mean, n); + l.variance_gpu = cuda_make_array(l.variance, n); + + l.rolling_mean_gpu = cuda_make_array(l.mean, n); + l.rolling_variance_gpu = cuda_make_array(l.variance, n); + + l.spatial_mean_gpu = cuda_make_array(l.spatial_mean, n*l.batch); + l.spatial_variance_gpu = cuda_make_array(l.spatial_mean, n*l.batch); + + l.spatial_mean_delta_gpu = cuda_make_array(l.spatial_mean, n*l.batch); + l.spatial_variance_delta_gpu = cuda_make_array(l.spatial_mean, n*l.batch); + + l.mean_delta_gpu = cuda_make_array(l.mean, n); + l.variance_delta_gpu = cuda_make_array(l.variance, n); + + l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); + l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); + } +#endif l.activation = activation; fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n); @@ -97,6 +134,42 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int return l; } +void denormalize_convolutional_layer(convolutional_layer l) +{ + int i, j; + for(i = 0; i < l.n; ++i){ + float scale = l.scales[i]/sqrt(l.rolling_variance[i] + .00001); + for(j = 0; j < l.c*l.size*l.size; ++j){ + l.filters[i*l.c*l.size*l.size + j] *= scale; + } + l.biases[i] -= l.rolling_mean[i] * scale; + } +} + +void test_convolutional_layer() +{ + convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1); + l.batch_normalize = 1; + float data[] = {1,1,1,1,1, + 1,1,1,1,1, + 1,1,1,1,1, + 1,1,1,1,1, + 1,1,1,1,1, + 2,2,2,2,2, + 2,2,2,2,2, + 2,2,2,2,2, + 2,2,2,2,2, + 2,2,2,2,2, + 3,3,3,3,3, + 3,3,3,3,3, + 3,3,3,3,3, + 3,3,3,3,3, + 3,3,3,3,3}; + network_state state = {0}; + state.input = data; + forward_convolutional_layer(l, state); +} + void resize_convolutional_layer(convolutional_layer *l, int w, int h) { l->w = w; @@ -111,13 +184,13 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h) l->inputs = l->w * l->h * l->c; l->col_image = realloc(l->col_image, - out_h*out_w*l->size*l->size*l->c*sizeof(float)); + out_h*out_w*l->size*l->size*l->c*sizeof(float)); l->output = realloc(l->output, - l->batch*out_h * out_w * l->n*sizeof(float)); + l->batch*out_h * out_w * l->n*sizeof(float)); l->delta = realloc(l->delta, - l->batch*out_h * out_w * l->n*sizeof(float)); + l->batch*out_h * out_w * l->n*sizeof(float)); - #ifdef GPU +#ifdef GPU cuda_free(l->col_image_gpu); cuda_free(l->delta_gpu); cuda_free(l->output_gpu); @@ -125,7 +198,7 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h) l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c); l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n); l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n); - #endif +#endif } void bias_output(float *output, float *biases, int batch, int n, int size) @@ -150,7 +223,6 @@ void backward_bias(float *bias_updates, float *delta, int batch, int n, int size } } - void forward_convolutional_layer(const convolutional_layer l, network_state state) { int out_h = convolutional_out_height(l); @@ -169,11 +241,18 @@ void forward_convolutional_layer(const convolutional_layer l, network_state stat for(i = 0; i < l.batch; ++i){ im2col_cpu(state.input, l.c, l.h, l.w, - l.size, l.stride, l.pad, b); + l.size, l.stride, l.pad, b); gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); c += n*m; state.input += l.c*l.h*l.w; } + + if(l.batch_normalize){ + mean_cpu(l.output, l.batch, l.n, l.out_h*l.out_w, l.mean); + variance_cpu(l.output, l.mean, l.batch, l.n, l.out_h*l.out_w, l.variance); + normalize_cpu(l.output, l.mean, l.variance, l.batch, l.n, l.out_h*l.out_w); + } + activate_array(l.output, m*n*l.batch, l.activation); } diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index 7452c3c5..70a3d052 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -17,11 +17,12 @@ void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float void push_convolutional_layer(convolutional_layer layer); void pull_convolutional_layer(convolutional_layer layer); -void bias_output_gpu(float *output, float *biases, int batch, int n, int size); +void add_bias_gpu(float *output, float *biases, int batch, int n, int size); void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size); #endif -convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation); +convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalization); +void denormalize_convolutional_layer(convolutional_layer l); void resize_convolutional_layer(convolutional_layer *layer, int w, int h); void forward_convolutional_layer(const convolutional_layer layer, network_state state); void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay); diff --git a/src/crop_layer_kernels.cu b/src/crop_layer_kernels.cu index fc7fcbdb..fdebd1b3 100644 --- a/src/crop_layer_kernels.cu +++ b/src/crop_layer_kernels.cu @@ -91,7 +91,7 @@ __device__ float bilinear_interpolate_kernel(float *image, int w, int h, float x return val; } -__global__ void levels_image_kernel(float *image, float *rand, int batch, int w, int h, int train, float saturation, float exposure, float translate, float scale) +__global__ void levels_image_kernel(float *image, float *rand, int batch, int w, int h, int train, float saturation, float exposure, float translate, float scale, float shift) { int size = batch * w * h; int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; @@ -100,6 +100,9 @@ __global__ void levels_image_kernel(float *image, float *rand, int batch, int w, id /= w; int y = id % h; id /= h; + float rshift = rand[0]; + float gshift = rand[1]; + float bshift = rand[2]; float r0 = rand[8*id + 0]; float r1 = rand[8*id + 1]; float r2 = rand[8*id + 2]; @@ -121,10 +124,12 @@ __global__ void levels_image_kernel(float *image, float *rand, int batch, int w, hsv.y *= saturation; hsv.z *= exposure; rgb = hsv_to_rgb_kernel(hsv); + } else { + shift = 0; } - image[x + w*(y + h*0)] = rgb.x*scale + translate; - image[x + w*(y + h*1)] = rgb.y*scale + translate; - image[x + w*(y + h*2)] = rgb.z*scale + translate; + image[x + w*(y + h*0)] = rgb.x*scale + translate + (rshift - .5)*shift; + image[x + w*(y + h*1)] = rgb.y*scale + translate + (gshift - .5)*shift; + image[x + w*(y + h*2)] = rgb.z*scale + translate + (bshift - .5)*shift; } __global__ void forward_crop_layer_kernel(float *input, float *rand, int size, int c, int h, int w, int crop_height, int crop_width, int train, int flip, float angle, float *output) @@ -186,7 +191,7 @@ extern "C" void forward_crop_layer_gpu(crop_layer layer, network_state state) int size = layer.batch * layer.w * layer.h; - levels_image_kernel<<>>(state.input, layer.rand_gpu, layer.batch, layer.w, layer.h, state.train, layer.saturation, layer.exposure, translate, scale); + levels_image_kernel<<>>(state.input, layer.rand_gpu, layer.batch, layer.w, layer.h, state.train, layer.saturation, layer.exposure, translate, scale, layer.shift); check_error(cudaPeekAtLastError()); size = layer.batch*layer.c*layer.crop_width*layer.crop_height; diff --git a/src/darknet.c b/src/darknet.c index 073156b5..78146119 100644 --- a/src/darknet.c +++ b/src/darknet.c @@ -141,6 +141,47 @@ void rgbgr_net(char *cfgfile, char *weightfile, char *outfile) save_weights(net, outfile); } +void normalize_net(char *cfgfile, char *weightfile, char *outfile) +{ + gpu_index = -1; + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + int i, j; + for(i = 0; i < net.n; ++i){ + layer l = net.layers[i]; + if(l.type == CONVOLUTIONAL){ + net.layers[i].batch_normalize=1; + net.layers[i].scales = calloc(l.n, sizeof(float)); + for(j = 0; j < l.n; ++j){ + net.layers[i].scales[i] = 1; + } + net.layers[i].rolling_mean = calloc(l.n, sizeof(float)); + net.layers[i].rolling_variance = calloc(l.n, sizeof(float)); + } + } + save_weights(net, outfile); +} + +void denormalize_net(char *cfgfile, char *weightfile, char *outfile) +{ + gpu_index = -1; + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + int i; + for(i = 0; i < net.n; ++i){ + layer l = net.layers[i]; + if(l.type == CONVOLUTIONAL){ + denormalize_convolutional_layer(l); + net.layers[i].batch_normalize=0; + } + } + save_weights(net, outfile); +} + void visualize(char *cfgfile, char *weightfile) { network net = parse_network_cfg(cfgfile); @@ -202,6 +243,10 @@ int main(int argc, char **argv) change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0); } else if (0 == strcmp(argv[1], "rgbgr")){ rgbgr_net(argv[2], argv[3], argv[4]); + } else if (0 == strcmp(argv[1], "denormalize")){ + denormalize_net(argv[2], argv[3], argv[4]); + } else if (0 == strcmp(argv[1], "normalize")){ + normalize_net(argv[2], argv[3], argv[4]); } else if (0 == strcmp(argv[1], "rescale")){ rescale_net(argv[2], argv[3], argv[4]); } else if (0 == strcmp(argv[1], "partial")){ diff --git a/src/data.c b/src/data.c index 92c3d950..df15dc56 100644 --- a/src/data.c +++ b/src/data.c @@ -153,7 +153,9 @@ void fill_truth_region(char *path, float *truth, int classes, int num_boxes, int { char *labelpath = find_replace(path, "images", "labels"); labelpath = find_replace(labelpath, "JPEGImages", "labels"); + labelpath = find_replace(labelpath, ".jpg", ".txt"); + labelpath = find_replace(labelpath, ".JPG", ".txt"); labelpath = find_replace(labelpath, ".JPEG", ".txt"); int count = 0; box_label *boxes = read_boxes(labelpath, &count); @@ -547,7 +549,7 @@ void *load_thread(void *ptr) check_error(status); #endif - printf("Loading data: %d\n", rand_r(&data_seed)); + //printf("Loading data: %d\n", rand_r(&data_seed)); load_args a = *(struct load_args*)ptr; if (a.type == CLASSIFICATION_DATA){ *a.d = load_data(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h); diff --git a/src/deconvolutional_kernels.cu b/src/deconvolutional_kernels.cu index aeab2c3f..a74fb782 100644 --- a/src/deconvolutional_kernels.cu +++ b/src/deconvolutional_kernels.cu @@ -20,7 +20,7 @@ extern "C" void forward_deconvolutional_layer_gpu(deconvolutional_layer layer, n int n = layer.h*layer.w; int k = layer.c; - bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, size); + fill_ongpu(layer.outputs*layer.batch, 0, layer.output_gpu, 1); for(i = 0; i < layer.batch; ++i){ float *a = layer.filters_gpu; @@ -31,6 +31,7 @@ extern "C" void forward_deconvolutional_layer_gpu(deconvolutional_layer layer, n col2im_ongpu(c, layer.n, out_h, out_w, layer.size, layer.stride, 0, layer.output_gpu+i*layer.n*size); } + add_bias_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, size); activate_array(layer.output_gpu, layer.batch*layer.n*size, layer.activation); } diff --git a/src/image.c b/src/image.c index 861d8a2a..ac495292 100644 --- a/src/image.c +++ b/src/image.c @@ -215,7 +215,7 @@ void show_image_cv(image p, char *name) IplImage *disp = cvCreateImage(cvSize(p.w,p.h), IPL_DEPTH_8U, p.c); int step = disp->widthStep; - cvNamedWindow(buff, CV_WINDOW_AUTOSIZE); + cvNamedWindow(buff, CV_WINDOW_NORMAL); //cvMoveWindow(buff, 100*(windows%10) + 200*(windows/10), 100*(windows%10)); ++windows; for(y = 0; y < p.h; ++y){ @@ -696,7 +696,7 @@ image load_image_cv(char *filename, int channels) if( (src = cvLoadImage(filename, flag)) == 0 ) { - printf("Cannot load file image %s\n", filename); + printf("Cannot load image \"%s\"\n", filename); exit(0); } image out = ipl_to_image(src); @@ -713,7 +713,7 @@ image load_image_stb(char *filename, int channels) int w, h, c; unsigned char *data = stbi_load(filename, &w, &h, &c, channels); if (!data) { - fprintf(stderr, "Cannot load file image %s\nSTB Reason: %s\n", filename, stbi_failure_reason()); + fprintf(stderr, "Cannot load image \"%s\"\nSTB Reason: %s\n", filename, stbi_failure_reason()); exit(0); } if(channels) c = channels; diff --git a/src/imagenet.c b/src/imagenet.c index 1701a2ad..fa162519 100644 --- a/src/imagenet.c +++ b/src/imagenet.c @@ -92,6 +92,7 @@ void validate_imagenet(char *filename, char *weightfile) srand(time(0)); char **labels = get_labels("data/inet.labels.list"); + //list *plist = get_paths("data/inet.suppress.list"); list *plist = get_paths("data/inet.val.list"); char **paths = (char **)list_to_array(plist); diff --git a/src/layer.h b/src/layer.h index 49f144d0..2b136a05 100644 --- a/src/layer.h +++ b/src/layer.h @@ -27,6 +27,7 @@ typedef struct { LAYER_TYPE type; ACTIVATION activation; COST_TYPE cost_type; + int batch_normalize; int batch; int forced; int object_logistic; @@ -51,6 +52,7 @@ typedef struct { float jitter; float saturation; float exposure; + float shift; int softmax; int classes; int coords; @@ -71,6 +73,7 @@ typedef struct { float class_scale; int dontload; + int dontloadscales; float probability; float scale; @@ -84,6 +87,9 @@ typedef struct { float *biases; float *bias_updates; + float *scales; + float *scale_updates; + float *weights; float *weight_updates; @@ -95,18 +101,44 @@ typedef struct { float * squared; float * norms; + float * spatial_mean; + float * mean; + float * variance; + + float * rolling_mean; + float * rolling_variance; + #ifdef GPU int *indexes_gpu; float * filters_gpu; float * filter_updates_gpu; + float * spatial_mean_gpu; + float * spatial_variance_gpu; + + float * mean_gpu; + float * variance_gpu; + + float * rolling_mean_gpu; + float * rolling_variance_gpu; + + float * spatial_mean_delta_gpu; + float * spatial_variance_delta_gpu; + + float * variance_delta_gpu; + float * mean_delta_gpu; + float * col_image_gpu; + float * x_gpu; + float * x_norm_gpu; float * weights_gpu; float * biases_gpu; + float * scales_gpu; float * weight_updates_gpu; float * bias_updates_gpu; + float * scale_updates_gpu; float * output_gpu; float * delta_gpu; diff --git a/src/network.h b/src/network.h index 78ad0fe9..1caf838d 100644 --- a/src/network.h +++ b/src/network.h @@ -15,6 +15,7 @@ typedef struct { int n; int batch; int *seen; + float epoch; int subdivisions; float momentum; float decay; diff --git a/src/network_kernels.cu b/src/network_kernels.cu index cfc6e83a..d2c8bf9e 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -36,7 +36,7 @@ void forward_network_gpu(network net, network_state state) for(i = 0; i < net.n; ++i){ layer l = net.layers[i]; if(l.delta_gpu){ - scal_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1); + fill_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1); } if(l.type == CONVOLUTIONAL){ forward_convolutional_layer_gpu(l, state); diff --git a/src/parser.c b/src/parser.c index a3400d03..254da5c2 100644 --- a/src/parser.c +++ b/src/parser.c @@ -124,8 +124,9 @@ convolutional_layer parse_convolutional(list *options, size_params params) c = params.c; batch=params.batch; if(!(h && w && c)) error("Layer before convolutional layer must output image."); + int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0); - convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,pad,activation); + convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,pad,activation, batch_normalize); char *weights = option_find_str(options, "weights", 0); char *biases = option_find_str(options, "biases", 0); @@ -227,6 +228,7 @@ crop_layer parse_crop(list *options, size_params params) int noadjust = option_find_int_quiet(options, "noadjust",0); crop_layer l = make_crop_layer(batch,h,w,c,crop_height,crop_width,flip, angle, saturation, exposure); + l.shift = option_find_float(options, "shift", 0); l.noadjust = noadjust; return l; } @@ -452,6 +454,7 @@ network parse_network_cfg(char *filename) fprintf(stderr, "Type not recognized: %s\n", s->type); } l.dontload = option_find_int_quiet(options, "dontload", 0); + l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0); option_unused(options); net.layers[count] = l; free_section(s); @@ -633,19 +636,13 @@ void save_weights_upto(network net, char *filename, int cutoff) #endif int num = l.n*l.c*l.size*l.size; fwrite(l.biases, sizeof(float), l.n, fp); - fwrite(l.filters, sizeof(float), num, fp); - } - if(l.type == DECONVOLUTIONAL){ -#ifdef GPU - if(gpu_index >= 0){ - pull_deconvolutional_layer(l); + if (l.batch_normalize){ + fwrite(l.scales, sizeof(float), l.n, fp); + fwrite(l.rolling_mean, sizeof(float), l.n, fp); + fwrite(l.rolling_variance, sizeof(float), l.n, fp); } -#endif - int num = l.n*l.c*l.size*l.size; - fwrite(l.biases, sizeof(float), l.n, fp); fwrite(l.filters, sizeof(float), num, fp); - } - if(l.type == CONNECTED){ + } if(l.type == CONNECTED){ #ifdef GPU if(gpu_index >= 0){ pull_connected_layer(l); @@ -682,6 +679,11 @@ void load_weights_upto(network *net, char *filename, int cutoff) if(l.type == CONVOLUTIONAL){ int num = l.n*l.c*l.size*l.size; fread(l.biases, sizeof(float), l.n, fp); + if (l.batch_normalize && (!l.dontloadscales)){ + fread(l.scales, sizeof(float), l.n, fp); + fread(l.rolling_mean, sizeof(float), l.n, fp); + fread(l.rolling_variance, sizeof(float), l.n, fp); + } fread(l.filters, sizeof(float), num, fp); #ifdef GPU if(gpu_index >= 0){ diff --git a/src/region_layer.c b/src/region_layer.c index 3239f878..3fff22bf 100644 --- a/src/region_layer.c +++ b/src/region_layer.c @@ -226,6 +226,11 @@ void backward_region_layer(const region_layer l, network_state state) void forward_region_layer_gpu(const region_layer l, network_state state) { + if(!state.train){ + copy_ongpu(l.batch*l.inputs, state.input, 1, l.output_gpu, 1); + return; + } + float *in_cpu = calloc(l.batch*l.inputs, sizeof(float)); float *truth_cpu = 0; if(state.truth){ diff --git a/src/swag.c b/src/swag.c index 8c9ce3cd..4dc6bf93 100644 --- a/src/swag.c +++ b/src/swag.c @@ -12,39 +12,28 @@ char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"}; -void draw_swag(image im, float *predictions, int side, int num, char *label, float thresh) +void draw_swag(image im, int num, float thresh, box *boxes, float **probs, char *label) { int classes = 20; - int i,n; + int i; - for(i = 0; i < side*side; ++i){ - int row = i / side; - int col = i % side; - for(n = 0; n < num; ++n){ - int p_index = side*side*classes + i*num + n; - int box_index = side*side*(classes + num) + (i*num + n)*4; - int class_index = i*classes; - float scale = predictions[p_index]; - int class = max_index(predictions+class_index, classes); - float prob = scale * predictions[class_index + class]; - if(prob > thresh){ - int width = sqrt(prob)*5 + 1; - printf("%f %s\n", prob, voc_names[class]); - float red = get_color(0,class,classes); - float green = get_color(1,class,classes); - float blue = get_color(2,class,classes); - box b = float_to_box(predictions+box_index); - b.x = (b.x + col)/side; - b.y = (b.y + row)/side; - b.w = b.w*b.w; - b.h = b.h*b.h; + for(i = 0; i < num; ++i){ + int class = max_index(probs[i], classes); + float prob = probs[i][class]; + if(prob > thresh){ + int width = pow(prob, 1./3.)*10 + 1; + printf("%f %s\n", prob, voc_names[class]); + float red = get_color(0,class,classes); + float green = get_color(1,class,classes); + float blue = get_color(2,class,classes); + //red = green = blue = 0; + box b = boxes[i]; - int left = (b.x-b.w/2)*im.w; - int right = (b.x+b.w/2)*im.w; - int top = (b.y-b.h/2)*im.h; - int bot = (b.y+b.h/2)*im.h; - draw_box_width(im, left, top, right, bot, width, red, green, blue); - } + int left = (b.x-b.w/2.)*im.w; + int right = (b.x+b.w/2.)*im.w; + int top = (b.y-b.h/2.)*im.h; + int bot = (b.y+b.h/2.)*im.h; + draw_box_width(im, left, top, right, bot, width, red, green, blue); } } show_image(im, label); @@ -52,7 +41,12 @@ void draw_swag(image im, float *predictions, int side, int num, char *label, flo void train_swag(char *cfgfile, char *weightfile) { + //char *train_images = "/home/pjreddie/data/voc/person_detection/2010_person.txt"; + //char *train_images = "/home/pjreddie/data/people-art/train.txt"; + //char *train_images = "/home/pjreddie/data/voc/test/2012_trainval.txt"; char *train_images = "/home/pjreddie/data/voc/test/train.txt"; + //char *train_images = "/home/pjreddie/data/voc/test/train_all.txt"; + //char *train_images = "/home/pjreddie/data/voc/test/2007_trainval.txt"; char *backup_directory = "/home/pjreddie/backup/"; srand(time(0)); data_seed = time(0); @@ -116,7 +110,7 @@ void train_swag(char *cfgfile, char *weightfile) if (avg_loss < 0) avg_loss = loss; avg_loss = avg_loss*.9 + loss*.1; - printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs); + printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs); if(i%1000==0){ char buff[256]; sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); @@ -189,6 +183,9 @@ void validate_swag(char *cfgfile, char *weightfile) srand(time(0)); char *base = "results/comp4_det_test_"; + //base = "/home/pjreddie/comp4_det_test_"; + //list *plist = get_paths("/home/pjreddie/data/people-art/test.txt"); + //list *plist = get_paths("/home/pjreddie/data/cubist/test.txt"); list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt"); char **paths = (char **)list_to_array(plist); @@ -216,7 +213,7 @@ void validate_swag(char *cfgfile, char *weightfile) int nms = 1; float iou_thresh = .5; - int nthreads = 8; + int nthreads = 2; image *val = calloc(nthreads, sizeof(image)); image *val_resized = calloc(nthreads, sizeof(image)); image *buf = calloc(nthreads, sizeof(image)); @@ -256,7 +253,7 @@ void validate_swag(char *cfgfile, char *weightfile) int w = val[t].w; int h = val[t].h; convert_swag_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes, 0); - if (nms) do_nms(boxes, probs, side*side*l.n, classes, iou_thresh); + if (nms) do_nms_sort(boxes, probs, side*side*l.n, classes, iou_thresh); print_swag_detections(fps, id, boxes, probs, side*side*l.n, classes, w, h); free(id); free_image(val[t]); @@ -315,8 +312,6 @@ void validate_swag_recall(char *cfgfile, char *weightfile) image sized = resize_image(orig, net.w, net.h); char *id = basecfg(path); float *predictions = network_predict(net, sized.data); - int w = orig.w; - int h = orig.h; convert_swag_detections(predictions, classes, l.n, square, side, 1, 1, thresh, probs, boxes, 1); if (nms) do_nms(boxes, probs, side*side*l.n, 1, nms_thresh); @@ -362,12 +357,17 @@ void test_swag(char *cfgfile, char *weightfile, char *filename, float thresh) if(weightfile){ load_weights(&net, weightfile); } - region_layer layer = net.layers[net.n-1]; + region_layer l = net.layers[net.n-1]; set_batch_network(&net, 1); srand(2222222); clock_t time; char buff[256]; char *input = buff; + int j; + float nms=.5; + box *boxes = calloc(l.side*l.side*l.n, sizeof(box)); + float **probs = calloc(l.side*l.side*l.n, sizeof(float *)); + for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *)); while(1){ if(filename){ strncpy(input, filename, 256); @@ -384,7 +384,10 @@ void test_swag(char *cfgfile, char *weightfile, char *filename, float thresh) time=clock(); float *predictions = network_predict(net, X); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); - draw_swag(im, predictions, layer.side, layer.n, "predictions", thresh); + convert_swag_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0); + if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms); + draw_swag(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions"); + show_image(sized, "resized"); free_image(im); free_image(sized); @@ -396,6 +399,48 @@ void test_swag(char *cfgfile, char *weightfile, char *filename, float thresh) } } + +/* +#ifdef OPENCV +image ipl_to_image(IplImage* src); +#include "opencv2/highgui/highgui_c.h" +#include "opencv2/imgproc/imgproc_c.h" + +void demo_swag(char *cfgfile, char *weightfile, float thresh) +{ +network net = parse_network_cfg(cfgfile); +if(weightfile){ +load_weights(&net, weightfile); +} +region_layer layer = net.layers[net.n-1]; +CvCapture *capture = cvCaptureFromCAM(-1); +set_batch_network(&net, 1); +srand(2222222); +while(1){ +IplImage* frame = cvQueryFrame(capture); +image im = ipl_to_image(frame); +cvReleaseImage(&frame); +rgbgr_image(im); + +image sized = resize_image(im, net.w, net.h); +float *X = sized.data; +float *predictions = network_predict(net, X); +draw_swag(im, predictions, layer.side, layer.n, "predictions", thresh); +free_image(im); +free_image(sized); +cvWaitKey(10); +} +} +#else +void demo_swag(char *cfgfile, char *weightfile, float thresh){} +#endif + */ + +void demo_swag(char *cfgfile, char *weightfile, float thresh); +#ifndef GPU +void demo_swag(char *cfgfile, char *weightfile, float thresh){} +#endif + void run_swag(int argc, char **argv) { float thresh = find_float_arg(argc, argv, "-thresh", .2); @@ -411,4 +456,5 @@ void run_swag(int argc, char **argv) else if(0==strcmp(argv[2], "train")) train_swag(cfg, weights); else if(0==strcmp(argv[2], "valid")) validate_swag(cfg, weights); else if(0==strcmp(argv[2], "recall")) validate_swag_recall(cfg, weights); + else if(0==strcmp(argv[2], "demo")) demo_swag(cfg, weights, thresh); } diff --git a/src/swag_kernels.cu b/src/swag_kernels.cu new file mode 100644 index 00000000..5cba15cc --- /dev/null +++ b/src/swag_kernels.cu @@ -0,0 +1,61 @@ +extern "C" { +#include "network.h" +#include "region_layer.h" +#include "detection_layer.h" +#include "cost_layer.h" +#include "utils.h" +#include "parser.h" +#include "box.h" +#include "image.h" +} + +#ifdef OPENCV +#include "opencv2/highgui/highgui.hpp" +#include "opencv2/imgproc/imgproc.hpp" +extern "C" image ipl_to_image(IplImage* src); +extern "C" void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness); +extern "C" void draw_swag(image im, int num, float thresh, box *boxes, float **probs, char *label); + +extern "C" void demo_swag(char *cfgfile, char *weightfile, float thresh) +{ + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + region_layer l = net.layers[net.n-1]; + cv::VideoCapture cap(0); + + set_batch_network(&net, 1); + srand(2222222); + float nms = .4; + int j; + box *boxes = (box *)calloc(l.side*l.side*l.n, sizeof(box)); + float **probs = (float **)calloc(l.side*l.side*l.n, sizeof(float *)); + for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float *)); + + while(1){ + cv::Mat frame_m; + cap >> frame_m; + IplImage frame = frame_m; + image im = ipl_to_image(&frame); + rgbgr_image(im); + + image sized = resize_image(im, net.w, net.h); + float *X = sized.data; + float *predictions = network_predict(net, X); + convert_swag_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0); + if (nms > 0) do_nms(boxes, probs, l.side*l.side*l.n, l.classes, nms); + printf("\033[2J"); + printf("\033[1;1H"); + printf("\nObjects:\n\n"); + draw_swag(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions"); + + free_image(im); + free_image(sized); + cvWaitKey(1); + } +} +#else +extern "C" void demo_swag(char *cfgfile, char *weightfile, float thresh){} +#endif +