mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Stable on MNIST, about to change a lot
This commit is contained in:
@ -187,6 +187,24 @@ double train_network_sgd(network net, data d, int n, double step, double momentu
|
||||
}
|
||||
return (double)correct/n;
|
||||
}
|
||||
double train_network_batch(network net, data d, int n, double step, double momentum,double decay)
|
||||
{
|
||||
int i;
|
||||
int correct = 0;
|
||||
for(i = 0; i < n; ++i){
|
||||
int index = rand()%d.X.rows;
|
||||
double *x = d.X.vals[index];
|
||||
double *y = d.y.vals[index];
|
||||
forward_network(net, x);
|
||||
int class = get_predicted_class_network(net);
|
||||
backward_network(net, x, y);
|
||||
correct += (y[class]?1:0);
|
||||
}
|
||||
update_network(net, step, momentum, decay);
|
||||
return (double)correct/n;
|
||||
|
||||
}
|
||||
|
||||
|
||||
void train_network(network net, data d, double step, double momentum, double decay)
|
||||
{
|
||||
|
Reference in New Issue
Block a user