darknet/src/dropout_layer.c

58 lines
1.4 KiB
C
Raw Normal View History

2014-08-08 23:04:15 +04:00
#include "dropout_layer.h"
2015-03-12 08:20:15 +03:00
#include "params.h"
2014-11-19 00:51:04 +03:00
#include "utils.h"
2015-01-23 03:38:24 +03:00
#include "cuda.h"
2014-11-19 00:51:04 +03:00
#include <stdlib.h>
#include <stdio.h>
2014-08-08 23:04:15 +04:00
2015-05-11 23:46:49 +03:00
dropout_layer make_dropout_layer(int batch, int inputs, float probability)
2014-08-08 23:04:15 +04:00
{
fprintf(stderr, "Dropout Layer: %d inputs, %f probability\n", inputs, probability);
2015-05-11 23:46:49 +03:00
dropout_layer l = {0};
l.type = DROPOUT;
l.probability = probability;
l.inputs = inputs;
l.outputs = inputs;
l.batch = batch;
l.rand = calloc(inputs*batch, sizeof(float));
l.scale = 1./(1.-probability);
2014-12-13 23:01:21 +03:00
#ifdef GPU
2015-05-11 23:46:49 +03:00
l.rand_gpu = cuda_make_array(l.rand, inputs*batch);
2014-11-19 00:51:04 +03:00
#endif
2015-05-11 23:46:49 +03:00
return l;
2014-08-08 23:04:15 +04:00
}
2015-05-11 23:46:49 +03:00
void resize_dropout_layer(dropout_layer *l, int inputs)
2015-02-11 06:41:03 +03:00
{
2015-05-11 23:46:49 +03:00
l->rand = realloc(l->rand, l->inputs*l->batch*sizeof(float));
2015-02-11 06:41:03 +03:00
#ifdef GPU
2015-05-11 23:46:49 +03:00
cuda_free(l->rand_gpu);
2015-02-11 06:41:03 +03:00
2015-05-11 23:46:49 +03:00
l->rand_gpu = cuda_make_array(l->rand, inputs*l->batch);
2015-02-11 06:41:03 +03:00
#endif
}
2015-05-11 23:46:49 +03:00
void forward_dropout_layer(dropout_layer l, network_state state)
2014-08-08 23:04:15 +04:00
{
int i;
2015-03-12 08:20:15 +03:00
if (!state.train) return;
2015-05-11 23:46:49 +03:00
for(i = 0; i < l.batch * l.inputs; ++i){
2014-12-13 23:01:21 +03:00
float r = rand_uniform();
2015-05-11 23:46:49 +03:00
l.rand[i] = r;
if(r < l.probability) state.input[i] = 0;
else state.input[i] *= l.scale;
2014-08-08 23:04:15 +04:00
}
}
2014-12-13 23:01:21 +03:00
2015-05-11 23:46:49 +03:00
void backward_dropout_layer(dropout_layer l, network_state state)
2014-08-08 23:04:15 +04:00
{
2014-12-13 23:01:21 +03:00
int i;
2015-03-12 08:20:15 +03:00
if(!state.delta) return;
2015-05-11 23:46:49 +03:00
for(i = 0; i < l.batch * l.inputs; ++i){
float r = l.rand[i];
if(r < l.probability) state.delta[i] = 0;
else state.delta[i] *= l.scale;
2014-12-13 23:01:21 +03:00
}
2014-08-08 23:04:15 +04:00
}
2014-11-19 00:51:04 +03:00