it's raining really hard outside :-( :rain: :storm: ☁️

This commit is contained in:
Joseph Redmon
2017-10-17 11:41:34 -07:00
parent 532c6e1481
commit cd5d393b46
27 changed files with 1340 additions and 1669 deletions

View File

@@ -124,7 +124,7 @@ void train_go(char *cfgfile, char *weightfile, char *filename, int *gpus, int ng
char *base = basecfg(cfgfile);
printf("%s\n", base);
printf("%d\n", ngpus);
network *nets = calloc(ngpus, sizeof(network));
network **nets = calloc(ngpus, sizeof(network*));
srand(time(0));
int seed = rand();
@@ -134,10 +134,10 @@ void train_go(char *cfgfile, char *weightfile, char *filename, int *gpus, int ng
cuda_set_device(gpus[i]);
#endif
nets[i] = load_network(cfgfile, weightfile, clear);
nets[i].learning_rate *= ngpus;
nets[i]->learning_rate *= ngpus;
}
network net = nets[0];
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
network *net = nets[0];
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
char *backup_directory = "/home/pjreddie/backup/";
@@ -147,11 +147,11 @@ void train_go(char *cfgfile, char *weightfile, char *filename, int *gpus, int ng
int N = m.n;
printf("Moves: %d\n", N);
int epoch = (*net.seen)/N;
while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
int epoch = (*net->seen)/N;
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
clock_t time=clock();
data train = random_go_moves(m, net.batch*net.subdivisions*ngpus);
data train = random_go_moves(m, net->batch*net->subdivisions*ngpus);
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
@@ -169,9 +169,9 @@ void train_go(char *cfgfile, char *weightfile, char *filename, int *gpus, int ng
if(avg_loss == -1) avg_loss = loss;
avg_loss = avg_loss*.95 + loss*.05;
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
if(*net.seen/N > epoch){
epoch = *net.seen/N;
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net->seen);
if(*net->seen/N > epoch){
epoch = *net->seen/N;
char buff[256];
sprintf(buff, "%s/%s_%d.weights", backup_directory,base, epoch);
save_weights(net, buff);
@@ -281,7 +281,7 @@ void flip_board(float *board)
}
}
void predict_move(network net, float *board, float *move, int multi)
void predict_move(network *net, float *board, float *move, int multi)
{
float *output = network_predict(net, board);
copy_cpu(19*19+1, output, 1, move, 1);
@@ -370,7 +370,7 @@ int legal_go(float *b, char *ko, int p, int r, int c)
return 1;
}
int generate_move(network net, int player, float *board, int multi, float thresh, float temp, char *ko, int print)
int generate_move(network *net, int player, float *board, int multi, float thresh, float temp, char *ko, int print)
{
int i, j;
int empty = 1;
@@ -383,7 +383,7 @@ int generate_move(network net, int player, float *board, int multi, float thresh
if(empty) {
return 72;
}
for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
for(i = 0; i < net->n; ++i) net->layers[i].temperature = temp;
float move[362];
if (player < 0) flip_board(board);
@@ -439,12 +439,9 @@ void valid_go(char *cfgfile, char *weightfile, int multi, char *filename)
srand(time(0));
char *base = basecfg(cfgfile);
printf("%s\n", base);
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
set_batch_network(&net, 1);
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
network *net = load_network(cfgfile, weightfile, 0);
set_batch_network(net, 1);
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
float *board = calloc(19*19, sizeof(float));
float *move = calloc(19*19+1, sizeof(float));
@@ -486,12 +483,9 @@ int print_game(float *board, FILE *fp)
void engine_go(char *filename, char *weightfile, int multi)
{
network net = parse_network_cfg(filename);
if(weightfile){
load_weights(&net, weightfile);
}
network *net = load_network(filename, weightfile, 0);
set_batch_network(net, 1);
srand(time(0));
set_batch_network(&net, 1);
float *board = calloc(19*19, sizeof(float));
char *one = calloc(91, sizeof(char));
char *two = calloc(91, sizeof(char));
@@ -679,12 +673,9 @@ void engine_go(char *filename, char *weightfile, int multi)
void test_go(char *cfg, char *weights, int multi)
{
network net = parse_network_cfg(cfg);
if(weights){
load_weights(&net, weights);
}
network *net = load_network(cfg, weights, 0);
set_batch_network(net, 1);
srand(time(0));
set_batch_network(&net, 1);
float *board = calloc(19*19, sizeof(float));
float *move = calloc(19*19+1, sizeof(float));
int color = 1;
@@ -785,23 +776,24 @@ float score_game(float *board)
void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
{
network net = parse_network_cfg(filename);
if(weightfile){
load_weights(&net, weightfile);
}
network *net = load_network(filename, weightfile, 0);
set_batch_network(net, 1);
network net2 = net;
if(f2){
network *net2;
if (f2) {
net2 = parse_network_cfg(f2);
if(w2){
load_weights(&net2, w2);
load_weights(net2, w2);
}
} else {
net2 = calloc(1, sizeof(network));
*net2 = *net;
}
srand(time(0));
char boards[600][93];
int count = 0;
set_batch_network(&net, 1);
set_batch_network(&net2, 1);
set_batch_network(net, 1);
set_batch_network(net2, 1);
float *board = calloc(19*19, sizeof(float));
char *one = calloc(91, sizeof(char));
char *two = calloc(91, sizeof(char));
@@ -819,15 +811,15 @@ void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
fprintf(stderr, "Total: %d, Player 1: %f, Player 2: %f\n", total, (float)p1/total, (float)p2/total);
sleep(1);
/*
int i = (score > 0)? 0 : 1;
int j;
for(; i < count; i += 2){
for(j = 0; j < 93; ++j){
printf("%c", boards[i][j]);
}
printf("\n");
}
*/
int i = (score > 0)? 0 : 1;
int j;
for(; i < count; i += 2){
for(j = 0; j < 93; ++j){
printf("%c", boards[i][j]);
}
printf("\n");
}
*/
memset(board, 0, 19*19*sizeof(float));
player = 1;
done = 0;
@@ -837,7 +829,7 @@ void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
}
print_board(stderr, board, 1, 0);
//sleep(1);
network use = ((total%2==0) == (player==1)) ? net : net2;
network *use = ((total%2==0) == (player==1)) ? net : net2;
int index = generate_move(use, player, board, multi, .4, 1, two, 0);
if(index < 0){
done = 1;