diff --git a/src/cnn.c b/src/cnn.c index 24b2b110..188cf14e 100644 --- a/src/cnn.c +++ b/src/cnn.c @@ -355,6 +355,16 @@ void train_cifar10() free_data(train); } +void compare_nist(char *p1,char *p2) +{ + srand(222222); + network n1 = parse_network_cfg(p1); + network n2 = parse_network_cfg(p2); + data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10); + normalize_data_rows(test); + compare_networks(n1, n2, test); +} + void test_nist(char *path) { srand(222222); @@ -377,7 +387,7 @@ void train_nist(char *cfgfile) normalize_data_rows(test); int count = 0; int iters = 60000/net.batch + 1; - while(++count <= 200){ + while(++count <= 10){ clock_t start = clock(), end; float loss = train_network_sgd(net, train, iters); end = clock(); @@ -625,6 +635,11 @@ int main(int argc, char **argv) else if(0==strcmp(argv[1], "visualize")) test_visualize(argv[2]); else if(0==strcmp(argv[1], "valid")) validate_imagenet(argv[2]); else if(0==strcmp(argv[1], "testnist")) test_nist(argv[2]); + else if(argc < 4){ + fprintf(stderr, "usage: %s \n", argv[0]); + return 0; + } + else if(0==strcmp(argv[1], "compare")) compare_nist(argv[2], argv[3]); fprintf(stderr, "Success!\n"); return 0; } diff --git a/src/network.c b/src/network.c index 829bb6ed..ac166a6f 100644 --- a/src/network.c +++ b/src/network.c @@ -645,6 +645,28 @@ void print_network(network net) } } +void compare_networks(network n1, network n2, data test) +{ + matrix g1 = network_predict_data(n1, test); + matrix g2 = network_predict_data(n2, test); + int i; + int a,b,c,d; + a = b = c = d = 0; + for(i = 0; i < g1.rows; ++i){ + int truth = max_index(test.y.vals[i], test.y.cols); + int p1 = max_index(g1.vals[i], g1.cols); + int p2 = max_index(g2.vals[i], g2.cols); + if(p1 == truth){ + if(p2 == truth) ++d; + else ++c; + }else{ + if(p2 == truth) ++b; + else ++a; + } + } + printf("%5d %5d\n%5d %5d\n", a, b, c, d); +} + float network_accuracy(network net, data d) { matrix guess = network_predict_data(net, d); diff --git a/src/network.h b/src/network.h index 6eb75451..7a401bdf 100644 --- a/src/network.h +++ b/src/network.h @@ -40,6 +40,8 @@ float train_network_datum_gpu(network net, float *x, float *y); float *network_predict_gpu(network net, float *input); #endif +void compare_networks(network n1, network n2, data d); + network make_network(int n, int batch); void forward_network(network net, float *input, float *truth, int train); void backward_network(network net, float *input); diff --git a/src/utils.c b/src/utils.c index 365faebc..682d3043 100644 --- a/src/utils.c +++ b/src/utils.c @@ -243,6 +243,7 @@ void scale_array(float *a, int n, float s) a[i] *= s; } } + int max_index(float *a, int n) { if(n <= 0) return -1;