mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Subdivisions for batches
This commit is contained in:
@ -106,10 +106,11 @@ void forward_network(network net, network_state state)
|
||||
void update_network(network net)
|
||||
{
|
||||
int i;
|
||||
int update_batch = net.batch*net.subdivisions;
|
||||
for(i = 0; i < net.n; ++i){
|
||||
if(net.types[i] == CONVOLUTIONAL){
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
update_convolutional_layer(layer, net.learning_rate, net.momentum, net.decay);
|
||||
update_convolutional_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay);
|
||||
}
|
||||
else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
@ -117,7 +118,7 @@ void update_network(network net)
|
||||
}
|
||||
else if(net.types[i] == CONNECTED){
|
||||
connected_layer layer = *(connected_layer *)net.layers[i];
|
||||
update_connected_layer(layer, net.learning_rate, net.momentum, net.decay);
|
||||
update_connected_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -281,7 +282,7 @@ float train_network_datum(network net, float *x, float *y)
|
||||
forward_network(net, state);
|
||||
backward_network(net, state);
|
||||
float error = get_network_cost(net);
|
||||
update_network(net);
|
||||
if((net.seen/net.batch)%net.subdivisions == 0) update_network(net);
|
||||
return error;
|
||||
}
|
||||
|
||||
@ -294,6 +295,7 @@ float train_network_sgd(network net, data d, int n)
|
||||
int i;
|
||||
float sum = 0;
|
||||
for(i = 0; i < n; ++i){
|
||||
net.seen += batch;
|
||||
get_random_batch(d, batch, X, y);
|
||||
float err = train_network_datum(net, X, y);
|
||||
sum += err;
|
||||
@ -314,6 +316,7 @@ float train_network(network net, data d)
|
||||
float sum = 0;
|
||||
for(i = 0; i < n; ++i){
|
||||
get_next_batch(d, batch, i*batch, X, y);
|
||||
net.seen += batch;
|
||||
float err = train_network_datum(net, X, y);
|
||||
sum += err;
|
||||
}
|
||||
|
Reference in New Issue
Block a user