Subdivisions for batches

This commit is contained in:
Joseph Redmon
2015-03-22 09:56:40 -07:00
parent 9d418102f4
commit 664c5dd2f2
10 changed files with 44 additions and 83 deletions

View File

@ -106,10 +106,11 @@ void forward_network(network net, network_state state)
void update_network(network net)
{
int i;
int update_batch = net.batch*net.subdivisions;
for(i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
update_convolutional_layer(layer, net.learning_rate, net.momentum, net.decay);
update_convolutional_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay);
}
else if(net.types[i] == DECONVOLUTIONAL){
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
@ -117,7 +118,7 @@ void update_network(network net)
}
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
update_connected_layer(layer, net.learning_rate, net.momentum, net.decay);
update_connected_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay);
}
}
}
@ -281,7 +282,7 @@ float train_network_datum(network net, float *x, float *y)
forward_network(net, state);
backward_network(net, state);
float error = get_network_cost(net);
update_network(net);
if((net.seen/net.batch)%net.subdivisions == 0) update_network(net);
return error;
}
@ -294,6 +295,7 @@ float train_network_sgd(network net, data d, int n)
int i;
float sum = 0;
for(i = 0; i < n; ++i){
net.seen += batch;
get_random_batch(d, batch, X, y);
float err = train_network_datum(net, X, y);
sum += err;
@ -314,6 +316,7 @@ float train_network(network net, data d)
float sum = 0;
for(i = 0; i < n; ++i){
get_next_batch(d, batch, i*batch, X, y);
net.seen += batch;
float err = train_network_datum(net, X, y);
sum += err;
}