darknet/src/dropout_layer.c

67 lines
1.9 KiB
C
Raw Normal View History

2014-08-08 23:04:15 +04:00
#include "dropout_layer.h"
2014-11-19 00:51:04 +03:00
#include "utils.h"
#include <stdlib.h>
#include <stdio.h>
2014-08-08 23:04:15 +04:00
dropout_layer *make_dropout_layer(int batch, int inputs, float probability)
{
fprintf(stderr, "Dropout Layer: %d inputs, %f probability\n", inputs, probability);
dropout_layer *layer = calloc(1, sizeof(dropout_layer));
layer->probability = probability;
layer->inputs = inputs;
layer->batch = batch;
2014-11-19 00:51:04 +03:00
#ifdef GPU
layer->rand = calloc(inputs*batch, sizeof(float));
layer->rand_cl = cl_make_array(layer->rand, inputs*batch);
#endif
2014-08-08 23:04:15 +04:00
return layer;
}
void forward_dropout_layer(dropout_layer layer, float *input)
{
int i;
for(i = 0; i < layer.batch * layer.inputs; ++i){
2014-11-19 00:51:04 +03:00
if(rand_uniform() < layer.probability) input[i] = 0;
2014-08-08 23:04:15 +04:00
else input[i] /= (1-layer.probability);
}
}
void backward_dropout_layer(dropout_layer layer, float *input, float *delta)
{
// Don't do shit LULZ
}
2014-11-19 00:51:04 +03:00
#ifdef GPU
cl_kernel get_dropout_kernel()
{
static int init = 0;
static cl_kernel kernel;
if(!init){
kernel = get_kernel("src/dropout_layer.cl", "forward", 0);
init = 1;
}
return kernel;
}
void forward_dropout_layer_gpu(dropout_layer layer, cl_mem input)
{
int j;
int size = layer.inputs*layer.batch;
for(j = 0; j < size; ++j) layer.rand[j] = rand_uniform();
cl_write_array(layer.rand_cl, layer.rand, layer.inputs*layer.batch);
cl_kernel kernel = get_dropout_kernel();
cl_command_queue queue = cl.queue;
cl_uint i = 0;
cl.error = clSetKernelArg(kernel, i++, sizeof(input), (void*) &input);
cl.error = clSetKernelArg(kernel, i++, sizeof(layer.rand_cl), (void*) &layer.rand_cl);
cl.error = clSetKernelArg(kernel, i++, sizeof(layer.probability), (void*) &layer.probability);
check_error(cl);
const size_t global_size[] = {size};
cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0);
check_error(cl);
}
#endif