mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
ADAM
This commit is contained in:
@ -233,7 +233,6 @@ void push_convolutional_layer(convolutional_layer layer)
|
||||
void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay)
|
||||
{
|
||||
int size = layer.size*layer.size*layer.c*layer.n;
|
||||
|
||||
axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
|
||||
scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
|
||||
|
||||
@ -242,9 +241,23 @@ void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float
|
||||
scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
|
||||
}
|
||||
|
||||
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
|
||||
axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
|
||||
scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);
|
||||
if(layer.adam){
|
||||
scal_ongpu(size, layer.B1, layer.m_gpu, 1);
|
||||
scal_ongpu(size, layer.B2, layer.v_gpu, 1);
|
||||
|
||||
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
|
||||
|
||||
axpy_ongpu(size, -(1-layer.B1), layer.weight_updates_gpu, 1, layer.m_gpu, 1);
|
||||
mul_ongpu(size, layer.weight_updates_gpu, 1, layer.weight_updates_gpu, 1);
|
||||
axpy_ongpu(size, (1-layer.B2), layer.weight_updates_gpu, 1, layer.v_gpu, 1);
|
||||
|
||||
adam_gpu(size, layer.weights_gpu, layer.m_gpu, layer.v_gpu, layer.B1, layer.B2, learning_rate/batch, layer.eps, layer.t+1);
|
||||
fill_ongpu(size, 0, layer.weight_updates_gpu, 1);
|
||||
}else{
|
||||
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
|
||||
axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
|
||||
scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user