add SGDR policy

This commit is contained in:
Josh Veitch-Michaelis
2019-03-18 23:26:04 +00:00
parent 8bcba6c105
commit d64693eb77
3 changed files with 14 additions and 1 deletions

View File

@ -117,6 +117,12 @@ float get_current_rate(network net)
return net.learning_rate * pow(rand_uniform(0,1), net.power);
case SIG:
return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step))));
case SGDR:
rate = net.learning_rate_min +
0.5*(net.learning_rate_max-net.learning_rate_min)
* (1. + cos( (float) (batch_num % net.batches_per_cycle)*3.14159265 / net.batches_per_cycle));
return rate;
default:
fprintf(stderr, "Policy is weird!\n");
return net.learning_rate;