This commit is contained in:
Joseph Redmon 2016-03-16 04:44:44 -07:00
parent cff59ba135
commit 67794a52a1

View File

@ -176,7 +176,7 @@ void flip_board(float *board)
} }
} }
void test_go(char *filename, char *weightfile) void test_go(char *filename, char *weightfile, int multi)
{ {
network net = parse_network_cfg(filename); network net = parse_network_cfg(filename);
if(weightfile){ if(weightfile){
@ -191,25 +191,25 @@ void test_go(char *filename, char *weightfile)
float *output = network_predict(net, board); float *output = network_predict(net, board);
copy_cpu(19*19, output, 1, move, 1); copy_cpu(19*19, output, 1, move, 1);
int i; int i;
#ifdef GPU if(multi){
image bim = float_to_image(19, 19, 1, board); image bim = float_to_image(19, 19, 1, board);
for(i = 1; i < 8; ++i){ for(i = 1; i < 8; ++i){
rotate_image_cw(bim, i); rotate_image_cw(bim, i);
if(i >= 4) flip_image(bim); if(i >= 4) flip_image(bim);
float *output = network_predict(net, board); float *output = network_predict(net, board);
image oim = float_to_image(19, 19, 1, output); image oim = float_to_image(19, 19, 1, output);
if(i >= 4) flip_image(oim); if(i >= 4) flip_image(oim);
rotate_image_cw(oim, -i); rotate_image_cw(oim, -i);
axpy_cpu(19*19, 1, output, 1, move, 1); axpy_cpu(19*19, 1, output, 1, move, 1);
if(i >= 4) flip_image(bim); if(i >= 4) flip_image(bim);
rotate_image_cw(bim, -i); rotate_image_cw(bim, -i);
}
scal_cpu(19*19, 1./8., move, 1);
} }
scal_cpu(19*19, 1./8., move, 1);
#endif
for(i = 0; i < 19*19; ++i){ for(i = 0; i < 19*19; ++i){
if(board[i]) move[i] = 0; if(board[i]) move[i] = 0;
} }
@ -282,8 +282,9 @@ void run_go(int argc, char **argv)
char *cfg = argv[3]; char *cfg = argv[3];
char *weights = (argc > 4) ? argv[4] : 0; char *weights = (argc > 4) ? argv[4] : 0;
int multi = find_arg(argc, argv, "-multi");
if(0==strcmp(argv[2], "train")) train_go(cfg, weights); if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
else if(0==strcmp(argv[2], "test")) test_go(cfg, weights); else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
} }