working stuffs

This commit is contained in:
Joseph Redmon 2014-12-18 11:28:42 -08:00
parent ecbeec86bf
commit 47914146d9
4 changed files with 41 additions and 1 deletions

View File

@ -355,6 +355,16 @@ void train_cifar10()
free_data(train); free_data(train);
} }
void compare_nist(char *p1,char *p2)
{
srand(222222);
network n1 = parse_network_cfg(p1);
network n2 = parse_network_cfg(p2);
data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
normalize_data_rows(test);
compare_networks(n1, n2, test);
}
void test_nist(char *path) void test_nist(char *path)
{ {
srand(222222); srand(222222);
@ -377,7 +387,7 @@ void train_nist(char *cfgfile)
normalize_data_rows(test); normalize_data_rows(test);
int count = 0; int count = 0;
int iters = 60000/net.batch + 1; int iters = 60000/net.batch + 1;
while(++count <= 200){ while(++count <= 10){
clock_t start = clock(), end; clock_t start = clock(), end;
float loss = train_network_sgd(net, train, iters); float loss = train_network_sgd(net, train, iters);
end = clock(); end = clock();
@ -625,6 +635,11 @@ int main(int argc, char **argv)
else if(0==strcmp(argv[1], "visualize")) test_visualize(argv[2]); else if(0==strcmp(argv[1], "visualize")) test_visualize(argv[2]);
else if(0==strcmp(argv[1], "valid")) validate_imagenet(argv[2]); else if(0==strcmp(argv[1], "valid")) validate_imagenet(argv[2]);
else if(0==strcmp(argv[1], "testnist")) test_nist(argv[2]); else if(0==strcmp(argv[1], "testnist")) test_nist(argv[2]);
else if(argc < 4){
fprintf(stderr, "usage: %s <function> <filename> <filename>\n", argv[0]);
return 0;
}
else if(0==strcmp(argv[1], "compare")) compare_nist(argv[2], argv[3]);
fprintf(stderr, "Success!\n"); fprintf(stderr, "Success!\n");
return 0; return 0;
} }

View File

@ -645,6 +645,28 @@ void print_network(network net)
} }
} }
void compare_networks(network n1, network n2, data test)
{
matrix g1 = network_predict_data(n1, test);
matrix g2 = network_predict_data(n2, test);
int i;
int a,b,c,d;
a = b = c = d = 0;
for(i = 0; i < g1.rows; ++i){
int truth = max_index(test.y.vals[i], test.y.cols);
int p1 = max_index(g1.vals[i], g1.cols);
int p2 = max_index(g2.vals[i], g2.cols);
if(p1 == truth){
if(p2 == truth) ++d;
else ++c;
}else{
if(p2 == truth) ++b;
else ++a;
}
}
printf("%5d %5d\n%5d %5d\n", a, b, c, d);
}
float network_accuracy(network net, data d) float network_accuracy(network net, data d)
{ {
matrix guess = network_predict_data(net, d); matrix guess = network_predict_data(net, d);

View File

@ -40,6 +40,8 @@ float train_network_datum_gpu(network net, float *x, float *y);
float *network_predict_gpu(network net, float *input); float *network_predict_gpu(network net, float *input);
#endif #endif
void compare_networks(network n1, network n2, data d);
network make_network(int n, int batch); network make_network(int n, int batch);
void forward_network(network net, float *input, float *truth, int train); void forward_network(network net, float *input, float *truth, int train);
void backward_network(network net, float *input); void backward_network(network net, float *input);

View File

@ -243,6 +243,7 @@ void scale_array(float *a, int n, float s)
a[i] *= s; a[i] *= s;
} }
} }
int max_index(float *a, int n) int max_index(float *a, int n)
{ {
if(n <= 0) return -1; if(n <= 0) return -1;