Some fixes

This commit is contained in:
AlexeyAB
2018-02-03 15:35:13 +03:00
parent 063a23f637
commit db3a3c54c5
4 changed files with 11 additions and 7 deletions

View File

@ -50,6 +50,7 @@ float get_current_rate(network net)
int batch_num = get_current_batch(net);
int i;
float rate;
if (batch_num < net.burn_in) return net.learning_rate * pow((float)batch_num / net.burn_in, net.power);
switch (net.policy) {
case CONSTANT:
return net.learning_rate;
@ -66,8 +67,9 @@ float get_current_rate(network net)
case EXP:
return net.learning_rate * pow(net.gamma, batch_num);
case POLY:
if (batch_num < net.burn_in) return net.learning_rate * pow((float)batch_num / net.burn_in, net.power);
return net.learning_rate * pow(1 - (float)batch_num / net.max_batches, net.power);
return net.learning_rate * pow(1 - (float)batch_num / net.max_batches, net.power);
//if (batch_num < net.burn_in) return net.learning_rate * pow((float)batch_num / net.burn_in, net.power);
//return net.learning_rate * pow(1 - (float)batch_num / net.max_batches, net.power);
case RANDOM:
return net.learning_rate * pow(rand_uniform(0,1), net.power);
case SIG: