mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
add avgpool layer
This commit is contained in:
57
src/avgpool_layer_kernels.cu
Normal file
57
src/avgpool_layer_kernels.cu
Normal file
@ -0,0 +1,57 @@
|
||||
extern "C" {
|
||||
#include "avgpool_layer.h"
|
||||
#include "cuda.h"
|
||||
}
|
||||
|
||||
__global__ void forward_avgpool_layer_kernel(int n, int w, int h, int c, float *input, float *output)
|
||||
{
|
||||
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if(id >= n) return;
|
||||
|
||||
int k = id % c;
|
||||
id /= c;
|
||||
int b = id;
|
||||
|
||||
int i;
|
||||
int out_index = (k + c*b);
|
||||
output[out_index] = 0;
|
||||
for(i = 0; i < w*h; ++i){
|
||||
int in_index = i + h*w*(k + b*c);
|
||||
output[out_index] += input[in_index];
|
||||
}
|
||||
output[out_index] /= w*h;
|
||||
}
|
||||
|
||||
__global__ void backward_avgpool_layer_kernel(int n, int w, int h, int c, float *in_delta, float *out_delta)
|
||||
{
|
||||
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if(id >= n) return;
|
||||
|
||||
int k = id % c;
|
||||
id /= c;
|
||||
int b = id;
|
||||
|
||||
int i;
|
||||
int out_index = (k + c*b);
|
||||
for(i = 0; i < w*h; ++i){
|
||||
int in_index = i + h*w*(k + b*c);
|
||||
in_delta[in_index] = out_delta[out_index] / (w*h);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void forward_avgpool_layer_gpu(avgpool_layer layer, network_state state)
|
||||
{
|
||||
size_t n = layer.c*layer.batch;
|
||||
|
||||
forward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, state.input, layer.output_gpu);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void backward_avgpool_layer_gpu(avgpool_layer layer, network_state state)
|
||||
{
|
||||
size_t n = layer.c*layer.batch;
|
||||
|
||||
backward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, state.delta, layer.delta_gpu);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
Reference in New Issue
Block a user