mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
working stuffs
This commit is contained in:
parent
ecbeec86bf
commit
47914146d9
17
src/cnn.c
17
src/cnn.c
@ -355,6 +355,16 @@ void train_cifar10()
|
||||
free_data(train);
|
||||
}
|
||||
|
||||
void compare_nist(char *p1,char *p2)
|
||||
{
|
||||
srand(222222);
|
||||
network n1 = parse_network_cfg(p1);
|
||||
network n2 = parse_network_cfg(p2);
|
||||
data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
|
||||
normalize_data_rows(test);
|
||||
compare_networks(n1, n2, test);
|
||||
}
|
||||
|
||||
void test_nist(char *path)
|
||||
{
|
||||
srand(222222);
|
||||
@ -377,7 +387,7 @@ void train_nist(char *cfgfile)
|
||||
normalize_data_rows(test);
|
||||
int count = 0;
|
||||
int iters = 60000/net.batch + 1;
|
||||
while(++count <= 200){
|
||||
while(++count <= 10){
|
||||
clock_t start = clock(), end;
|
||||
float loss = train_network_sgd(net, train, iters);
|
||||
end = clock();
|
||||
@ -625,6 +635,11 @@ int main(int argc, char **argv)
|
||||
else if(0==strcmp(argv[1], "visualize")) test_visualize(argv[2]);
|
||||
else if(0==strcmp(argv[1], "valid")) validate_imagenet(argv[2]);
|
||||
else if(0==strcmp(argv[1], "testnist")) test_nist(argv[2]);
|
||||
else if(argc < 4){
|
||||
fprintf(stderr, "usage: %s <function> <filename> <filename>\n", argv[0]);
|
||||
return 0;
|
||||
}
|
||||
else if(0==strcmp(argv[1], "compare")) compare_nist(argv[2], argv[3]);
|
||||
fprintf(stderr, "Success!\n");
|
||||
return 0;
|
||||
}
|
||||
|
@ -645,6 +645,28 @@ void print_network(network net)
|
||||
}
|
||||
}
|
||||
|
||||
void compare_networks(network n1, network n2, data test)
|
||||
{
|
||||
matrix g1 = network_predict_data(n1, test);
|
||||
matrix g2 = network_predict_data(n2, test);
|
||||
int i;
|
||||
int a,b,c,d;
|
||||
a = b = c = d = 0;
|
||||
for(i = 0; i < g1.rows; ++i){
|
||||
int truth = max_index(test.y.vals[i], test.y.cols);
|
||||
int p1 = max_index(g1.vals[i], g1.cols);
|
||||
int p2 = max_index(g2.vals[i], g2.cols);
|
||||
if(p1 == truth){
|
||||
if(p2 == truth) ++d;
|
||||
else ++c;
|
||||
}else{
|
||||
if(p2 == truth) ++b;
|
||||
else ++a;
|
||||
}
|
||||
}
|
||||
printf("%5d %5d\n%5d %5d\n", a, b, c, d);
|
||||
}
|
||||
|
||||
float network_accuracy(network net, data d)
|
||||
{
|
||||
matrix guess = network_predict_data(net, d);
|
||||
|
@ -40,6 +40,8 @@ float train_network_datum_gpu(network net, float *x, float *y);
|
||||
float *network_predict_gpu(network net, float *input);
|
||||
#endif
|
||||
|
||||
void compare_networks(network n1, network n2, data d);
|
||||
|
||||
network make_network(int n, int batch);
|
||||
void forward_network(network net, float *input, float *truth, int train);
|
||||
void backward_network(network net, float *input);
|
||||
|
@ -243,6 +243,7 @@ void scale_array(float *a, int n, float s)
|
||||
a[i] *= s;
|
||||
}
|
||||
}
|
||||
|
||||
int max_index(float *a, int n)
|
||||
{
|
||||
if(n <= 0) return -1;
|
||||
|
Loading…
Reference in New Issue
Block a user