diff --git a/Makefile b/Makefile index 415c5226..dc08b468 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ CC=gcc COMMON=-Wall `pkg-config --cflags opencv` -CFLAGS= $(COMMON) -O3 -ffast-math -flto +CFLAGS= $(COMMON) -Ofast -ffast-math -flto UNAME = $(shell uname) ifeq ($(UNAME), Darwin) COMMON += -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include @@ -10,12 +10,13 @@ endif #CFLAGS= $(COMMON) -O0 -g LDFLAGS=`pkg-config --libs opencv` -lm VPATH=./src/ +EXEC=cnn -OBJ=network.o image.o tests.o convolutional_layer.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o +OBJ=network.o image.o tests.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o convolutional_layer.o -all: cnn +all: $(EXEC) -cnn: $(OBJ) +$(EXEC): $(OBJ) $(CC) $(CFLAGS) $(LDFLAGS) $^ -o $@ %.o: %.c @@ -24,5 +25,5 @@ cnn: $(OBJ) .PHONY: clean clean: - rm -rf $(OBJ) cnn + rm -rf $(OBJ) $(EXEC) diff --git a/nist_basic.cfg b/nist_basic.cfg index 3b55166b..f5ea0a38 100644 --- a/nist_basic.cfg +++ b/nist_basic.cfg @@ -1,7 +1,11 @@ -[conn] -input=784 -output = 100 -activation=ramp +[conv] +width=28 +height=28 +channels=1 +filters=20 +size=5 +stride=1 +activation=linear [conn] output = 10 diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index ef48120f..53eb7bf1 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -1,17 +1,13 @@ #include "convolutional_layer.h" #include "utils.h" +#include "mini_blas.h" #include image get_convolutional_image(convolutional_layer layer) { int h,w,c; - if(layer.edge){ - h = (layer.h-1)/layer.stride + 1; - w = (layer.w-1)/layer.stride + 1; - }else{ - h = (layer.h - layer.size)/layer.stride+1; - w = (layer.h - layer.size)/layer.stride+1; - } + h = layer.out_h; + w = layer.out_w; c = layer.n; return double_to_image(h,w,c,layer.output); } @@ -19,13 +15,8 @@ image get_convolutional_image(convolutional_layer layer) image get_convolutional_delta(convolutional_layer layer) { int h,w,c; - if(layer.edge){ - h = (layer.h-1)/layer.stride + 1; - w = (layer.w-1)/layer.stride + 1; - }else{ - h = (layer.h - layer.size)/layer.stride+1; - w = (layer.h - layer.size)/layer.stride+1; - } + h = layer.out_h; + w = layer.out_w; c = layer.n; return double_to_image(h,w,c,layer.delta); } @@ -34,74 +25,114 @@ convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int si { int i; int out_h,out_w; + size = 2*(size/2)+1; //HA! And you thought you'd use an even sized filter... convolutional_layer *layer = calloc(1, sizeof(convolutional_layer)); layer->h = h; layer->w = w; layer->c = c; layer->n = n; - layer->edge = 0; layer->stride = stride; - layer->kernels = calloc(n, sizeof(image)); - layer->kernel_updates = calloc(n, sizeof(image)); - layer->kernel_momentum = calloc(n, sizeof(image)); + layer->size = size; + + layer->filters = calloc(c*n*size*size, sizeof(double)); + layer->filter_updates = calloc(c*n*size*size, sizeof(double)); + layer->filter_momentum = calloc(c*n*size*size, sizeof(double)); + layer->biases = calloc(n, sizeof(double)); layer->bias_updates = calloc(n, sizeof(double)); layer->bias_momentum = calloc(n, sizeof(double)); double scale = 2./(size*size); + for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = rand_normal()*scale; for(i = 0; i < n; ++i){ //layer->biases[i] = rand_normal()*scale + scale; layer->biases[i] = 0; - layer->kernels[i] = make_random_kernel(size, c, scale); - layer->kernel_updates[i] = make_random_kernel(size, c, 0); - layer->kernel_momentum[i] = make_random_kernel(size, c, 0); } - layer->size = 2*(size/2)+1; - if(layer->edge){ - out_h = (layer->h-1)/layer->stride + 1; - out_w = (layer->w-1)/layer->stride + 1; - }else{ - out_h = (layer->h - layer->size)/layer->stride+1; - out_w = (layer->h - layer->size)/layer->stride+1; - } - fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n); + out_h = (h-size)/stride + 1; + out_w = (w-size)/stride + 1; + + layer->col_image = calloc(out_h*out_w*size*size*c, sizeof(double)); layer->output = calloc(out_h * out_w * n, sizeof(double)); layer->delta = calloc(out_h * out_w * n, sizeof(double)); - layer->upsampled = make_image(h,w,n); layer->activation = activation; + layer->out_h = out_h; + layer->out_w = out_w; + + fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n); + srand(0); return layer; } void forward_convolutional_layer(const convolutional_layer layer, double *in) { - image input = double_to_image(layer.h, layer.w, layer.c, in); - image output = get_convolutional_image(layer); - int i,j; - for(i = 0; i < layer.n; ++i){ - convolve(input, layer.kernels[i], layer.stride, i, output, layer.edge); - } - for(i = 0; i < output.c; ++i){ - for(j = 0; j < output.h*output.w; ++j){ - int index = i*output.h*output.w + j; - output.data[index] += layer.biases[i]; - output.data[index] = activate(output.data[index], layer.activation); - } - } + int m = layer.n; + int k = layer.size*layer.size*layer.c; + int n = ((layer.h-layer.size)/layer.stride + 1)* + ((layer.w-layer.size)/layer.stride + 1); + + memset(layer.output, 0, m*n*sizeof(double)); + + double *a = layer.filters; + double *b = layer.col_image; + double *c = layer.output; + + im2col_cpu(in, layer.c, layer.h, layer.w, layer.size, layer.stride, b); + gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); + } -void backward_convolutional_layer(convolutional_layer layer, double *input, double *delta) +void gradient_delta_convolutional_layer(convolutional_layer layer) { int i; - - image in_delta = double_to_image(layer.h, layer.w, layer.c, delta); - image out_delta = get_convolutional_delta(layer); - zero_image(in_delta); - - for(i = 0; i < layer.n; ++i){ - back_convolve(in_delta, layer.kernels[i], layer.stride, i, out_delta, layer.edge); + for(i = 0; i < layer.out_h*layer.out_w*layer.n; ++i){ + layer.delta[i] *= gradient(layer.output[i], layer.activation); } } +void learn_bias_convolutional_layer(convolutional_layer layer) +{ + int i,j; + int size = layer.out_h*layer.out_w; + for(i = 0; i < layer.n; ++i){ + double sum = 0; + for(j = 0; j < size; ++j){ + sum += layer.delta[j+i*size]; + } + layer.bias_updates[i] += sum/size; + } +} + +void learn_convolutional_layer(convolutional_layer layer) +{ + gradient_delta_convolutional_layer(layer); + learn_bias_convolutional_layer(layer); + int m = layer.n; + int n = layer.size*layer.size*layer.c; + int k = ((layer.h-layer.size)/layer.stride + 1)* + ((layer.w-layer.size)/layer.stride + 1); + + double *a = layer.delta; + double *b = layer.col_image; + double *c = layer.filter_updates; + + gemm(0,1,m,n,k,1,a,k,b,k,1,c,n); +} + +void update_convolutional_layer(convolutional_layer layer, double step, double momentum, double decay) +{ + int i; + int size = layer.size*layer.size*layer.c*layer.n; + for(i = 0; i < layer.n; ++i){ + layer.biases[i] += step*layer.bias_updates[i]; + layer.bias_updates[i] *= momentum; + } + for(i = 0; i < size; ++i){ + layer.filters[i] += step*(layer.filter_updates[i] - decay*layer.filters[i]); + layer.filter_updates[i] *= momentum; + } +} +/* + void backward_convolutional_layer2(convolutional_layer layer, double *input, double *delta) { image in_delta = double_to_image(layer.h, layer.w, layer.c, delta); @@ -124,15 +155,6 @@ void backward_convolutional_layer2(convolutional_layer layer, double *input, dou } } -void gradient_delta_convolutional_layer(convolutional_layer layer) -{ - int i; - image out_delta = get_convolutional_delta(layer); - image out_image = get_convolutional_image(layer); - for(i = 0; i < out_image.h*out_image.w*out_image.c; ++i){ - out_delta.data[i] *= gradient(out_image.data[i], layer.activation); - } -} void learn_convolutional_layer(convolutional_layer layer, double *input) { @@ -163,8 +185,37 @@ void update_convolutional_layer(convolutional_layer layer, double step, double m zero_image(layer.kernel_updates[i]); } } +*/ -void visualize_convolutional_filters(convolutional_layer layer, char *window) +void test_convolutional_layer() +{ + convolutional_layer l = *make_convolutional_layer(4,4,1,1,3,1,LINEAR); + double input[] = {1,2,3,4, + 5,6,7,8, + 9,10,11,12, + 13,14,15,16}; + double filter[] = {.5, 0, .3, + 0 , 1, 0, + .2 , 0, 1}; + double delta[] = {1, 2, + 3, 4}; + l.filters = filter; + forward_convolutional_layer(l, input); + l.delta = delta; + learn_convolutional_layer(l); + image filter_updates = double_to_image(3,3,1,l.filter_updates); + print_image(filter_updates); +} + +image get_convolutional_filter(convolutional_layer layer, int i) +{ + int h = layer.size; + int w = layer.size; + int c = layer.c; + return double_to_image(h,w,c,layer.filters+i*h*w*c); +} + +void visualize_convolutional_layer(convolutional_layer layer, char *window) { int color = 1; int border = 1; @@ -172,7 +223,7 @@ void visualize_convolutional_filters(convolutional_layer layer, char *window) int size = layer.size; h = size; w = (size + border) * layer.n - border; - c = layer.kernels[0].c; + c = layer.c; if(c != 3 || !color){ h = (h+border)*c - border; c = 1; @@ -182,11 +233,13 @@ void visualize_convolutional_filters(convolutional_layer layer, char *window) int i,j; for(i = 0; i < layer.n; ++i){ int w_offset = i*(size+border); - image k = layer.kernels[i]; + image k = get_convolutional_filter(layer, i); + //printf("%f ** ", layer.biases[i]); + //print_image(k); image copy = copy_image(k); normalize_image(copy); for(j = 0; j < k.c; ++j){ - set_pixel(copy,0,0,j,layer.biases[i]); + //set_pixel(copy,0,0,j,layer.biases[i]); } if(c == 3 && color){ embed_image(copy, filters, 0, w_offset); @@ -211,15 +264,3 @@ void visualize_convolutional_filters(convolutional_layer layer, char *window) free_image(filters); } -void visualize_convolutional_layer(convolutional_layer layer) -{ - int i; - char buff[256]; - for(i = 0; i < layer.n; ++i){ - image k = layer.kernels[i]; - sprintf(buff, "Kernel %d", i); - if(k.c <= 3) show_image(k, buff); - else show_image_layers(k, buff); - } -} - diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index 135d9832..e2e6cdc4 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -6,36 +6,40 @@ typedef struct { int h,w,c; + int out_h, out_w, out_c; int n; int size; int stride; - image *kernels; - image *kernel_updates; - image *kernel_momentum; + double *filters; + double *filter_updates; + double *filter_momentum; + double *biases; double *bias_updates; double *bias_momentum; - image upsampled; + + double *col_image; double *delta; double *output; ACTIVATION activation; - int edge; } convolutional_layer; convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int size, int stride, ACTIVATION activation); void forward_convolutional_layer(const convolutional_layer layer, double *in); -void backward_convolutional_layer(convolutional_layer layer, double *input, double *delta); -void learn_convolutional_layer(convolutional_layer layer, double *input); - +void learn_convolutional_layer(convolutional_layer layer); void update_convolutional_layer(convolutional_layer layer, double step, double momentum, double decay); +void visualize_convolutional_layer(convolutional_layer layer, char *window); -void backpropagate_convolutional_layer_convolve(image input, convolutional_layer layer); -void visualize_convolutional_filters(convolutional_layer layer, char *window); -void visualize_convolutional_layer(convolutional_layer layer); +//void backward_convolutional_layer(convolutional_layer layer, double *input, double *delta); + +//void backpropagate_convolutional_layer_convolve(image input, convolutional_layer layer); +//void visualize_convolutional_filters(convolutional_layer layer, char *window); +//void visualize_convolutional_layer(convolutional_layer layer); image get_convolutional_image(convolutional_layer layer); image get_convolutional_delta(convolutional_layer layer); +image get_convolutional_filter(convolutional_layer layer, int i); #endif diff --git a/src/mini_blas.c b/src/mini_blas.c index b15ba8e6..3af36e5c 100644 --- a/src/mini_blas.c +++ b/src/mini_blas.c @@ -1,16 +1,44 @@ +#include +#include + +void pm(int M, int N, double *A) +{ + int i,j; + for(i =0 ; i < M; ++i){ + for(j = 0; j < N; ++j){ + printf("%10.6f, ", A[i*N+j]); + } + printf("\n"); + } + printf("\n"); +} + void gemm(int TA, int TB, int M, int N, int K, double ALPHA, double *A, int lda, double *B, int ldb, double BETA, double *C, int ldc) { - // Assume TA = TB = 0, beta = 1 LULZ + // Assume TA = 0, beta = 1 LULZ int i,j,k; - for(i = 0; i < M; ++i){ - for(k = 0; k < K; ++k){ + if(TB && !TA){ + for(i = 0; i < M; ++i){ for(j = 0; j < N; ++j){ - C[i*ldc+j] += ALPHA*A[i*lda+k]*B[k*ldb+j]; + register double sum = 0; + for(k = 0; k < K; ++k){ + sum += ALPHA*A[i*lda+k]*B[k+j*ldb]; + } + C[i*ldc+j] += sum; + } + } + }else{ + for(i = 0; i < M; ++i){ + for(k = 0; k < K; ++k){ + register double A_PART = ALPHA*A[i*lda+k]; + for(j = 0; j < N; ++j){ + C[i*ldc+j] += A_PART*B[k*ldb+j]; + } } } } @@ -59,7 +87,7 @@ void im2col(double *image, int h, int w, int c, int size, int stride, double *ma void im2col_cpu(double* data_im, const int channels, const int height, const int width, const int ksize, const int stride, double* data_col) - { +{ int c,h,w; int height_col = (height - ksize) / stride + 1; int width_col = (width - ksize) / stride + 1; diff --git a/src/mini_blas.h b/src/mini_blas.h index cdf3a86e..46a37d3d 100644 --- a/src/mini_blas.h +++ b/src/mini_blas.h @@ -1,3 +1,4 @@ +void pm(int M, int N, double *A); void gemm(int TA, int TB, int M, int N, int K, double ALPHA, double *A, int lda, double *B, int ldb, diff --git a/src/network.c b/src/network.c index 07ac6213..2ce13d83 100644 --- a/src/network.c +++ b/src/network.c @@ -6,6 +6,7 @@ #include "connected_layer.h" #include "convolutional_layer.h" +//#include "old_conv.h" #include "maxpool_layer.h" #include "softmax_layer.h" @@ -113,14 +114,17 @@ double *get_network_delta(network net) return get_network_delta_layer(net, net.n-1); } -void calculate_error_network(network net, double *truth) +double calculate_error_network(network net, double *truth) { + double sum = 0; double *delta = get_network_delta(net); double *out = get_network_output(net); int i, k = get_network_output_size(net); for(i = 0; i < k; ++i){ delta[i] = truth[i] - out[i]; + sum += delta[i]*delta[i]; } + return sum; } int get_predicted_class_network(network net) @@ -130,9 +134,9 @@ int get_predicted_class_network(network net) return max_index(out, k); } -void backward_network(network net, double *input, double *truth) +double backward_network(network net, double *input, double *truth) { - calculate_error_network(net, truth); + double error = calculate_error_network(net, truth); int i; double *prev_input; double *prev_delta; @@ -146,8 +150,9 @@ void backward_network(network net, double *input, double *truth) } if(net.types[i] == CONVOLUTIONAL){ convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - learn_convolutional_layer(layer, prev_input); - if(i != 0) backward_convolutional_layer(layer, prev_input, prev_delta); + learn_convolutional_layer(layer); + //learn_convolutional_layer(layer); + //if(i != 0) backward_convolutional_layer(layer, prev_input, prev_delta); } else if(net.types[i] == MAXPOOL){ maxpool_layer layer = *(maxpool_layer *)net.layers[i]; @@ -163,29 +168,31 @@ void backward_network(network net, double *input, double *truth) if(i != 0) backward_connected_layer(layer, prev_input, prev_delta); } } + return error; } -int train_network_datum(network net, double *x, double *y, double step, double momentum, double decay) +double train_network_datum(network net, double *x, double *y, double step, double momentum, double decay) { forward_network(net, x); int class = get_predicted_class_network(net); - backward_network(net, x, y); + double error = backward_network(net, x, y); update_network(net, step, momentum, decay); - return (y[class]?1:0); + //return (y[class]?1:0); + return error; } double train_network_sgd(network net, data d, int n, double step, double momentum,double decay) { int i; - int correct = 0; + double error = 0; for(i = 0; i < n; ++i){ int index = rand()%d.X.rows; - correct += train_network_datum(net, d.X.vals[index], d.y.vals[index], step, momentum, decay); + error += train_network_datum(net, d.X.vals[index], d.y.vals[index], step, momentum, decay); //if((i+1)%10 == 0){ // printf("%d: %f\n", (i+1), (double)correct/(i+1)); //} } - return (double)correct/n; + return error/n; } double train_network_batch(network net, data d, int n, double step, double momentum,double decay) { @@ -282,7 +289,7 @@ void visualize_network(network net) sprintf(buff, "Layer %d", i); if(net.types[i] == CONVOLUTIONAL){ convolutional_layer layer = *(convolutional_layer *)net.layers[i]; - visualize_convolutional_filters(layer, buff); + visualize_convolutional_layer(layer, buff); } } } diff --git a/src/network.h b/src/network.h index 975c3ddc..fa109dd0 100644 --- a/src/network.h +++ b/src/network.h @@ -22,7 +22,7 @@ typedef struct { network make_network(int n); void forward_network(network net, double *input); -void backward_network(network net, double *input, double *truth); +double backward_network(network net, double *input, double *truth); void update_network(network net, double step, double momentum, double decay); double train_network_sgd(network net, data d, int n, double step, double momentum,double decay); double train_network_batch(network net, data d, int n, double step, double momentum,double decay); diff --git a/src/tests.c b/src/tests.c index c459a362..af22ddb8 100644 --- a/src/tests.c +++ b/src/tests.c @@ -1,4 +1,5 @@ #include "connected_layer.h" +//#include "old_conv.h" #include "convolutional_layer.h" #include "maxpool_layer.h" #include "network.h" @@ -35,7 +36,7 @@ void test_convolve_matrix() printf("dog channels %d\n", dog.c); int size = 11; - int stride = 1; + int stride = 4; int n = 40; double *filters = make_random_image(size, size, dog.c*n).data; @@ -64,29 +65,6 @@ void test_color() show_image_layers(dog, "Test Color"); } -void test_convolutional_layer() -{ - srand(0); - image dog = load_image("dog.jpg"); - int i; - int n = 3; - int stride = 1; - int size = 3; - convolutional_layer layer = *make_convolutional_layer(dog.h, dog.w, dog.c, n, size, stride, RELU); - char buff[256]; - for(i = 0; i < n; ++i) { - sprintf(buff, "Kernel %d", i); - show_image(layer.kernels[i], buff); - } - forward_convolutional_layer(layer, dog.data); - - image output = get_convolutional_image(layer); - maxpool_layer mlayer = *make_maxpool_layer(output.h, output.w, output.c, 2); - forward_maxpool_layer(mlayer, layer.output); - - show_image_layers(get_maxpool_image(mlayer), "Test Maxpool Layer"); -} - void verify_convolutional_layer() { srand(0); @@ -117,7 +95,7 @@ void verify_convolutional_layer() image out_delta = get_convolutional_delta(layer); for(i = 0; i < out.h*out.w*out.c; ++i){ out_delta.data[i] = 1; - backward_convolutional_layer(layer, test.data, in_delta.data); + //backward_convolutional_layer(layer, test.data, in_delta.data); image partial = copy_image(in_delta); jacobian2[i] = partial.data; out_delta.data[i] = 0; @@ -240,16 +218,16 @@ void test_nist() double momentum = .9; double decay = 0.01; clock_t start = clock(), end; - while(++count <= 1000){ - double acc = train_network_sgd(net, train, 6400, lr, momentum, decay); - printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*100, 1.-acc, lr, momentum, decay); + while(++count <= 100){ + visualize_network(net); + double loss = train_network_sgd(net, train, 10000, lr, momentum, decay); + printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*100, loss, lr, momentum, decay); end = clock(); printf("Time: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC); start=end; - //visualize_network(net); - //cvWaitKey(100); + cvWaitKey(100); //lr /= 2; - if(count%5 == 0 && 0){ + if(count%5 == 0){ double train_acc = network_accuracy(net, train); fprintf(stderr, "\nTRAIN: %f\n", train_acc); double test_acc = network_accuracy(net, test); @@ -268,11 +246,9 @@ void test_ensemble() data test = load_categorical_data_csv("mnist/mnist_test.csv", 0,10); normalize_data_rows(test); data train = d; - /* - data *split = split_data(d, 1, 10); - data train = split[0]; - data test = split[1]; - */ + // data *split = split_data(d, 1, 10); + // data train = split[0]; + // data test = split[1]; matrix prediction = make_matrix(test.y.rows, test.y.cols); int n = 30; for(i = 0; i < n; ++i){ @@ -298,22 +274,6 @@ void test_ensemble() printf("Full Ensemble Accuracy: %lf\n", acc); } -void test_kernel_update() -{ - srand(0); - double delta[] = {.1}; - double input[] = {.3, .5, .3, .5, .5, .5, .5, .0, .5}; - double kernel[] = {1,2,3,4,5,6,7,8,9}; - convolutional_layer layer = *make_convolutional_layer(3, 3, 1, 1, 3, 1, LINEAR); - layer.kernels[0].data = kernel; - layer.delta = delta; - learn_convolutional_layer(layer, input); - print_image(layer.kernels[0]); - print_image(get_convolutional_delta(layer)); - print_image(layer.kernel_updates[0]); - -} - void test_random_classify() { network net = parse_network_cfg("connected.cfg"); @@ -380,7 +340,7 @@ double *random_matrix(int rows, int cols) void test_blas() { - int m = 6025, n = 20, k = 11*11*3; + int m = 1000, n = 1000, k = 1000; double *a = random_matrix(m,k); double *b = random_matrix(k,n); double *c = random_matrix(m,n); @@ -405,17 +365,16 @@ void test_im2row() double *matrix = calloc(msize, sizeof(double)); int i; for(i = 0; i < 1000; ++i){ - im2col_cpu(test.data, c, h, w, size, stride, matrix); - image render = double_to_image(mh, mw, mc, matrix); + im2col_cpu(test.data, c, h, w, size, stride, matrix); + image render = double_to_image(mh, mw, mc, matrix); } } int main() { //test_blas(); - //test_convolve_matrix(); -// test_im2row(); - //test_kernel_update(); + //test_convolve_matrix(); + // test_im2row(); //test_split(); //test_ensemble(); test_nist();