diff --git a/Makefile b/Makefile index 3374dcc1..6878eef0 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,7 @@ CFLAGS+= -DCUDNN LDFLAGS+= -lcudnn endif -OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o +OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o super.o voxel.o ifeq ($(GPU), 1) LDFLAGS+= -lstdc++ OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o diff --git a/cfg/darknet.cfg b/cfg/darknet.cfg index a96f4d0c..7c0d28a3 100644 --- a/cfg/darknet.cfg +++ b/cfg/darknet.cfg @@ -11,9 +11,10 @@ max_crop=320 learning_rate=0.1 policy=poly power=4 -max_batches=500000 +max_batches=1600000 [convolutional] +batch_normalize=1 filters=16 size=3 stride=1 @@ -25,6 +26,7 @@ size=2 stride=2 [convolutional] +batch_normalize=1 filters=32 size=3 stride=1 @@ -36,6 +38,7 @@ size=2 stride=2 [convolutional] +batch_normalize=1 filters=64 size=3 stride=1 @@ -47,6 +50,7 @@ size=2 stride=2 [convolutional] +batch_normalize=1 filters=128 size=3 stride=1 @@ -58,6 +62,7 @@ size=2 stride=2 [convolutional] +batch_normalize=1 filters=256 size=3 stride=1 @@ -69,6 +74,7 @@ size=2 stride=2 [convolutional] +batch_normalize=1 filters=512 size=3 stride=1 @@ -80,18 +86,22 @@ size=2 stride=2 [convolutional] +batch_normalize=1 filters=1024 size=3 stride=1 pad=1 activation=leaky -[avgpool] - -[connected] -output=1000 +[convolutional] +filters=1000 +size=1 +stride=1 +pad=1 activation=leaky +[avgpool] + [softmax] groups=1 diff --git a/cfg/extraction.cfg b/cfg/extraction.cfg index f9668a57..94e10675 100644 --- a/cfg/extraction.cfg +++ b/cfg/extraction.cfg @@ -1,26 +1,20 @@ [net] batch=128 subdivisions=1 -height=256 -width=256 +height=224 +width=224 +max_crop=320 channels=3 momentum=0.9 decay=0.0005 -learning_rate=0.5 +learning_rate=0.1 policy=poly -power=6 -max_batches=500000 - -[crop] -crop_height=224 -crop_width=224 -flip=1 -saturation=1 -exposure=1 -angle=0 +power=4 +max_batches=1600000 [convolutional] +batch_normalize=1 filters=64 size=7 stride=2 @@ -32,6 +26,7 @@ size=2 stride=2 [convolutional] +batch_normalize=1 filters=192 size=3 stride=1 @@ -43,6 +38,7 @@ size=2 stride=2 [convolutional] +batch_normalize=1 filters=128 size=1 stride=1 @@ -50,6 +46,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=256 size=3 stride=1 @@ -57,6 +54,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=256 size=1 stride=1 @@ -64,6 +62,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=512 size=3 stride=1 @@ -75,6 +74,7 @@ size=2 stride=2 [convolutional] +batch_normalize=1 filters=256 size=1 stride=1 @@ -82,6 +82,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=512 size=3 stride=1 @@ -89,6 +90,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=256 size=1 stride=1 @@ -96,6 +98,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=512 size=3 stride=1 @@ -103,6 +106,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=256 size=1 stride=1 @@ -110,6 +114,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=512 size=3 stride=1 @@ -117,6 +122,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=256 size=1 stride=1 @@ -124,6 +130,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=512 size=3 stride=1 @@ -131,6 +138,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=512 size=1 stride=1 @@ -138,6 +146,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=1024 size=3 stride=1 @@ -149,6 +158,7 @@ size=2 stride=2 [convolutional] +batch_normalize=1 filters=512 size=1 stride=1 @@ -156,6 +166,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=1024 size=3 stride=1 @@ -163,6 +174,7 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=512 size=1 stride=1 @@ -170,18 +182,22 @@ pad=1 activation=leaky [convolutional] +batch_normalize=1 filters=1024 size=3 stride=1 pad=1 activation=leaky +[convolutional] +filters=1000 +size=1 +stride=1 +pad=1 +activation=leaky + [avgpool] -[connected] -output=1000 -activation=leaky - [softmax] groups=1 diff --git a/src/blas.c b/src/blas.c index 00f0c3a3..9d42562b 100644 --- a/src/blas.c +++ b/src/blas.c @@ -1,6 +1,27 @@ #include "blas.h" #include "math.h" #include +#include +#include +#include + +void reorg(float *x, int size, int layers, int batch, int forward) +{ + float *swap = calloc(size*layers*batch, sizeof(float)); + int i,c,b; + for(b = 0; b < batch; ++b){ + for(c = 0; c < layers; ++c){ + for(i = 0; i < size; ++i){ + int i1 = b*layers*size + c*size + i; + int i2 = b*layers*size + i*layers + c; + if (forward) swap[i2] = x[i1]; + else swap[i1] = x[i2]; + } + } + } + memcpy(x, swap, size*layers*batch*sizeof(float)); + free(swap); +} void weighted_sum_cpu(float *a, float *b, float *s, int n, float *c) { diff --git a/src/blas.h b/src/blas.h index b4cfcf2e..4fdaa412 100644 --- a/src/blas.h +++ b/src/blas.h @@ -1,5 +1,6 @@ #ifndef BLAS_H #define BLAS_H +void reorg(float *x, int size, int layers, int batch, int forward); void pm(int M, int N, float *A); float *random_matrix(int rows, int cols); void time_random_matrix(int TA, int TB, int m, int k, int n); @@ -69,6 +70,7 @@ void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, floa void weighted_sum_gpu(float *a, float *b, float *s, int num, float *c); void mult_add_into_gpu(int num, float *a, float *b, float *c); +void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out); #endif #endif diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index ac537d88..3f7f1f95 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -312,6 +312,38 @@ __global__ void variance_kernel(float *x, float *mean, int batch, int filters, i variance[i] *= scale; } +__global__ void reorg_kernel(int N, float *x, int w, int h, int c, int batch, int stride, int forward, float *out) +{ + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if(i >= N) return; + int in_index = i; + int in_w = i%w; + i = i/w; + int in_h = i%h; + i = i/h; + int in_c = i%c; + i = i/c; + int b = i%batch; + + int out_c = c/(stride*stride); + + int c2 = in_c % out_c; + int offset = in_c / out_c; + int w2 = in_w*stride + offset % stride; + int h2 = in_h*stride + offset / stride; + //printf("%d\n", offset); + int out_index = w2 + w*stride*(h2 + h*stride*(c2 + out_c*b)); + + // printf("%d %d %d\n", w2, h2, c2); + //printf("%d %d\n", in_index, out_index); + //if(out_index >= N || out_index < 0) printf("bad bad bad \n"); + + if(forward) out[out_index] = x[in_index]; + else out[in_index] = x[out_index]; + //if(forward) out[1] = x[1]; + //else out[0] = x[0]; +} + __global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; @@ -488,6 +520,13 @@ extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * check_error(cudaPeekAtLastError()); } +extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out) +{ + int size = w*h*c*batch; + reorg_kernel<<>>(size, x, w, h, c, batch, stride, forward, out); + check_error(cudaPeekAtLastError()); +} + extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask) { mask_kernel<<>>(N, X, mask_num, mask); diff --git a/src/classifier.c b/src/classifier.c index 2d0d0e0c..608e3ab5 100644 --- a/src/classifier.c +++ b/src/classifier.c @@ -3,6 +3,7 @@ #include "parser.h" #include "option_list.h" #include "blas.h" +#include "assert.h" #include "classifier.h" #include @@ -40,6 +41,9 @@ list *read_data_cfg(char *filename) void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear) { + int nthreads = 2; + int i; + data_seed = time(0); srand(time(0)); float avg_loss = -1; @@ -51,7 +55,8 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear) } if(clear) *net.seen = 0; printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); - int imgs = net.batch*net.subdivisions; + int imgs = net.batch*net.subdivisions/nthreads; + assert(net.batch*net.subdivisions % nthreads == 0); list *options = read_data_cfg(datacfg); @@ -66,9 +71,10 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear) printf("%d\n", plist->size); int N = plist->size; clock_t time; - pthread_t load_thread; - data train; - data buffer; + + pthread_t *load_threads = calloc(nthreads, sizeof(pthread_t)); + data *trains = calloc(nthreads, sizeof(data)); + data *buffers = calloc(nthreads, sizeof(data)); load_args args = {0}; args.w = net.w; @@ -83,17 +89,27 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear) args.n = imgs; args.m = N; args.labels = labels; - args.d = &buffer; args.type = CLASSIFICATION_DATA; - load_thread = load_data_in_thread(args); + for(i = 0; i < nthreads; ++i){ + args.d = buffers + i; + load_threads[i] = load_data_in_thread(args); + } + int epoch = (*net.seen)/N; while(get_current_batch(net) < net.max_batches || net.max_batches == 0){ time=clock(); - pthread_join(load_thread, 0); - train = buffer; + for(i = 0; i < nthreads; ++i){ + pthread_join(load_threads[i], 0); + trains[i] = buffers[i]; + } + data train = concat_datas(trains, nthreads); + + for(i = 0; i < nthreads; ++i){ + args.d = buffers + i; + load_threads[i] = load_data_in_thread(args); + } - load_thread = load_data_in_thread(args); printf("Loaded: %lf seconds\n", sec(clock()-time)); time=clock(); @@ -111,6 +127,9 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear) avg_loss = avg_loss*.9 + loss*.1; printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen); free_data(train); + for(i = 0; i < nthreads; ++i){ + free_data(trains[i]); + } if(*net.seen/N > epoch){ epoch = *net.seen/N; char buff[256]; @@ -127,8 +146,14 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear) sprintf(buff, "%s/%s.weights", backup_directory, base); save_weights(net, buff); - pthread_join(load_thread, 0); - free_data(buffer); + for(i = 0; i < nthreads; ++i){ + pthread_join(load_threads[i], 0); + free_data(buffers[i]); + } + free(buffers); + free(trains); + free(load_threads); + free_network(net); free_ptrs((void**)labels, classes); free_ptrs((void**)paths, plist->size); @@ -136,7 +161,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear) free(base); } -void validate_classifier(char *datacfg, char *filename, char *weightfile) +void validate_classifier_crop(char *datacfg, char *filename, char *weightfile) { int i = 0; network net = parse_network_cfg(filename); @@ -708,10 +733,10 @@ void run_classifier(int argc, char **argv) else if(0==strcmp(argv[2], "demo")) demo_classifier(data, cfg, weights, cam_index, filename); else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights, layer); else if(0==strcmp(argv[2], "label")) label_classifier(data, cfg, weights); - else if(0==strcmp(argv[2], "valid")) validate_classifier(data, cfg, weights); - else if(0==strcmp(argv[2], "valid10")) validate_classifier_10(data, cfg, weights); + else if(0==strcmp(argv[2], "valid")) validate_classifier_single(data, cfg, weights); else if(0==strcmp(argv[2], "validmulti")) validate_classifier_multi(data, cfg, weights); - else if(0==strcmp(argv[2], "validsingle")) validate_classifier_single(data, cfg, weights); + else if(0==strcmp(argv[2], "valid10")) validate_classifier_10(data, cfg, weights); + else if(0==strcmp(argv[2], "validcrop")) validate_classifier_crop(data, cfg, weights); else if(0==strcmp(argv[2], "validfull")) validate_classifier_full(data, cfg, weights); } diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index e8ae49c9..006dc4c6 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -104,36 +104,37 @@ image get_convolutional_delta(convolutional_layer l) size_t get_workspace_size(layer l){ #ifdef CUDNN - size_t most = 0; - size_t s = 0; - cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(), - l.srcTensorDesc, - l.filterDesc, - l.convDesc, - l.dstTensorDesc, - l.fw_algo, - &s); - if (s > most) most = s; - cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(), - l.srcTensorDesc, - l.ddstTensorDesc, - l.convDesc, - l.dfilterDesc, - l.bf_algo, - &s); - if (s > most) most = s; - cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(), - l.filterDesc, - l.ddstTensorDesc, - l.convDesc, - l.dsrcTensorDesc, - l.bd_algo, - &s); - if (s > most) most = s; - return most; -#else + if(gpu_index >= 0){ + size_t most = 0; + size_t s = 0; + cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(), + l.srcTensorDesc, + l.filterDesc, + l.convDesc, + l.dstTensorDesc, + l.fw_algo, + &s); + if (s > most) most = s; + cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(), + l.srcTensorDesc, + l.ddstTensorDesc, + l.convDesc, + l.dfilterDesc, + l.bf_algo, + &s); + if (s > most) most = s; + cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(), + l.filterDesc, + l.ddstTensorDesc, + l.convDesc, + l.dsrcTensorDesc, + l.bd_algo, + &s); + if (s > most) most = s; + return most; + } + #endif return (size_t)l.out_h*l.out_w*l.size*l.size*l.c*sizeof(float); -#endif } #ifdef GPU @@ -240,49 +241,51 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int } #ifdef GPU - l.filters_gpu = cuda_make_array(l.filters, c*n*size*size); - l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size); + if(gpu_index >= 0){ + l.filters_gpu = cuda_make_array(l.filters, c*n*size*size); + l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size); - l.biases_gpu = cuda_make_array(l.biases, n); - l.bias_updates_gpu = cuda_make_array(l.bias_updates, n); + l.biases_gpu = cuda_make_array(l.biases, n); + l.bias_updates_gpu = cuda_make_array(l.bias_updates, n); - l.scales_gpu = cuda_make_array(l.scales, n); - l.scale_updates_gpu = cuda_make_array(l.scale_updates, n); + l.scales_gpu = cuda_make_array(l.scales, n); + l.scale_updates_gpu = cuda_make_array(l.scale_updates, n); - l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n); - l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); + l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n); + l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); - if(binary){ - l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size); - } - if(xnor){ - l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size); - l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch); - } + if(binary){ + l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size); + } + if(xnor){ + l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size); + l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch); + } - if(batch_normalize){ - l.mean_gpu = cuda_make_array(l.mean, n); - l.variance_gpu = cuda_make_array(l.variance, n); + if(batch_normalize){ + l.mean_gpu = cuda_make_array(l.mean, n); + l.variance_gpu = cuda_make_array(l.variance, n); - l.rolling_mean_gpu = cuda_make_array(l.mean, n); - l.rolling_variance_gpu = cuda_make_array(l.variance, n); + l.rolling_mean_gpu = cuda_make_array(l.mean, n); + l.rolling_variance_gpu = cuda_make_array(l.variance, n); - l.mean_delta_gpu = cuda_make_array(l.mean, n); - l.variance_delta_gpu = cuda_make_array(l.variance, n); + l.mean_delta_gpu = cuda_make_array(l.mean, n); + l.variance_delta_gpu = cuda_make_array(l.variance, n); - l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); - l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); - } + l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); + l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n); + } #ifdef CUDNN - cudnnCreateTensorDescriptor(&l.srcTensorDesc); - cudnnCreateTensorDescriptor(&l.dstTensorDesc); - cudnnCreateFilterDescriptor(&l.filterDesc); - cudnnCreateTensorDescriptor(&l.dsrcTensorDesc); - cudnnCreateTensorDescriptor(&l.ddstTensorDesc); - cudnnCreateFilterDescriptor(&l.dfilterDesc); - cudnnCreateConvolutionDescriptor(&l.convDesc); - cudnn_convolutional_setup(&l); + cudnnCreateTensorDescriptor(&l.srcTensorDesc); + cudnnCreateTensorDescriptor(&l.dstTensorDesc); + cudnnCreateFilterDescriptor(&l.filterDesc); + cudnnCreateTensorDescriptor(&l.dsrcTensorDesc); + cudnnCreateTensorDescriptor(&l.ddstTensorDesc); + cudnnCreateFilterDescriptor(&l.dfilterDesc); + cudnnCreateConvolutionDescriptor(&l.convDesc); + cudnn_convolutional_setup(&l); #endif + } #endif l.workspace_size = get_workspace_size(l); l.activation = activation; diff --git a/src/darknet.c b/src/darknet.c index 49c9747e..c367abff 100644 --- a/src/darknet.c +++ b/src/darknet.c @@ -12,6 +12,7 @@ #include "opencv2/highgui/highgui_c.h" #endif +extern void run_voxel(int argc, char **argv); extern void run_imagenet(int argc, char **argv); extern void run_yolo(int argc, char **argv); extern void run_detector(int argc, char **argv); @@ -28,6 +29,7 @@ extern void run_tag(int argc, char **argv); extern void run_cifar(int argc, char **argv); extern void run_go(int argc, char **argv); extern void run_art(int argc, char **argv); +extern void run_super(int argc, char **argv); void change_rate(char *filename, float scale, float add) { @@ -89,6 +91,23 @@ void average(int argc, char *argv[]) save_weights(sum, outfile); } +void speed(char *cfgfile, int tics) +{ + if (tics == 0) tics = 1000; + network net = parse_network_cfg(cfgfile); + set_batch_network(&net, 1); + int i; + time_t start = time(0); + image im = make_image(net.w, net.h, net.c); + for(i = 0; i < tics; ++i){ + network_predict(net, im.data); + } + double t = difftime(time(0), start); + printf("\n%d evals, %f Seconds\n", tics, t); + printf("Speed: %f sec/eval\n", t/tics); + printf("Speed: %f Hz\n", tics/t); +} + void operations(char *cfgfile) { gpu_index = -1; @@ -314,6 +333,10 @@ int main(int argc, char **argv) average(argc, argv); } else if (0 == strcmp(argv[1], "yolo")){ run_yolo(argc, argv); + } else if (0 == strcmp(argv[1], "voxel")){ + run_voxel(argc, argv); + } else if (0 == strcmp(argv[1], "super")){ + run_super(argc, argv); } else if (0 == strcmp(argv[1], "detector")){ run_detector(argc, argv); } else if (0 == strcmp(argv[1], "cifar")){ @@ -339,7 +362,7 @@ int main(int argc, char **argv) } else if (0 == strcmp(argv[1], "writing")){ run_writing(argc, argv); } else if (0 == strcmp(argv[1], "3d")){ - composite_3d(argv[2], argv[3], argv[4]); + composite_3d(argv[2], argv[3], argv[4], (argc > 5) ? atof(argv[5]) : 0); } else if (0 == strcmp(argv[1], "test")){ test_resize(argv[2]); } else if (0 == strcmp(argv[1], "captcha")){ @@ -360,6 +383,8 @@ int main(int argc, char **argv) rescale_net(argv[2], argv[3], argv[4]); } else if (0 == strcmp(argv[1], "ops")){ operations(argv[2]); + } else if (0 == strcmp(argv[1], "speed")){ + speed(argv[2], (argc > 3) ? atoi(argv[3]) : 0); } else if (0 == strcmp(argv[1], "partial")){ partial(argv[2], argv[3], argv[4], atoi(argv[5])); } else if (0 == strcmp(argv[1], "average")){ diff --git a/src/data.c b/src/data.c index afd7506f..231fb935 100644 --- a/src/data.c +++ b/src/data.c @@ -8,6 +8,7 @@ #include unsigned int data_seed; +pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; list *get_paths(char *filename) { @@ -26,12 +27,14 @@ char **get_random_paths_indexes(char **paths, int n, int m, int *indexes) { char **random_paths = calloc(n, sizeof(char*)); int i; + pthread_mutex_lock(&mutex); for(i = 0; i < n; ++i){ int index = rand_r(&data_seed)%m; indexes[i] = index; random_paths[i] = paths[index]; if(i == 0) printf("%s\n", paths[index]); } + pthread_mutex_unlock(&mutex); return random_paths; } @@ -39,11 +42,13 @@ char **get_random_paths(char **paths, int n, int m) { char **random_paths = calloc(n, sizeof(char*)); int i; + pthread_mutex_lock(&mutex); for(i = 0; i < n; ++i){ int index = rand_r(&data_seed)%m; random_paths[i] = paths[index]; if(i == 0) printf("%s\n", paths[index]); } + pthread_mutex_unlock(&mutex); return random_paths; } @@ -105,7 +110,7 @@ matrix load_image_cropped_paths(char **paths, int n, int min, int max, int size) for(i = 0; i < n; ++i){ image im = load_image_color(paths[i], 0, 0); - image crop = random_crop_image(im, min, max, size); + image crop = random_resize_crop_image(im, min, max, size); int flip = rand_r(&data_seed)%2; if (flip) flip_image(crop); /* @@ -667,6 +672,8 @@ void *load_thread(void *ptr) *a.d = load_data(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h); } else if (a.type == CLASSIFICATION_DATA){ *a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.min, a.max, a.size); + } else if (a.type == SUPER_DATA){ + *a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale); } else if (a.type == STUDY_DATA){ *a.d = load_data_study(a.paths, a.n, a.m, a.labels, a.classes, a.min, a.max, a.size); } else if (a.type == WRITING_DATA){ @@ -737,6 +744,36 @@ data load_data_study(char **paths, int n, int m, char **labels, int k, int min, return d; } +data load_data_super(char **paths, int n, int m, int w, int h, int scale) +{ + if(m) paths = get_random_paths(paths, n, m); + data d = {0}; + d.shallow = 0; + + int i; + d.X.rows = n; + d.X.vals = calloc(n, sizeof(float*)); + d.X.cols = w*h*3; + + d.y.rows = n; + d.y.vals = calloc(n, sizeof(float*)); + d.y.cols = w*scale * h*scale * 3; + + for(i = 0; i < n; ++i){ + image im = load_image_color(paths[i], 0, 0); + image crop = random_crop_image(im, w*scale, h*scale); + int flip = rand_r(&data_seed)%2; + if (flip) flip_image(crop); + image resize = resize_image(crop, w, h); + d.X.vals[i] = resize.data; + d.y.vals[i] = crop.data; + free_image(im); + } + + if(m) free(paths); + return d; +} + data load_data_augment(char **paths, int n, int m, char **labels, int k, int min, int max, int size) { if(m) paths = get_random_paths(paths, n, m); @@ -786,6 +823,19 @@ data concat_data(data d1, data d2) return d; } +data concat_datas(data *d, int n) +{ + int i; + data out = {0}; + out.shallow = 1; + for(i = 0; i < n; ++i){ + data new = concat_data(d[i], out); + free_data(out); + out = new; + } + return out; +} + data load_categorical_data_csv(char *filename, int target, int k) { data d = {0}; diff --git a/src/data.h b/src/data.h index 11363f19..75123a50 100644 --- a/src/data.h +++ b/src/data.h @@ -30,7 +30,7 @@ typedef struct{ } data; typedef enum { - CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA, STUDY_DATA, DET_DATA + CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA, STUDY_DATA, DET_DATA, SUPER_DATA } data_type; typedef struct load_args{ @@ -49,6 +49,7 @@ typedef struct load_args{ int min, max, size; int classes; int background; + int scale; float jitter; data *d; image *im; @@ -73,6 +74,7 @@ data load_data(char **paths, int n, int m, char **labels, int k, int w, int h); data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter); data load_data_tag(char **paths, int n, int m, int k, int min, int max, int size); data load_data_augment(char **paths, int n, int m, char **labels, int k, int min, int max, int size); +data load_data_super(char **paths, int n, int m, int w, int h, int scale); data load_data_study(char **paths, int n, int m, char **labels, int k, int min, int max, int size); data load_go(char *filename); @@ -94,6 +96,7 @@ void translate_data_rows(data d, float s); void randomize_data(data d); data *split_data(data d, int part, int total); data concat_data(data d1, data d2); +data concat_datas(data *d, int n); void fill_truth(char *path, char **labels, int k, float *truth); #endif diff --git a/src/detector.c b/src/detector.c new file mode 100644 index 00000000..64adaf37 --- /dev/null +++ b/src/detector.c @@ -0,0 +1,398 @@ +#include "network.h" +#include "detection_layer.h" +#include "cost_layer.h" +#include "utils.h" +#include "parser.h" +#include "box.h" + +#ifdef OPENCV +#include "opencv2/highgui/highgui_c.h" +#endif + +static char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"}; +static image voc_labels[20]; + +void train_detector(char *cfgfile, char *weightfile) +{ + char *train_images = "/data/voc/train.txt"; + char *backup_directory = "/home/pjreddie/backup/"; + srand(time(0)); + data_seed = time(0); + char *base = basecfg(cfgfile); + printf("%s\n", base); + float avg_loss = -1; + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + int imgs = net.batch*net.subdivisions; + int i = *net.seen/imgs; + data train, buffer; + + layer l = net.layers[net.n - 1]; + + int classes = l.classes; + float jitter = l.jitter; + + list *plist = get_paths(train_images); + //int N = plist->size; + char **paths = (char **)list_to_array(plist); + + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.paths = paths; + args.n = imgs; + args.m = plist->size; + args.classes = classes; + args.jitter = jitter; + args.num_boxes = l.max_boxes; + args.d = &buffer; + args.type = DETECTION_DATA; + + pthread_t load_thread = load_data_in_thread(args); + clock_t time; + //while(i*imgs < N*120){ + while(get_current_batch(net) < net.max_batches){ + i += 1; + time=clock(); + pthread_join(load_thread, 0); + train = buffer; + load_thread = load_data_in_thread(args); + +/* + int k; + for(k = 0; k < l.max_boxes; ++k){ + box b = float_to_box(train.y.vals[10] + 1 + k*5); + if(!b.x) break; + printf("loaded: %f %f %f %f\n", b.x, b.y, b.w, b.h); + } + image im = float_to_image(448, 448, 3, train.X.vals[10]); + int k; + for(k = 0; k < l.max_boxes; ++k){ + box b = float_to_box(train.y.vals[10] + 1 + k*5); + printf("%d %d %d %d\n", truth.x, truth.y, truth.w, truth.h); + draw_bbox(im, b, 8, 1,0,0); + } + save_image(im, "truth11"); +*/ + + printf("Loaded: %lf seconds\n", sec(clock()-time)); + + time=clock(); + float loss = train_network(net, train); + if (avg_loss < 0) avg_loss = loss; + avg_loss = avg_loss*.9 + loss*.1; + + printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs); + if(i%1000==0 || (i < 1000 && i%100 == 0)){ + char buff[256]; + sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); + save_weights(net, buff); + } + free_data(train); + } + char buff[256]; + sprintf(buff, "%s/%s_final.weights", backup_directory, base); + save_weights(net, buff); +} + +static void convert_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness) +{ + int i,j,n; + //int per_cell = 5*num+classes; + for (i = 0; i < side*side; ++i){ + int row = i / side; + int col = i % side; + for(n = 0; n < num; ++n){ + int index = i*num + n; + int p_index = index * (classes + 5) + 4; + float scale = predictions[p_index]; + int box_index = index * (classes + 5); + boxes[index].x = (predictions[box_index + 0] + col + .5) / side * w; + boxes[index].y = (predictions[box_index + 1] + row + .5) / side * h; + boxes[index].w = pow(logistic_activate(predictions[box_index + 2]), (square?2:1)) * w; + boxes[index].h = pow(logistic_activate(predictions[box_index + 3]), (square?2:1)) * h; + for(j = 0; j < classes; ++j){ + int class_index = index * (classes + 5) + 5; + float prob = scale*predictions[class_index+j]; + probs[index][j] = (prob > thresh) ? prob : 0; + } + if(only_objectness){ + probs[index][0] = scale; + } + } + } +} + +void print_detector_detections(FILE **fps, char *id, box *boxes, float **probs, int total, int classes, int w, int h) +{ + int i, j; + for(i = 0; i < total; ++i){ + float xmin = boxes[i].x - boxes[i].w/2.; + float xmax = boxes[i].x + boxes[i].w/2.; + float ymin = boxes[i].y - boxes[i].h/2.; + float ymax = boxes[i].y + boxes[i].h/2.; + + if (xmin < 0) xmin = 0; + if (ymin < 0) ymin = 0; + if (xmax > w) xmax = w; + if (ymax > h) ymax = h; + + for(j = 0; j < classes; ++j){ + if (probs[i][j]) fprintf(fps[j], "%s %f %f %f %f %f\n", id, probs[i][j], + xmin, ymin, xmax, ymax); + } + } +} + +void validate_detector(char *cfgfile, char *weightfile) +{ + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + set_batch_network(&net, 1); + fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + srand(time(0)); + + char *base = "results/comp4_det_test_"; + //list *plist = get_paths("data/voc.2007.test"); + list *plist = get_paths("/home/pjreddie/data/voc/2007_test.txt"); + //list *plist = get_paths("data/voc.2012.test"); + char **paths = (char **)list_to_array(plist); + + layer l = net.layers[net.n-1]; + int classes = l.classes; + int side = l.w; + + int j; + FILE **fps = calloc(classes, sizeof(FILE *)); + for(j = 0; j < classes; ++j){ + char buff[1024]; + snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]); + fps[j] = fopen(buff, "w"); + } + box *boxes = calloc(side*side*l.n, sizeof(box)); + float **probs = calloc(side*side*l.n, sizeof(float *)); + for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *)); + + int m = plist->size; + int i=0; + int t; + + float thresh = .001; + float nms = .5; + + int nthreads = 2; + image *val = calloc(nthreads, sizeof(image)); + image *val_resized = calloc(nthreads, sizeof(image)); + image *buf = calloc(nthreads, sizeof(image)); + image *buf_resized = calloc(nthreads, sizeof(image)); + pthread_t *thr = calloc(nthreads, sizeof(pthread_t)); + + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.type = IMAGE_DATA; + + for(t = 0; t < nthreads; ++t){ + args.path = paths[i+t]; + args.im = &buf[t]; + args.resized = &buf_resized[t]; + thr[t] = load_data_in_thread(args); + } + time_t start = time(0); + for(i = nthreads; i < m+nthreads; i += nthreads){ + fprintf(stderr, "%d\n", i); + for(t = 0; t < nthreads && i+t-nthreads < m; ++t){ + pthread_join(thr[t], 0); + val[t] = buf[t]; + val_resized[t] = buf_resized[t]; + } + for(t = 0; t < nthreads && i+t < m; ++t){ + args.path = paths[i+t]; + args.im = &buf[t]; + args.resized = &buf_resized[t]; + thr[t] = load_data_in_thread(args); + } + for(t = 0; t < nthreads && i+t-nthreads < m; ++t){ + char *path = paths[i+t-nthreads]; + char *id = basecfg(path); + float *X = val_resized[t].data; + float *predictions = network_predict(net, X); + int w = val[t].w; + int h = val[t].h; + convert_detections(predictions, classes, l.n, 0, side, w, h, thresh, probs, boxes, 0); + if (nms) do_nms_sort(boxes, probs, side*side*l.n, classes, nms); + print_detector_detections(fps, id, boxes, probs, side*side*l.n, classes, w, h); + free(id); + free_image(val[t]); + free_image(val_resized[t]); + } + } + fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start)); +} + +void validate_detector_recall(char *cfgfile, char *weightfile) +{ + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + set_batch_network(&net, 1); + fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + srand(time(0)); + + char *base = "results/comp4_det_test_"; + list *plist = get_paths("data/voc.2007.test"); + char **paths = (char **)list_to_array(plist); + + layer l = net.layers[net.n-1]; + int classes = l.classes; + int square = l.sqrt; + int side = l.side; + + int j, k; + FILE **fps = calloc(classes, sizeof(FILE *)); + for(j = 0; j < classes; ++j){ + char buff[1024]; + snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]); + fps[j] = fopen(buff, "w"); + } + box *boxes = calloc(side*side*l.n, sizeof(box)); + float **probs = calloc(side*side*l.n, sizeof(float *)); + for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *)); + + int m = plist->size; + int i=0; + + float thresh = .001; + float iou_thresh = .5; + float nms = .4; + + int total = 0; + int correct = 0; + int proposals = 0; + float avg_iou = 0; + + for(i = 0; i < m; ++i){ + char *path = paths[i]; + image orig = load_image_color(path, 0, 0); + image sized = resize_image(orig, net.w, net.h); + char *id = basecfg(path); + float *predictions = network_predict(net, sized.data); + convert_detections(predictions, classes, l.n, square, l.w, 1, 1, thresh, probs, boxes, 1); + if (nms) do_nms(boxes, probs, side*side*l.n, 1, nms); + + char *labelpath = find_replace(path, "images", "labels"); + labelpath = find_replace(labelpath, "JPEGImages", "labels"); + labelpath = find_replace(labelpath, ".jpg", ".txt"); + labelpath = find_replace(labelpath, ".JPEG", ".txt"); + + int num_labels = 0; + box_label *truth = read_boxes(labelpath, &num_labels); + for(k = 0; k < side*side*l.n; ++k){ + if(probs[k][0] > thresh){ + ++proposals; + } + } + for (j = 0; j < num_labels; ++j) { + ++total; + box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h}; + float best_iou = 0; + for(k = 0; k < side*side*l.n; ++k){ + float iou = box_iou(boxes[k], t); + if(probs[k][0] > thresh && iou > best_iou){ + best_iou = iou; + } + } + avg_iou += best_iou; + if(best_iou > iou_thresh){ + ++correct; + } + } + + fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total); + free(id); + free_image(orig); + free_image(sized); + } +} + +void test_detector(char *cfgfile, char *weightfile, char *filename, float thresh) +{ + + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + detection_layer l = net.layers[net.n-1]; + l.side = l.w; + set_batch_network(&net, 1); + srand(2222222); + clock_t time; + char buff[256]; + char *input = buff; + int j; + float nms=.4; + box *boxes = calloc(l.side*l.side*l.n, sizeof(box)); + float **probs = calloc(l.side*l.side*l.n, sizeof(float *)); + for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *)); + while(1){ + if(filename){ + strncpy(input, filename, 256); + } else { + printf("Enter Image Path: "); + fflush(stdout); + input = fgets(input, 256, stdin); + if(!input) return; + strtok(input, "\n"); + } + image im = load_image_color(input,0,0); + image sized = resize_image(im, net.w, net.h); + float *X = sized.data; + time=clock(); + float *predictions = network_predict(net, X); + printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); + convert_detections(predictions, l.classes, l.n, 0, l.w, 1, 1, thresh, probs, boxes, 0); + if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms); + //draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, voc_labels, 20); + draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, voc_labels, 20); + save_image(im, "predictions"); + show_image(im, "predictions"); + + free_image(im); + free_image(sized); +#ifdef OPENCV + cvWaitKey(0); + cvDestroyAllWindows(); +#endif + if (filename) break; + } +} + +void run_detector(int argc, char **argv) +{ + int i; + for(i = 0; i < 20; ++i){ + char buff[256]; + sprintf(buff, "data/labels/%s.png", voc_names[i]); + voc_labels[i] = load_image_color(buff, 0, 0); + } + + float thresh = find_float_arg(argc, argv, "-thresh", .2); + if(argc < 4){ + fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); + return; + } + + char *cfg = argv[3]; + char *weights = (argc > 4) ? argv[4] : 0; + char *filename = (argc > 5) ? argv[5]: 0; + if(0==strcmp(argv[2], "test")) test_detector(cfg, weights, filename, thresh); + else if(0==strcmp(argv[2], "train")) train_detector(cfg, weights); + else if(0==strcmp(argv[2], "valid")) validate_detector(cfg, weights); + else if(0==strcmp(argv[2], "recall")) validate_detector_recall(cfg, weights); +} diff --git a/src/image.c b/src/image.c index 98e80c90..fd890d0f 100644 --- a/src/image.c +++ b/src/image.c @@ -347,23 +347,6 @@ void show_image_cv(image p, const char *name) #endif } - void save_image(image im, const char *name) - { - char buff[256]; - //sprintf(buff, "%s (%d)", name, windows); - sprintf(buff, "%s.png", name); - unsigned char *data = calloc(im.w*im.h*im.c, sizeof(char)); - int i,k; - for(k = 0; k < im.c; ++k){ - for(i = 0; i < im.w*im.h; ++i){ - data[i*im.c+k] = (unsigned char) (255*im.data[i + k*im.w*im.h]); - } - } - int success = stbi_write_png(buff, im.w, im.h, im.c, data, im.w*im.c); - free(data); - if(!success) fprintf(stderr, "Failed to write image %s\n", buff); - } - #ifdef OPENCV image get_image_from_stream(CvCapture *cap) { @@ -376,7 +359,7 @@ void show_image_cv(image p, const char *name) #endif #ifdef OPENCV - void save_image_jpg(image p, char *name) + void save_image_jpg(image p, const char *name) { image copy = copy_image(p); rgbgr_image(copy); @@ -400,6 +383,28 @@ void show_image_cv(image p, const char *name) } #endif + void save_image(image im, const char *name) + { + #ifdef OPENCV + save_image_jpg(im, name); + #else + char buff[256]; + //sprintf(buff, "%s (%d)", name, windows); + sprintf(buff, "%s.png", name); + unsigned char *data = calloc(im.w*im.h*im.c, sizeof(char)); + int i,k; + for(k = 0; k < im.c; ++k){ + for(i = 0; i < im.w*im.h; ++i){ + data[i*im.c+k] = (unsigned char) (255*im.data[i + k*im.w*im.h]); + } + } + int success = stbi_write_png(buff, im.w, im.h, im.c, data, im.w*im.c); + free(data); + if(!success) fprintf(stderr, "Failed to write image %s\n", buff); + #endif + } + + void show_image_layers(image p, char *name) { int i; @@ -539,7 +544,7 @@ int best_3d_shift(image a, image b, int min, int max) return best; } -void composite_3d(char *f1, char *f2, char *out) +void composite_3d(char *f1, char *f2, char *out, int delta) { if(!out) out = "out"; image a = load_image(f1, 0,0,0); @@ -551,7 +556,7 @@ void composite_3d(char *f1, char *f2, char *out) image c2 = crop_image(b, -10, shift, b.w, b.h); float d2 = dist_array(c2.data, a.data, a.w*a.h*a.c, 100); - if(d2 < d1){ + if(d2 < d1 && 0){ image swap = a; a = b; b = swap; @@ -562,7 +567,7 @@ void composite_3d(char *f1, char *f2, char *out) printf("%d\n", shift); } - image c = crop_image(b, 0, shift, a.w, a.h); + image c = crop_image(b, delta, shift, a.w, a.h); int i; for(i = 0; i < c.w*c.h; ++i){ c.data[i] = a.data[i]; @@ -590,7 +595,15 @@ image resize_min(image im, int min) return resized; } -image random_crop_image(image im, int low, int high, int size) +image random_crop_image(image im, int w, int h) +{ + int dx = rand_int(0, im.w - w); + int dy = rand_int(0, im.h - h); + image crop = crop_image(im, dx, dy, w, h); + return crop; +} + +image random_resize_crop_image(image im, int low, int high, int size) { int r = rand_int(low, high); image resized = resize_min(im, r); diff --git a/src/image.h b/src/image.h index ece7cb6a..e4eecd5e 100644 --- a/src/image.h +++ b/src/image.h @@ -30,7 +30,8 @@ void draw_detections(image im, int num, float thresh, box *boxes, float **probs, image image_distance(image a, image b); void scale_image(image m, float s); image crop_image(image im, int dx, int dy, int w, int h); -image random_crop_image(image im, int low, int high, int size); +image random_crop_image(image im, int w, int h); +image random_resize_crop_image(image im, int low, int high, int size); image resize_image(image im, int w, int h); image resize_min(image im, int min); void translate_image(image m, float s); @@ -44,7 +45,8 @@ void saturate_exposure_image(image im, float sat, float exposure); void hsv_to_rgb(image im); void rgbgr_image(image im); void constrain_image(image im); -void composite_3d(char *f1, char *f2, char *out); +void composite_3d(char *f1, char *f2, char *out, int delta); +int best_3d_shift_r(image a, image b, int min, int max); image grayscale_image(image im); image threshold_image(image im, float thresh); @@ -61,7 +63,7 @@ void show_image_layers(image p, char *name); void show_image_collapsed(image p, char *name); #ifdef OPENCV -void save_image_jpg(image p, char *name); +void save_image_jpg(image p, const char *name); image get_image_from_stream(CvCapture *cap); image ipl_to_image(IplImage* src); #endif diff --git a/src/layer.h b/src/layer.h index 7182acd7..10d64e50 100644 --- a/src/layer.h +++ b/src/layer.h @@ -30,6 +30,7 @@ typedef enum { NETWORK, XNOR, REGION, + REORG, BLANK } LAYER_TYPE; @@ -80,6 +81,7 @@ struct layer{ int does_cost; int joint; int noadjust; + int reorg; float alpha; float beta; diff --git a/src/network.c b/src/network.c index 6ed82cea..91baafe4 100644 --- a/src/network.c +++ b/src/network.c @@ -20,6 +20,7 @@ #include "normalization_layer.h" #include "batchnorm_layer.h" #include "maxpool_layer.h" +#include "reorg_layer.h" #include "avgpool_layer.h" #include "cost_layer.h" #include "softmax_layer.h" @@ -98,6 +99,8 @@ char *get_layer_string(LAYER_TYPE a) return "crnn"; case MAXPOOL: return "maxpool"; + case REORG: + return "reorg"; case AVGPOOL: return "avgpool"; case SOFTMAX: @@ -181,6 +184,8 @@ void forward_network(network net, network_state state) forward_softmax_layer(l, state); } else if(l.type == MAXPOOL){ forward_maxpool_layer(l, state); + } else if(l.type == REORG){ + forward_reorg_layer(l, state); } else if(l.type == AVGPOOL){ forward_avgpool_layer(l, state); } else if(l.type == DROPOUT){ @@ -222,7 +227,7 @@ void update_network(network net) float *get_network_output(network net) { #ifdef GPU - return get_network_output_gpu(net); + if (gpu_index >= 0) return get_network_output_gpu(net); #endif int i; for(i = net.n-1; i > 0; --i) if(net.layers[i].type != COST) break; @@ -279,6 +284,8 @@ void backward_network(network net, network_state state) backward_batchnorm_layer(l, state); } else if(l.type == MAXPOOL){ if(i != 0) backward_maxpool_layer(l, state); + } else if(l.type == REORG){ + backward_reorg_layer(l, state); } else if(l.type == AVGPOOL){ backward_avgpool_layer(l, state); } else if(l.type == DROPOUT){ @@ -366,6 +373,7 @@ float train_network(network net, data d) return (float)sum/(n*batch); } + float train_network_batch(network net, data d, int n) { int i,j; @@ -422,6 +430,8 @@ int resize_network(network *net, int w, int h) resize_crop_layer(&l, w, h); }else if(l.type == MAXPOOL){ resize_maxpool_layer(&l, w, h); + }else if(l.type == REORG){ + resize_reorg_layer(&l, w, h); }else if(l.type == AVGPOOL){ resize_avgpool_layer(&l, w, h); }else if(l.type == NORMALIZATION){ @@ -439,11 +449,16 @@ int resize_network(network *net, int w, int h) if(l.type == AVGPOOL) break; } #ifdef GPU + if(gpu_index >= 0){ cuda_free(net->workspace); net->workspace = cuda_make_array(0, (workspace_size-1)/sizeof(float)+1); -#else + }else { free(net->workspace); net->workspace = calloc(1, workspace_size); + } +#else + free(net->workspace); + net->workspace = calloc(1, workspace_size); #endif //fprintf(stderr, " Done!\n"); return 0; @@ -659,10 +674,10 @@ void free_network(network net) free_layer(net.layers[i]); } free(net.layers); - #ifdef GPU +#ifdef GPU if(*net.input_gpu) cuda_free(*net.input_gpu); if(*net.truth_gpu) cuda_free(*net.truth_gpu); if(net.input_gpu) free(net.input_gpu); if(net.truth_gpu) free(net.truth_gpu); - #endif +#endif } diff --git a/src/network.h b/src/network.h index af64e06f..41573873 100644 --- a/src/network.h +++ b/src/network.h @@ -41,6 +41,8 @@ typedef struct network{ int max_crop; int min_crop; + int gpu_index; + #ifdef GPU float **input_gpu; float **truth_gpu; diff --git a/src/network_kernels.cu b/src/network_kernels.cu index e1d41295..3e01019e 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -24,6 +24,7 @@ extern "C" { #include "activation_layer.h" #include "deconvolutional_layer.h" #include "maxpool_layer.h" +#include "reorg_layer.h" #include "avgpool_layer.h" #include "normalization_layer.h" #include "batchnorm_layer.h" @@ -82,6 +83,8 @@ void forward_network_gpu(network net, network_state state) forward_batchnorm_layer_gpu(l, state); } else if(l.type == MAXPOOL){ forward_maxpool_layer_gpu(l, state); + } else if(l.type == REORG){ + forward_reorg_layer_gpu(l, state); } else if(l.type == AVGPOOL){ forward_avgpool_layer_gpu(l, state); } else if(l.type == DROPOUT){ @@ -122,6 +125,8 @@ void backward_network_gpu(network net, network_state state) backward_local_layer_gpu(l, state); } else if(l.type == MAXPOOL){ if(i != 0) backward_maxpool_layer_gpu(l, state); + } else if(l.type == REORG){ + backward_reorg_layer_gpu(l, state); } else if(l.type == AVGPOOL){ if(i != 0) backward_avgpool_layer_gpu(l, state); } else if(l.type == DROPOUT){ @@ -179,7 +184,7 @@ void update_network_gpu(network net) } } -float train_network_datum_gpu(network net, float *x, float *y) +void forward_backward_network_gpu(network net, float *x, float *y) { network_state state; state.index = 0; @@ -200,12 +205,64 @@ float train_network_datum_gpu(network net, float *x, float *y) state.train = 1; forward_network_gpu(net, state); backward_network_gpu(net, state); +} + +float train_network_datum_gpu(network net, float *x, float *y) +{ + forward_backward_network_gpu(net, x, y); float error = get_network_cost(net); if (((*net.seen) / net.batch) % net.subdivisions == 0) update_network_gpu(net); return error; } +typedef struct { + network net; + float *X; + float *y; +} train_args; + +void *train_thread(void *ptr) +{ + train_args args = *(train_args*)ptr; + + cudaError_t status = cudaSetDevice(args.net.gpu_index); + check_error(status); + forward_backward_network_gpu(args.net, args.X, args.y); + free(ptr); + return 0; +} + +pthread_t train_network_in_thread(train_args args) +{ + pthread_t thread; + train_args *ptr = (train_args *)calloc(1, sizeof(train_args)); + *ptr = args; + if(pthread_create(&thread, 0, train_thread, ptr)) error("Thread creation failed"); + return thread; +} + +float train_networks(network *nets, int n, data d) +{ + int batch = nets[0].batch; + float **X = (float **) calloc(n, sizeof(float *)); + float **y = (float **) calloc(n, sizeof(float *)); + pthread_t *threads = (pthread_t *) calloc(n, sizeof(pthread_t)); + + int i; + float sum = 0; + for(i = 0; i < n; ++i){ + X[i] = (float *) calloc(batch*d.X.cols, sizeof(float)); + y[i] = (float *) calloc(batch*d.y.cols, sizeof(float)); + get_next_batch(d, batch, i*batch, X[i], y[i]); + float err = train_network_datum(nets[i], X[i], y[i]); + sum += err; + } + free(X); + free(y); + return (float)sum/(n*batch); +} + float *get_network_output_layer_gpu(network net, int i) { layer l = net.layers[i]; diff --git a/src/parser.c b/src/parser.c index b5c399fc..503e7cf5 100644 --- a/src/parser.c +++ b/src/parser.c @@ -3,6 +3,7 @@ #include #include "parser.h" +#include "assert.h" #include "activations.h" #include "crop_layer.h" #include "cost_layer.h" @@ -16,6 +17,7 @@ #include "gru_layer.h" #include "crnn_layer.h" #include "maxpool_layer.h" +#include "reorg_layer.h" #include "softmax_layer.h" #include "dropout_layer.h" #include "detection_layer.h" @@ -43,6 +45,7 @@ int is_rnn(section *s); int is_gru(section *s); int is_crnn(section *s); int is_maxpool(section *s); +int is_reorg(section *s); int is_avgpool(section *s); int is_dropout(section *s); int is_softmax(section *s); @@ -115,13 +118,6 @@ deconvolutional_layer parse_deconvolutional(list *options, size_params params) deconvolutional_layer layer = make_deconvolutional_layer(batch,h,w,c,n,size,stride,activation); - char *weights = option_find_str(options, "weights", 0); - char *biases = option_find_str(options, "biases", 0); - parse_data(weights, layer.filters, c*n*size*size); - parse_data(biases, layer.biases, n); - #ifdef GPU - if(weights || biases) push_deconvolutional_layer(layer); - #endif return layer; } @@ -169,13 +165,6 @@ convolutional_layer parse_convolutional(list *options, size_params params) layer.flipped = option_find_int_quiet(options, "flipped", 0); layer.dot = option_find_float_quiet(options, "dot", 0); - char *weights = option_find_str(options, "weights", 0); - char *biases = option_find_str(options, "biases", 0); - parse_data(weights, layer.filters, c*n*size*size); - parse_data(biases, layer.biases, n); - #ifdef GPU - if(weights || biases) push_convolutional_layer(layer); - #endif return layer; } @@ -229,13 +218,6 @@ connected_layer parse_connected(list *options, size_params params) connected_layer layer = make_connected_layer(params.batch, params.inputs, output, activation, batch_normalize); - char *weights = option_find_str(options, "weights", 0); - char *biases = option_find_str(options, "biases", 0); - parse_data(biases, layer.biases, output); - parse_data(weights, layer.weights, params.inputs*output); - #ifdef GPU - if(weights || biases) push_connected_layer(layer); - #endif return layer; } @@ -286,6 +268,7 @@ detection_layer parse_detection(list *options, size_params params) layer.class_scale = option_find_float(options, "class_scale", 1); layer.jitter = option_find_float(options, "jitter", .2); layer.random = option_find_int_quiet(options, "random", 0); + layer.reorg = option_find_int_quiet(options, "reorg", 0); return layer; } @@ -322,6 +305,21 @@ crop_layer parse_crop(list *options, size_params params) return l; } +layer parse_reorg(list *options, size_params params) +{ + int stride = option_find_int(options, "stride",1); + + int batch,h,w,c; + h = params.h; + w = params.w; + c = params.c; + batch=params.batch; + if(!(h && w && c)) error("Layer before reorg layer must output image."); + + layer layer = make_reorg_layer(batch,w,h,c,stride); + return layer; +} + maxpool_layer parse_maxpool(list *options, size_params params) { int stride = option_find_int(options, "stride",1); @@ -590,6 +588,8 @@ network parse_network_cfg(char *filename) l = parse_batchnorm(options, params); }else if(is_maxpool(s)){ l = parse_maxpool(options, params); + }else if(is_reorg(s)){ + l = parse_reorg(options, params); }else if(is_avgpool(s)){ l = parse_avgpool(options, params); }else if(is_route(s)){ @@ -626,9 +626,13 @@ network parse_network_cfg(char *filename) net.outputs = get_network_output_size(net); net.output = get_network_output(net); if(workspace_size){ - //printf("%ld\n", workspace_size); + //printf("%ld\n", workspace_size); #ifdef GPU - net.workspace = cuda_make_array(0, (workspace_size-1)/sizeof(float)+1); + if(gpu_index >= 0){ + net.workspace = cuda_make_array(0, (workspace_size-1)/sizeof(float)+1); + }else { + net.workspace = calloc(1, workspace_size); + } #else net.workspace = calloc(1, workspace_size); #endif @@ -659,6 +663,7 @@ LAYER_TYPE string_to_layer_type(char * type) || strcmp(type, "[connected]")==0) return CONNECTED; if (strcmp(type, "[max]")==0 || strcmp(type, "[maxpool]")==0) return MAXPOOL; + if (strcmp(type, "[reorg]")==0) return REORG; if (strcmp(type, "[avg]")==0 || strcmp(type, "[avgpool]")==0) return AVGPOOL; if (strcmp(type, "[dropout]")==0) return DROPOUT; @@ -731,6 +736,10 @@ int is_connected(section *s) return (strcmp(s->type, "[conn]")==0 || strcmp(s->type, "[connected]")==0); } +int is_reorg(section *s) +{ + return (strcmp(s->type, "[reorg]")==0); +} int is_maxpool(section *s) { return (strcmp(s->type, "[max]")==0 diff --git a/src/region_layer.c b/src/region_layer.c new file mode 100644 index 00000000..5fe37c5f --- /dev/null +++ b/src/region_layer.c @@ -0,0 +1,286 @@ +#include "region_layer.h" +#include "activations.h" +#include "softmax_layer.h" +#include "blas.h" +#include "box.h" +#include "cuda.h" +#include "utils.h" +#include +#include +#include +#include + +region_layer make_region_layer(int batch, int w, int h, int n, int classes, int coords) +{ + region_layer l = {0}; + l.type = REGION; + + l.n = n; + l.batch = batch; + l.h = h; + l.w = w; + l.classes = classes; + l.coords = coords; + l.cost = calloc(1, sizeof(float)); + l.outputs = h*w*n*(classes + coords + 1); + l.inputs = l.outputs; + l.truths = 30*(5); + l.delta = calloc(batch*l.outputs, sizeof(float)); + l.output = calloc(batch*l.outputs, sizeof(float)); +#ifdef GPU + l.output_gpu = cuda_make_array(l.output, batch*l.outputs); + l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs); +#endif + + fprintf(stderr, "Region Layer\n"); + srand(0); + + return l; +} + +box get_region_box2(float *x, int index, int i, int j, int w, int h) +{ + float aspect = exp(x[index+0]); + float scale = logistic_activate(x[index+1]); + float move_x = x[index+2]; + float move_y = x[index+3]; + + box b; + b.w = sqrt(scale * aspect); + b.h = b.w * 1./aspect; + b.x = move_x * b.w + (i + .5)/w; + b.y = move_y * b.h + (j + .5)/h; + return b; +} + +float delta_region_box2(box truth, float *output, int index, int i, int j, int w, int h, float *delta) +{ + box pred = get_region_box2(output, index, i, j, w, h); + float iou = box_iou(pred, truth); + float true_aspect = truth.w/truth.h; + float true_scale = truth.w*truth.h; + + float true_dx = (truth.x - (i+.5)/w) / truth.w; + float true_dy = (truth.y - (j+.5)/h) / truth.h; + delta[index + 0] = (true_aspect - exp(output[index + 0])) * exp(output[index + 0]); + delta[index + 1] = (true_scale - logistic_activate(output[index + 1])) * logistic_gradient(logistic_activate(output[index + 1])); + delta[index + 2] = true_dx - output[index + 2]; + delta[index + 3] = true_dy - output[index + 3]; + return iou; +} + +box get_region_box(float *x, int index, int i, int j, int w, int h, int adjust, int logistic) +{ + box b; + b.x = (x[index + 0] + i + .5)/w; + b.y = (x[index + 1] + j + .5)/h; + b.w = x[index + 2]; + b.h = x[index + 3]; + if(logistic){ + b.w = logistic_activate(x[index + 2]); + b.h = logistic_activate(x[index + 3]); + } + if(adjust && b.w < .01) b.w = .01; + if(adjust && b.h < .01) b.h = .01; + return b; +} + +float delta_region_box(box truth, float *output, int index, int i, int j, int w, int h, float *delta, int logistic, float scale) +{ + box pred = get_region_box(output, index, i, j, w, h, 0, logistic); + float iou = box_iou(pred, truth); + + delta[index + 0] = scale * (truth.x - pred.x); + delta[index + 1] = scale * (truth.y - pred.y); + delta[index + 2] = scale * ((truth.w - pred.w)*(logistic ? logistic_gradient(pred.w) : 1)); + delta[index + 3] = scale * ((truth.h - pred.h)*(logistic ? logistic_gradient(pred.h) : 1)); + return iou; +} + +float logit(float x) +{ + return log(x/(1.-x)); +} + +float tisnan(float x) +{ + return (x != x); +} + +#define LOG 1 + +void forward_region_layer(const region_layer l, network_state state) +{ + int i,j,b,t,n; + int size = l.coords + l.classes + 1; + memcpy(l.output, state.input, l.outputs*l.batch*sizeof(float)); + reorg(l.output, l.w*l.h, size*l.n, l.batch, 1); + for (b = 0; b < l.batch; ++b){ + for(i = 0; i < l.h*l.w*l.n; ++i){ + int index = size*i + b*l.outputs; + l.output[index + 4] = logistic_activate(l.output[index + 4]); + if(l.softmax){ + softmax_array(l.output + index + 5, l.classes, 1, l.output + index + 5); + } + } + } + if(!state.train) return; + memset(l.delta, 0, l.outputs * l.batch * sizeof(float)); + float avg_iou = 0; + float avg_cat = 0; + float avg_obj = 0; + float avg_anyobj = 0; + int count = 0; + *(l.cost) = 0; + for (b = 0; b < l.batch; ++b) { + for (j = 0; j < l.h; ++j) { + for (i = 0; i < l.w; ++i) { + for (n = 0; n < l.n; ++n) { + int index = size*(j*l.w*l.n + i*l.n + n) + b*l.outputs; + box pred = get_region_box(l.output, index, i, j, l.w, l.h, 1, LOG); + float best_iou = 0; + for(t = 0; t < 30; ++t){ + box truth = float_to_box(state.truth + t*5 + b*l.truths); + if(!truth.x) break; + float iou = box_iou(pred, truth); + if (iou > best_iou) best_iou = iou; + } + avg_anyobj += l.output[index + 4]; + l.delta[index + 4] = l.noobject_scale * ((0 - l.output[index + 4]) * logistic_gradient(l.output[index + 4])); + if(best_iou > .5) l.delta[index + 4] = 0; + + if(*(state.net.seen) < 6400){ + box truth = {0}; + truth.x = (i + .5)/l.w; + truth.y = (j + .5)/l.h; + truth.w = .5; + truth.h = .5; + delta_region_box(truth, l.output, index, i, j, l.w, l.h, l.delta, LOG, 1); + } + } + } + } + for(t = 0; t < 30; ++t){ + box truth = float_to_box(state.truth + t*5 + b*l.truths); + int class = state.truth[t*5 + b*l.truths + 4]; + if(!truth.x) break; + float best_iou = 0; + int best_index = 0; + int best_n = 0; + i = (truth.x * l.w); + j = (truth.y * l.h); + //printf("%d %f %d %f\n", i, truth.x*l.w, j, truth.y*l.h); + box truth_shift = truth; + truth_shift.x = 0; + truth_shift.y = 0; + printf("index %d %d\n",i, j); + for(n = 0; n < l.n; ++n){ + int index = size*(j*l.w*l.n + i*l.n + n) + b*l.outputs; + box pred = get_region_box(l.output, index, i, j, l.w, l.h, 1, LOG); + printf("pred: (%f, %f) %f x %f\n", pred.x, pred.y, pred.w, pred.h); + pred.x = 0; + pred.y = 0; + float iou = box_iou(pred, truth_shift); + if (iou > best_iou){ + best_index = index; + best_iou = iou; + best_n = n; + } + } + printf("%d %f (%f, %f) %f x %f\n", best_n, best_iou, truth.x, truth.y, truth.w, truth.h); + + float iou = delta_region_box(truth, l.output, best_index, i, j, l.w, l.h, l.delta, LOG, l.coord_scale); + avg_iou += iou; + + //l.delta[best_index + 4] = iou - l.output[best_index + 4]; + avg_obj += l.output[best_index + 4]; + l.delta[best_index + 4] = l.object_scale * (1 - l.output[best_index + 4]) * logistic_gradient(l.output[best_index + 4]); + if (l.rescore) { + l.delta[best_index + 4] = l.object_scale * (iou - l.output[best_index + 4]) * logistic_gradient(l.output[best_index + 4]); + } + //printf("%f\n", l.delta[best_index+1]); + /* + if(isnan(l.delta[best_index+1])){ + printf("%f\n", true_scale); + printf("%f\n", l.output[best_index + 1]); + printf("%f\n", truth.w); + printf("%f\n", truth.h); + error("bad"); + } + */ + for(n = 0; n < l.classes; ++n){ + l.delta[best_index + 5 + n] = l.class_scale * (((n == class)?1 : 0) - l.output[best_index + 5 + n]); + if(n == class) avg_cat += l.output[best_index + 5 + n]; + } + /* + if(0){ + printf("truth: %f %f %f %f\n", truth.x, truth.y, truth.w, truth.h); + printf("pred: %f %f %f %f\n\n", pred.x, pred.y, pred.w, pred.h); + float aspect = exp(true_aspect); + float scale = logistic_activate(true_scale); + float move_x = true_dx; + float move_y = true_dy; + + box b; + b.w = sqrt(scale * aspect); + b.h = b.w * 1./aspect; + b.x = move_x * b.w + (i + .5)/l.w; + b.y = move_y * b.h + (j + .5)/l.h; + printf("%f %f\n", b.x, truth.x); + printf("%f %f\n", b.y, truth.y); + printf("%f %f\n", b.w, truth.w); + printf("%f %f\n", b.h, truth.h); + //printf("%f\n", box_iou(b, truth)); + } + */ + ++count; + } + } + printf("\n"); + reorg(l.delta, l.w*l.h, size*l.n, l.batch, 0); + *(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2); + printf("Region Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_obj/count, avg_anyobj/(l.w*l.h*l.n*l.batch), count); +} + +void backward_region_layer(const region_layer l, network_state state) +{ + axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, state.delta, 1); +} + +#ifdef GPU + +void forward_region_layer_gpu(const region_layer l, network_state state) +{ + /* + if(!state.train){ + copy_ongpu(l.batch*l.inputs, state.input, 1, l.output_gpu, 1); + return; + } + */ + + float *in_cpu = calloc(l.batch*l.inputs, sizeof(float)); + float *truth_cpu = 0; + if(state.truth){ + int num_truth = l.batch*l.truths; + truth_cpu = calloc(num_truth, sizeof(float)); + cuda_pull_array(state.truth, truth_cpu, num_truth); + } + cuda_pull_array(state.input, in_cpu, l.batch*l.inputs); + network_state cpu_state = state; + cpu_state.train = state.train; + cpu_state.truth = truth_cpu; + cpu_state.input = in_cpu; + forward_region_layer(l, cpu_state); + cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs); + cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs); + free(cpu_state.input); + if(cpu_state.truth) free(cpu_state.truth); +} + +void backward_region_layer_gpu(region_layer l, network_state state) +{ + axpy_ongpu(l.batch*l.outputs, 1, l.delta_gpu, 1, state.delta, 1); + //copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1); +} +#endif + diff --git a/src/region_layer.h b/src/region_layer.h new file mode 100644 index 00000000..a4156fd0 --- /dev/null +++ b/src/region_layer.h @@ -0,0 +1,18 @@ +#ifndef REGION_LAYER_H +#define REGION_LAYER_H + +#include "layer.h" +#include "network.h" + +typedef layer region_layer; + +region_layer make_region_layer(int batch, int h, int w, int n, int classes, int coords); +void forward_region_layer(const region_layer l, network_state state); +void backward_region_layer(const region_layer l, network_state state); + +#ifdef GPU +void forward_region_layer_gpu(const region_layer l, network_state state); +void backward_region_layer_gpu(region_layer l, network_state state); +#endif + +#endif diff --git a/src/reorg_layer.c b/src/reorg_layer.c new file mode 100644 index 00000000..55b425f1 --- /dev/null +++ b/src/reorg_layer.c @@ -0,0 +1,111 @@ +#include "reorg_layer.h" +#include "cuda.h" +#include "blas.h" +#include + + +layer make_reorg_layer(int batch, int h, int w, int c, int stride) +{ + layer l = {0}; + l.type = REORG; + l.batch = batch; + l.stride = stride; + l.h = h; + l.w = w; + l.c = c; + l.out_w = w*stride; + l.out_h = h*stride; + l.out_c = c/(stride*stride); + fprintf(stderr, "Reorg Layer: %d x %d x %d image -> %d x %d x %d image, \n", w,h,c,l.out_w, l.out_h, l.out_c); + l.outputs = l.out_h * l.out_w * l.out_c; + l.inputs = h*w*c; + int output_size = l.out_h * l.out_w * l.out_c * batch; + l.output = calloc(output_size, sizeof(float)); + l.delta = calloc(output_size, sizeof(float)); + #ifdef GPU + l.output_gpu = cuda_make_array(l.output, output_size); + l.delta_gpu = cuda_make_array(l.delta, output_size); + #endif + return l; +} + +void resize_reorg_layer(layer *l, int w, int h) +{ + int stride = l->stride; + + l->h = h; + l->w = w; + + l->out_w = w*stride; + l->out_h = h*stride; + + l->outputs = l->out_h * l->out_w * l->out_c; + l->inputs = l->outputs; + int output_size = l->outputs * l->batch; + + l->output = realloc(l->output, output_size * sizeof(float)); + l->delta = realloc(l->delta, output_size * sizeof(float)); + + #ifdef GPU + cuda_free(l->output_gpu); + cuda_free(l->delta_gpu); + l->output_gpu = cuda_make_array(l->output, output_size); + l->delta_gpu = cuda_make_array(l->delta, output_size); + #endif +} + +void forward_reorg_layer(const layer l, network_state state) +{ + int b,i,j,k; + + for(b = 0; b < l.batch; ++b){ + for(k = 0; k < l.c; ++k){ + for(j = 0; j < l.h; ++j){ + for(i = 0; i < l.w; ++i){ + int in_index = i + l.w*(j + l.h*(k + l.c*b)); + + int c2 = k % l.out_c; + int offset = k / l.out_c; + int w2 = i*l.stride + offset % l.stride; + int h2 = j*l.stride + offset / l.stride; + int out_index = w2 + l.out_w*(h2 + l.out_h*(c2 + l.out_c*b)); + l.output[out_index] = state.input[in_index]; + } + } + } + } +} + +void backward_reorg_layer(const layer l, network_state state) +{ + int b,i,j,k; + + for(b = 0; b < l.batch; ++b){ + for(k = 0; k < l.c; ++k){ + for(j = 0; j < l.h; ++j){ + for(i = 0; i < l.w; ++i){ + int in_index = i + l.w*(j + l.h*(k + l.c*b)); + + int c2 = k % l.out_c; + int offset = k / l.out_c; + int w2 = i*l.stride + offset % l.stride; + int h2 = j*l.stride + offset / l.stride; + int out_index = w2 + l.out_w*(h2 + l.out_h*(c2 + l.out_c*b)); + state.delta[in_index] = l.delta[out_index]; + } + } + } + } +} + +#ifdef GPU +void forward_reorg_layer_gpu(layer l, network_state state) +{ + reorg_ongpu(state.input, l.w, l.h, l.c, l.batch, l.stride, 1, l.output_gpu); +} + +void backward_reorg_layer_gpu(layer l, network_state state) +{ + reorg_ongpu(l.delta_gpu, l.w, l.h, l.c, l.batch, l.stride, 0, state.delta); +} +#endif diff --git a/src/reorg_layer.h b/src/reorg_layer.h new file mode 100644 index 00000000..659bc7cc --- /dev/null +++ b/src/reorg_layer.h @@ -0,0 +1,20 @@ +#ifndef REORG_LAYER_H +#define REORG_LAYER_H + +#include "image.h" +#include "cuda.h" +#include "layer.h" +#include "network.h" + +layer make_reorg_layer(int batch, int h, int w, int c, int stride); +void resize_reorg_layer(layer *l, int w, int h); +void forward_reorg_layer(const layer l, network_state state); +void backward_reorg_layer(const layer l, network_state state); + +#ifdef GPU +void forward_reorg_layer_gpu(layer l, network_state state); +void backward_reorg_layer_gpu(layer l, network_state state); +#endif + +#endif + diff --git a/src/super.c b/src/super.c new file mode 100644 index 00000000..67b941ff --- /dev/null +++ b/src/super.c @@ -0,0 +1,132 @@ +#include "network.h" +#include "cost_layer.h" +#include "utils.h" +#include "parser.h" + +#ifdef OPENCV +#include "opencv2/highgui/highgui_c.h" +#endif + +void train_super(char *cfgfile, char *weightfile) +{ + char *train_images = "/data/imagenet/imagenet1k.train.list"; + char *backup_directory = "/home/pjreddie/backup/"; + srand(time(0)); + data_seed = time(0); + char *base = basecfg(cfgfile); + printf("%s\n", base); + float avg_loss = -1; + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + int imgs = net.batch*net.subdivisions; + int i = *net.seen/imgs; + data train, buffer; + + + list *plist = get_paths(train_images); + //int N = plist->size; + char **paths = (char **)list_to_array(plist); + + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.scale = 4; + args.paths = paths; + args.n = imgs; + args.m = plist->size; + args.d = &buffer; + args.type = SUPER_DATA; + + pthread_t load_thread = load_data_in_thread(args); + clock_t time; + //while(i*imgs < N*120){ + while(get_current_batch(net) < net.max_batches){ + i += 1; + time=clock(); + pthread_join(load_thread, 0); + train = buffer; + load_thread = load_data_in_thread(args); + + printf("Loaded: %lf seconds\n", sec(clock()-time)); + + time=clock(); + float loss = train_network(net, train); + if (avg_loss < 0) avg_loss = loss; + avg_loss = avg_loss*.9 + loss*.1; + + printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs); + if(i%1000==0){ + char buff[256]; + sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); + save_weights(net, buff); + } + if(i%100==0){ + char buff[256]; + sprintf(buff, "%s/%s.backup", backup_directory, base); + save_weights(net, buff); + } + free_data(train); + } + char buff[256]; + sprintf(buff, "%s/%s_final.weights", backup_directory, base); + save_weights(net, buff); +} + +void test_super(char *cfgfile, char *weightfile, char *filename) +{ + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + set_batch_network(&net, 1); + srand(2222222); + + clock_t time; + char buff[256]; + char *input = buff; + while(1){ + if(filename){ + strncpy(input, filename, 256); + }else{ + printf("Enter Image Path: "); + fflush(stdout); + input = fgets(input, 256, stdin); + if(!input) return; + strtok(input, "\n"); + } + image im = load_image_color(input, 0, 0); + resize_network(&net, im.w, im.h); + printf("%d %d\n", im.w, im.h); + + float *X = im.data; + time=clock(); + network_predict(net, X); + image out = get_network_image(net); + printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); + save_image(out, "out"); + + free_image(im); + if (filename) break; + } +} + + +void run_super(int argc, char **argv) +{ + if(argc < 4){ + fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); + return; + } + + char *cfg = argv[3]; + char *weights = (argc > 4) ? argv[4] : 0; + char *filename = (argc > 5) ? argv[5] : 0; + if(0==strcmp(argv[2], "train")) train_super(cfg, weights); + else if(0==strcmp(argv[2], "test")) test_super(cfg, weights, filename); + /* + else if(0==strcmp(argv[2], "valid")) validate_super(cfg, weights); + */ +} diff --git a/src/utils.c b/src/utils.c index 90af5cf7..73863057 100644 --- a/src/utils.c +++ b/src/utils.c @@ -521,6 +521,11 @@ int max_index(float *a, int n) int rand_int(int min, int max) { + if (max < min){ + int s = min; + min = max; + max = s; + } int r = (rand()%(max - min + 1)) + min; return r; } diff --git a/src/voxel.c b/src/voxel.c new file mode 100644 index 00000000..b41cf77d --- /dev/null +++ b/src/voxel.c @@ -0,0 +1,169 @@ +#include "network.h" +#include "cost_layer.h" +#include "utils.h" +#include "parser.h" + +#ifdef OPENCV +#include "opencv2/highgui/highgui_c.h" +#endif + +void extract_voxel(char *lfile, char *rfile, char *prefix) +{ + int w = 1920; + int h = 1080; +#ifdef OPENCV + int shift = 0; + int count = 0; + CvCapture *lcap = cvCaptureFromFile(lfile); + CvCapture *rcap = cvCaptureFromFile(rfile); + while(1){ + image l = get_image_from_stream(lcap); + image r = get_image_from_stream(rcap); + if(!l.w || !r.w) break; + if(count%100 == 0) { + shift = best_3d_shift_r(l, r, -l.h/100, l.h/100); + printf("%d\n", shift); + } + image ls = crop_image(l, (l.w - w)/2, (l.h - h)/2, w, h); + image rs = crop_image(r, 105 + (r.w - w)/2, (r.h - h)/2 + shift, w, h); + char buff[256]; + sprintf(buff, "%s_%05d_l", prefix, count); + save_image(ls, buff); + sprintf(buff, "%s_%05d_r", prefix, count); + save_image(rs, buff); + free_image(l); + free_image(r); + free_image(ls); + free_image(rs); + ++count; + } + +#else +printf("need OpenCV for extraction\n"); +#endif +} + +void train_voxel(char *cfgfile, char *weightfile) +{ + char *train_images = "/data/imagenet/imagenet1k.train.list"; + char *backup_directory = "/home/pjreddie/backup/"; + srand(time(0)); + data_seed = time(0); + char *base = basecfg(cfgfile); + printf("%s\n", base); + float avg_loss = -1; + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + int imgs = net.batch*net.subdivisions; + int i = *net.seen/imgs; + data train, buffer; + + + list *plist = get_paths(train_images); + //int N = plist->size; + char **paths = (char **)list_to_array(plist); + + load_args args = {0}; + args.w = net.w; + args.h = net.h; + args.scale = 4; + args.paths = paths; + args.n = imgs; + args.m = plist->size; + args.d = &buffer; + args.type = SUPER_DATA; + + pthread_t load_thread = load_data_in_thread(args); + clock_t time; + //while(i*imgs < N*120){ + while(get_current_batch(net) < net.max_batches){ + i += 1; + time=clock(); + pthread_join(load_thread, 0); + train = buffer; + load_thread = load_data_in_thread(args); + + printf("Loaded: %lf seconds\n", sec(clock()-time)); + + time=clock(); + float loss = train_network(net, train); + if (avg_loss < 0) avg_loss = loss; + avg_loss = avg_loss*.9 + loss*.1; + + printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs); + if(i%1000==0){ + char buff[256]; + sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); + save_weights(net, buff); + } + if(i%100==0){ + char buff[256]; + sprintf(buff, "%s/%s.backup", backup_directory, base); + save_weights(net, buff); + } + free_data(train); + } + char buff[256]; + sprintf(buff, "%s/%s_final.weights", backup_directory, base); + save_weights(net, buff); +} + +void test_voxel(char *cfgfile, char *weightfile, char *filename) +{ + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + set_batch_network(&net, 1); + srand(2222222); + + clock_t time; + char buff[256]; + char *input = buff; + while(1){ + if(filename){ + strncpy(input, filename, 256); + }else{ + printf("Enter Image Path: "); + fflush(stdout); + input = fgets(input, 256, stdin); + if(!input) return; + strtok(input, "\n"); + } + image im = load_image_color(input, 0, 0); + resize_network(&net, im.w, im.h); + printf("%d %d\n", im.w, im.h); + + float *X = im.data; + time=clock(); + network_predict(net, X); + image out = get_network_image(net); + printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); + save_image(out, "out"); + + free_image(im); + if (filename) break; + } +} + + +void run_voxel(int argc, char **argv) +{ + if(argc < 4){ + fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); + return; + } + + char *cfg = argv[3]; + char *weights = (argc > 4) ? argv[4] : 0; + char *filename = (argc > 5) ? argv[5] : 0; + if(0==strcmp(argv[2], "train")) train_voxel(cfg, weights); + else if(0==strcmp(argv[2], "test")) test_voxel(cfg, weights, filename); + else if(0==strcmp(argv[2], "extract")) extract_voxel(argv[3], argv[4], argv[5]); + /* + else if(0==strcmp(argv[2], "valid")) validate_voxel(cfg, weights); + */ +}