mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
working stuffs
This commit is contained in:
parent
ecbeec86bf
commit
47914146d9
17
src/cnn.c
17
src/cnn.c
@ -355,6 +355,16 @@ void train_cifar10()
|
|||||||
free_data(train);
|
free_data(train);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void compare_nist(char *p1,char *p2)
|
||||||
|
{
|
||||||
|
srand(222222);
|
||||||
|
network n1 = parse_network_cfg(p1);
|
||||||
|
network n2 = parse_network_cfg(p2);
|
||||||
|
data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
|
||||||
|
normalize_data_rows(test);
|
||||||
|
compare_networks(n1, n2, test);
|
||||||
|
}
|
||||||
|
|
||||||
void test_nist(char *path)
|
void test_nist(char *path)
|
||||||
{
|
{
|
||||||
srand(222222);
|
srand(222222);
|
||||||
@ -377,7 +387,7 @@ void train_nist(char *cfgfile)
|
|||||||
normalize_data_rows(test);
|
normalize_data_rows(test);
|
||||||
int count = 0;
|
int count = 0;
|
||||||
int iters = 60000/net.batch + 1;
|
int iters = 60000/net.batch + 1;
|
||||||
while(++count <= 200){
|
while(++count <= 10){
|
||||||
clock_t start = clock(), end;
|
clock_t start = clock(), end;
|
||||||
float loss = train_network_sgd(net, train, iters);
|
float loss = train_network_sgd(net, train, iters);
|
||||||
end = clock();
|
end = clock();
|
||||||
@ -625,6 +635,11 @@ int main(int argc, char **argv)
|
|||||||
else if(0==strcmp(argv[1], "visualize")) test_visualize(argv[2]);
|
else if(0==strcmp(argv[1], "visualize")) test_visualize(argv[2]);
|
||||||
else if(0==strcmp(argv[1], "valid")) validate_imagenet(argv[2]);
|
else if(0==strcmp(argv[1], "valid")) validate_imagenet(argv[2]);
|
||||||
else if(0==strcmp(argv[1], "testnist")) test_nist(argv[2]);
|
else if(0==strcmp(argv[1], "testnist")) test_nist(argv[2]);
|
||||||
|
else if(argc < 4){
|
||||||
|
fprintf(stderr, "usage: %s <function> <filename> <filename>\n", argv[0]);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
else if(0==strcmp(argv[1], "compare")) compare_nist(argv[2], argv[3]);
|
||||||
fprintf(stderr, "Success!\n");
|
fprintf(stderr, "Success!\n");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -645,6 +645,28 @@ void print_network(network net)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void compare_networks(network n1, network n2, data test)
|
||||||
|
{
|
||||||
|
matrix g1 = network_predict_data(n1, test);
|
||||||
|
matrix g2 = network_predict_data(n2, test);
|
||||||
|
int i;
|
||||||
|
int a,b,c,d;
|
||||||
|
a = b = c = d = 0;
|
||||||
|
for(i = 0; i < g1.rows; ++i){
|
||||||
|
int truth = max_index(test.y.vals[i], test.y.cols);
|
||||||
|
int p1 = max_index(g1.vals[i], g1.cols);
|
||||||
|
int p2 = max_index(g2.vals[i], g2.cols);
|
||||||
|
if(p1 == truth){
|
||||||
|
if(p2 == truth) ++d;
|
||||||
|
else ++c;
|
||||||
|
}else{
|
||||||
|
if(p2 == truth) ++b;
|
||||||
|
else ++a;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("%5d %5d\n%5d %5d\n", a, b, c, d);
|
||||||
|
}
|
||||||
|
|
||||||
float network_accuracy(network net, data d)
|
float network_accuracy(network net, data d)
|
||||||
{
|
{
|
||||||
matrix guess = network_predict_data(net, d);
|
matrix guess = network_predict_data(net, d);
|
||||||
|
@ -40,6 +40,8 @@ float train_network_datum_gpu(network net, float *x, float *y);
|
|||||||
float *network_predict_gpu(network net, float *input);
|
float *network_predict_gpu(network net, float *input);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
void compare_networks(network n1, network n2, data d);
|
||||||
|
|
||||||
network make_network(int n, int batch);
|
network make_network(int n, int batch);
|
||||||
void forward_network(network net, float *input, float *truth, int train);
|
void forward_network(network net, float *input, float *truth, int train);
|
||||||
void backward_network(network net, float *input);
|
void backward_network(network net, float *input);
|
||||||
|
@ -243,6 +243,7 @@ void scale_array(float *a, int n, float s)
|
|||||||
a[i] *= s;
|
a[i] *= s;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int max_index(float *a, int n)
|
int max_index(float *a, int n)
|
||||||
{
|
{
|
||||||
if(n <= 0) return -1;
|
if(n <= 0) return -1;
|
||||||
|
Loading…
Reference in New Issue
Block a user