mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Midway through lots of fixes, checkpoint
This commit is contained in:
@ -272,7 +272,9 @@ float calculate_error_network(network net, float *truth)
|
||||
for(i = 0; i < get_network_output_size(net)*net.batch; ++i){
|
||||
//if(i %get_network_output_size(net) == 0) printf("\n");
|
||||
//printf("%5.2f %5.2f, ", out[i], truth[i]);
|
||||
//if(i == get_network_output_size(net)) printf("\n");
|
||||
delta[i] = truth[i] - out[i];
|
||||
//printf("%f, ", delta[i]);
|
||||
sum += delta[i]*delta[i];
|
||||
}
|
||||
//printf("\n");
|
||||
@ -382,20 +384,20 @@ float train_network_sgd(network net, data d, int n, float step, float momentum,f
|
||||
}
|
||||
float train_network_batch(network net, data d, int n, float step, float momentum,float decay)
|
||||
{
|
||||
int i;
|
||||
int correct = 0;
|
||||
int i,j;
|
||||
float sum = 0;
|
||||
int batch = 2;
|
||||
for(i = 0; i < n; ++i){
|
||||
int index = rand()%d.X.rows;
|
||||
float *x = d.X.vals[index];
|
||||
float *y = d.y.vals[index];
|
||||
forward_network(net, x, 1);
|
||||
int class = get_predicted_class_network(net);
|
||||
backward_network(net, x, y);
|
||||
correct += (y[class]?1:0);
|
||||
for(j = 0; j < batch; ++j){
|
||||
int index = rand()%d.X.rows;
|
||||
float *x = d.X.vals[index];
|
||||
float *y = d.y.vals[index];
|
||||
forward_network(net, x, 1);
|
||||
sum += backward_network(net, x, y);
|
||||
}
|
||||
update_network(net, step, momentum, decay);
|
||||
}
|
||||
update_network(net, step, momentum, decay);
|
||||
return (float)correct/n;
|
||||
|
||||
return (float)sum/(n*batch);
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user