mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
so much need to commit
This commit is contained in:
@ -4,6 +4,7 @@
|
||||
|
||||
extern "C" {
|
||||
#include "convolutional_layer.h"
|
||||
#include "batchnorm_layer.h"
|
||||
#include "gemm.h"
|
||||
#include "blas.h"
|
||||
#include "im2col.h"
|
||||
@ -12,6 +13,41 @@ extern "C" {
|
||||
#include "cuda.h"
|
||||
}
|
||||
|
||||
__global__ void binarize_kernel(float *x, int n, float *binary)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
binary[i] = (x[i] > 0) ? 1 : -1;
|
||||
}
|
||||
|
||||
void binarize_gpu(float *x, int n, float *binary)
|
||||
{
|
||||
binarize_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, binary);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
__global__ void binarize_input_kernel(float *input, int n, int size, float *binary)
|
||||
{
|
||||
int s = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (s >= size) return;
|
||||
int i = 0;
|
||||
float mean = 0;
|
||||
for(i = 0; i < n; ++i){
|
||||
mean += abs(input[i*size + s]);
|
||||
}
|
||||
mean = mean / n;
|
||||
for(i = 0; i < n; ++i){
|
||||
binary[i*size + s] = (input[i*size + s] > 0) ? mean : -mean;
|
||||
}
|
||||
}
|
||||
|
||||
void binarize_input_gpu(float *input, int n, int size, float *binary)
|
||||
{
|
||||
binarize_input_kernel<<<cuda_gridsize(size), BLOCK>>>(input, n, size, binary);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
__global__ void binarize_filters_kernel(float *filters, int n, int size, float *binary)
|
||||
{
|
||||
int f = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
@ -27,140 +63,12 @@ __global__ void binarize_filters_kernel(float *filters, int n, int size, float *
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void scale_bias_kernel(float *output, float *biases, int n, int size)
|
||||
{
|
||||
int offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int filter = blockIdx.y;
|
||||
int batch = blockIdx.z;
|
||||
|
||||
if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter];
|
||||
}
|
||||
|
||||
void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
|
||||
{
|
||||
dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
|
||||
dim3 dimBlock(BLOCK, 1, 1);
|
||||
|
||||
scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
|
||||
{
|
||||
__shared__ float part[BLOCK];
|
||||
int i,b;
|
||||
int filter = blockIdx.x;
|
||||
int p = threadIdx.x;
|
||||
float sum = 0;
|
||||
for(b = 0; b < batch; ++b){
|
||||
for(i = 0; i < size; i += BLOCK){
|
||||
int index = p + i + size*(filter + n*b);
|
||||
sum += (p+i < size) ? delta[index]*x_norm[index] : 0;
|
||||
}
|
||||
}
|
||||
part[p] = sum;
|
||||
__syncthreads();
|
||||
if (p == 0) {
|
||||
for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i];
|
||||
}
|
||||
}
|
||||
|
||||
void binarize_filters_gpu(float *filters, int n, int size, float *binary)
|
||||
{
|
||||
binarize_filters_kernel<<<cuda_gridsize(n), BLOCK>>>(filters, n, size, binary);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
|
||||
{
|
||||
backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
__global__ void add_bias_kernel(float *output, float *biases, int n, int size)
|
||||
{
|
||||
int offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int filter = blockIdx.y;
|
||||
int batch = blockIdx.z;
|
||||
|
||||
if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter];
|
||||
}
|
||||
|
||||
void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
|
||||
{
|
||||
dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
|
||||
dim3 dimBlock(BLOCK, 1, 1);
|
||||
|
||||
add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
|
||||
{
|
||||
__shared__ float part[BLOCK];
|
||||
int i,b;
|
||||
int filter = blockIdx.x;
|
||||
int p = threadIdx.x;
|
||||
float sum = 0;
|
||||
for(b = 0; b < batch; ++b){
|
||||
for(i = 0; i < size; i += BLOCK){
|
||||
int index = p + i + size*(filter + n*b);
|
||||
sum += (p+i < size) ? delta[index] : 0;
|
||||
}
|
||||
}
|
||||
part[p] = sum;
|
||||
__syncthreads();
|
||||
if (p == 0) {
|
||||
for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void dot_kernel(float *output, float scale, int batch, int n, int size, float *delta)
|
||||
{
|
||||
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
int f1 = index / n;
|
||||
int f2 = index % n;
|
||||
if (f2 <= f1) return;
|
||||
|
||||
float sum = 0;
|
||||
float norm1 = 0;
|
||||
float norm2 = 0;
|
||||
int b, i;
|
||||
for(b = 0; b < batch; ++b){
|
||||
for(i = 0; i < size; ++i){
|
||||
int i1 = b * size * n + f1 * size + i;
|
||||
int i2 = b * size * n + f2 * size + i;
|
||||
sum += output[i1] * output[i2];
|
||||
norm1 += output[i1] * output[i1];
|
||||
norm2 += output[i2] * output[i2];
|
||||
}
|
||||
}
|
||||
norm1 = sqrt(norm1);
|
||||
norm2 = sqrt(norm2);
|
||||
float norm = norm1 * norm2;
|
||||
sum = sum / norm;
|
||||
for(b = 0; b < batch; ++b){
|
||||
for(i = 0; i < size; ++i){
|
||||
int i1 = b * size * n + f1 * size + i;
|
||||
int i2 = b * size * n + f2 * size + i;
|
||||
delta[i1] += - scale * sum * output[i2] / norm;
|
||||
delta[i2] += - scale * sum * output[i1] / norm;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dot_error_gpu(layer l)
|
||||
{
|
||||
dot_kernel<<<cuda_gridsize(l.n*l.n), BLOCK>>>(l.output_gpu, l.dot, l.batch, l.n, l.out_w * l.out_h, l.delta_gpu);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
|
||||
{
|
||||
backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
{
|
||||
int i;
|
||||
@ -175,6 +83,16 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
swap_binary(&l);
|
||||
}
|
||||
|
||||
if(l.xnor){
|
||||
binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
|
||||
//binarize_gpu(l.filters_gpu, l.n*l.c*l.size*l.size, l.binary_filters_gpu);
|
||||
swap_binary(&l);
|
||||
for(i = 0; i < l.batch; ++i){
|
||||
binarize_input_gpu(state.input + i*l.inputs, l.c, l.h*l.w, l.binary_input_gpu + i*l.inputs);
|
||||
}
|
||||
state.input = l.binary_input_gpu;
|
||||
}
|
||||
|
||||
for(i = 0; i < l.batch; ++i){
|
||||
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu);
|
||||
float * a = l.filters_gpu;
|
||||
@ -184,29 +102,13 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
}
|
||||
|
||||
if (l.batch_normalize) {
|
||||
if (state.train) {
|
||||
fast_mean_gpu(l.output_gpu, l.batch, l.n, l.out_h*l.out_w, l.mean_gpu);
|
||||
fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.n, l.out_h*l.out_w, l.variance_gpu);
|
||||
|
||||
scal_ongpu(l.n, .95, l.rolling_mean_gpu, 1);
|
||||
axpy_ongpu(l.n, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
|
||||
scal_ongpu(l.n, .95, l.rolling_variance_gpu, 1);
|
||||
axpy_ongpu(l.n, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
|
||||
|
||||
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
|
||||
normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_h*l.out_w);
|
||||
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1);
|
||||
} else {
|
||||
normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.n, l.out_h*l.out_w);
|
||||
}
|
||||
|
||||
scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
|
||||
forward_batchnorm_layer_gpu(l, state);
|
||||
}
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, n);
|
||||
|
||||
activate_array_ongpu(l.output_gpu, m*n*l.batch, l.activation);
|
||||
if(l.dot > 0) dot_error_gpu(l);
|
||||
if(l.binary) swap_binary(&l);
|
||||
//if(l.dot > 0) dot_error_gpu(l);
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
}
|
||||
|
||||
void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
@ -222,15 +124,10 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, k);
|
||||
|
||||
if(l.batch_normalize){
|
||||
backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.scale_updates_gpu);
|
||||
|
||||
scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
|
||||
|
||||
fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.mean_delta_gpu);
|
||||
fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.variance_delta_gpu);
|
||||
normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.delta_gpu);
|
||||
backward_batchnorm_layer_gpu(l, state);
|
||||
}
|
||||
|
||||
if(l.xnor) state.input = l.binary_input_gpu;
|
||||
for(i = 0; i < l.batch; ++i){
|
||||
float * a = l.delta_gpu;
|
||||
float * b = l.col_image_gpu;
|
||||
@ -240,7 +137,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
|
||||
|
||||
if(state.delta){
|
||||
if(l.binary) swap_binary(&l);
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
float * a = l.filters_gpu;
|
||||
float * b = l.delta_gpu;
|
||||
float * c = l.col_image_gpu;
|
||||
@ -248,7 +145,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
|
||||
|
||||
col2im_ongpu(l.col_image_gpu, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
|
||||
if(l.binary) swap_binary(&l);
|
||||
if(l.binary || l.xnor) swap_binary(&l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user