mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Test for XNOR-conv on CUDA
This commit is contained in:
@ -141,7 +141,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
size_t t_intput_size = new_ldb * n;
|
||||
size_t t_bit_input_size = t_intput_size / 8;// +1;
|
||||
|
||||
|
||||
/*
|
||||
int i = 0;
|
||||
im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.align_workspace_gpu, l.bit_align);
|
||||
//cudaDeviceSynchronize();
|
||||
@ -160,23 +160,46 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
(unsigned char *)l.align_bit_weights_gpu, new_ldb, (unsigned char *)l.transposed_align_workspace_gpu, new_ldb, l.output_gpu, n, l.mean_arr_gpu);
|
||||
//cudaDeviceSynchronize();
|
||||
//check_error(status);
|
||||
|
||||
*/
|
||||
|
||||
{
|
||||
//float_to_bit_gpu(state.input, (unsigned char *)l.align_workspace_gpu, input_size);
|
||||
//
|
||||
|
||||
/*
|
||||
float *input_cpu = (float *)calloc(input_size, sizeof(float));
|
||||
status = cudaMemcpy(input_cpu, state.input, input_size* sizeof(float), cudaMemcpyDeviceToHost);
|
||||
check_error(status);
|
||||
|
||||
convolve_bin_cpu(input_cpu, l.weights, l.output, l.w, l.h, l.c, l.n, l.size, l.pad); // CPU
|
||||
// swaped(binary_weights <-> l.weights)
|
||||
convolve_cpu(input_cpu, l.weights, l.output, l.w, l.h, l.c, l.n, l.size, l.pad); // CPU
|
||||
status = cudaMemcpy(l.output_gpu, l.output, l.outputs * sizeof(float), cudaMemcpyHostToDevice);
|
||||
check_error(status);
|
||||
free(input_cpu);
|
||||
*/
|
||||
|
||||
//convolve_bin_gpu(state.input, l.weights_gpu, l.output_gpu, l.w, l.h, l.c, l.n, l.size, l.pad);
|
||||
/*
|
||||
float *input_cpu = (float *)calloc(input_size, sizeof(float));
|
||||
float *input_bin_cpu = (float *)calloc(input_size, sizeof(char));
|
||||
//float *weights_bin_cpu = (float *)calloc(l.n*l.c*l.size*l.size, sizeof(char));
|
||||
status = cudaMemcpy(input_cpu, state.input, input_size * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
check_error(status);
|
||||
float_to_bit(input_cpu, (unsigned char *)input_bin_cpu, input_size);
|
||||
//float_to_bit(l.weights, (unsigned char *)weights_bin_cpu, l.n*l.c*l.size*l.size); // l.align_bit_weights
|
||||
|
||||
convolve_bin_cpu(input_bin_cpu, (float *)l.align_bit_weights, l.output, l.w, l.h, l.c, l.n, l.size, l.pad, l.new_lda, l.mean_arr); // CPU
|
||||
status = cudaMemcpy(l.output_gpu, l.output, l.outputs * sizeof(float), cudaMemcpyHostToDevice);
|
||||
check_error(status);
|
||||
//free(weights_bin_cpu);
|
||||
free(input_bin_cpu);
|
||||
free(input_cpu);
|
||||
*/
|
||||
|
||||
float_to_bit_gpu(state.input, (unsigned char *)l.align_workspace_gpu, input_size);
|
||||
convolve_bin_gpu(l.align_workspace_gpu, (float *)l.align_bit_weights_gpu, l.output_gpu, l.w, l.h, l.c, l.n, l.size, l.pad, l.new_lda, l.mean_arr_gpu);
|
||||
|
||||
|
||||
//convolve_gpu(state.input, l.weights_gpu, l.output_gpu, l.w, l.h, l.c, l.n, l.size, l.pad);
|
||||
|
||||
//cudaDeviceSynchronize();
|
||||
//check_error(status);
|
||||
|
||||
@ -309,10 +332,16 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
int k = l.size*l.size*l.c;
|
||||
int n = l.out_w*l.out_h;
|
||||
for(i = 0; i < l.batch; ++i){
|
||||
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace);
|
||||
float *im = state.input + i*l.c*l.h*l.w;
|
||||
float * a = l.weights_gpu;
|
||||
float * b = state.workspace;
|
||||
float * c = l.output_gpu;
|
||||
if (l.size == 1) {
|
||||
b = im;
|
||||
}
|
||||
else {
|
||||
im2col_ongpu(im, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace);
|
||||
}
|
||||
gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
|
||||
}
|
||||
#endif
|
||||
|
@ -605,6 +605,7 @@ void binary_align_weights(convolutional_layer *l)
|
||||
int m = l->n;
|
||||
int k = l->size*l->size*l->c;
|
||||
size_t new_lda = k + (l->lda_align - k % l->lda_align); // (k / 8 + 1) * 8;
|
||||
l->new_lda = new_lda;
|
||||
|
||||
binarize_weights(l->weights, m, k, l->binary_weights);
|
||||
|
||||
|
10
src/im2col.h
10
src/im2col.h
@ -29,9 +29,15 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K,
|
||||
unsigned char *B, int ldb,
|
||||
float *C, int ldc, float *mean_arr);
|
||||
|
||||
void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad);
|
||||
void convolve_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad);
|
||||
|
||||
void convolve_bin_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad);
|
||||
void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad,
|
||||
int new_lda, float *mean_arr_gpu);
|
||||
|
||||
void convolve_bin_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n,
|
||||
int size, int pad, int new_lda, float *mean_arr_gpu);
|
||||
|
||||
void convolve_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad);
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
@ -258,7 +258,17 @@ void fill_int8_gpu(unsigned char *src, unsigned char val, size_t size) {
|
||||
// --------------------------------
|
||||
|
||||
typedef unsigned long long int uint64_t;
|
||||
typedef unsigned int uint32_t;
|
||||
typedef unsigned char uint8_t;
|
||||
typedef char int8_t;
|
||||
|
||||
__device__ __host__ static inline uint64_t broadcast_bit_1_to_64(uint8_t src) {
|
||||
return (src > 0) ? 0xFFFFFFFFFFFFFFFF : 0;
|
||||
}
|
||||
|
||||
__device__ __host__ static inline uint8_t xnor_bit1(uint8_t a, uint8_t b) {
|
||||
return ~(a^b) & 0b1;
|
||||
}
|
||||
|
||||
__device__ __host__ static inline uint64_t xnor_int64(uint64_t a, uint64_t b) {
|
||||
return ~(a^b);
|
||||
@ -462,78 +472,9 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K,
|
||||
C, ldc,
|
||||
mean_arr);
|
||||
}
|
||||
|
||||
// --------------------------------
|
||||
|
||||
|
||||
__global__ void convolve_bin_gpu_kernel(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad)
|
||||
{
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
|
||||
int fil;
|
||||
// filter index
|
||||
//for (fil = 0; fil < n; ++fil)
|
||||
int chan, y, x, f_y, f_x;
|
||||
// channel index
|
||||
//for (chan = 0; chan < in_c; ++chan)
|
||||
// input - y
|
||||
//for (y = 0; y < in_h; ++y)
|
||||
// input - x
|
||||
//for (x = 0; x < in_w; ++x)
|
||||
x = index % in_w;
|
||||
int index2 = index / in_w;
|
||||
y = index2 % in_h;
|
||||
fil = index2 / in_h;
|
||||
if (fil < n)
|
||||
{
|
||||
|
||||
int const output_index = fil*in_w*in_h + y*in_w + x;
|
||||
float sum = 0;
|
||||
|
||||
for (chan = 0; chan < in_c; ++chan)
|
||||
{
|
||||
int const weights_pre_index = fil*in_c*size*size + chan*size*size;
|
||||
int const input_pre_index = chan*in_w*in_h;
|
||||
|
||||
// filter - y
|
||||
for (f_y = 0; f_y < size; ++f_y)
|
||||
{
|
||||
int input_y = y + f_y - pad;
|
||||
// filter - x
|
||||
for (f_x = 0; f_x < size; ++f_x)
|
||||
{
|
||||
int input_x = x + f_x - pad;
|
||||
if (input_y < 0 || input_x < 0 || input_y >= in_h || input_x >= in_w) continue;
|
||||
|
||||
int input_index = input_pre_index + input_y*in_w + input_x;
|
||||
int weights_index = weights_pre_index + f_y*size + f_x;
|
||||
|
||||
sum += input[input_index] *weights[weights_index];
|
||||
|
||||
}
|
||||
}
|
||||
// l.output[filters][width][height] +=
|
||||
// state.input[channels][width][height] *
|
||||
// l.weights[filters][channels][filter_width][filter_height];
|
||||
//output[output_index] += sum;
|
||||
}
|
||||
output[output_index] = sum;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad)
|
||||
{
|
||||
size_t array_size = in_w*in_h*n; // width X height X filters
|
||||
const int num_blocks = array_size / BLOCK + 1;
|
||||
//printf("\n array_size = %d, num_blocks = %d, w = %d, h = %d, n = %d, c = %d, pad = %d \n", array_size, num_blocks, in_w, in_h, n, in_c, pad);
|
||||
|
||||
convolve_bin_gpu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (input, weights, output, in_w, in_h, in_c, n, size, pad);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void convolve_bin_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad)
|
||||
void convolve_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad)
|
||||
{
|
||||
int fil;
|
||||
// filter index
|
||||
@ -576,4 +517,325 @@ void convolve_bin_cpu(float *input, float *weights, float *output, int in_w, int
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
// --------------------------------
|
||||
|
||||
|
||||
void convolve_bin_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n,
|
||||
int size, int pad, int new_lda, float *mean_arr_gpu)
|
||||
{
|
||||
int fil;
|
||||
// filter index
|
||||
#pragma omp parallel for // "omp parallel for" - automatic parallelization of loop by using OpenMP
|
||||
for (fil = 0; fil < n; ++fil) {
|
||||
float mean_val = mean_arr_gpu[fil];
|
||||
int chan, y, x, f_y, f_x;
|
||||
// channel index
|
||||
for (chan = 0; chan < in_c; ++chan)
|
||||
// input - y
|
||||
for (y = 0; y < in_h; ++y)
|
||||
// input - x
|
||||
for (x = 0; x < in_w; ++x)
|
||||
{
|
||||
int const output_index = fil*in_w*in_h + y*in_w + x;
|
||||
int const weights_pre_index = fil*in_c*size*size + chan*size*size;
|
||||
int const input_pre_index = chan*in_w*in_h;
|
||||
int sum = 0;
|
||||
int good_val = 0;
|
||||
|
||||
// filter - y
|
||||
for (f_y = 0; f_y < size; ++f_y)
|
||||
{
|
||||
int input_y = y + f_y - pad;
|
||||
// filter - x
|
||||
for (f_x = 0; f_x < size; ++f_x)
|
||||
{
|
||||
int input_x = x + f_x - pad;
|
||||
if (input_y < 0 || input_x < 0 || input_y >= in_h || input_x >= in_w) continue;
|
||||
|
||||
int input_index = input_pre_index + input_y*in_w + input_x;
|
||||
//int weights_index = weights_pre_index + f_y*size + f_x;
|
||||
//int weights_index = fil*in_c*size*size + chan*size*size + f_y*size + f_x;
|
||||
int weights_index = fil*new_lda + chan*size*size + f_y*size + f_x;
|
||||
|
||||
//sum += input[input_index] * weights[weights_index];
|
||||
|
||||
int8_t in_bit = get_bit((uint8_t *)input, input_index);
|
||||
int8_t w_bit = get_bit((uint8_t *)weights, weights_index);
|
||||
int res = xnor_bit1(in_bit, w_bit);
|
||||
sum += res;
|
||||
good_val++;
|
||||
//sum += (res > 0) ? 1 : -1;
|
||||
//in_bit = (in_bit > 0) ? 1 : -1;
|
||||
//w_bit = (w_bit > 0) ? 1 : -1;
|
||||
//int8_t res = in_bit*w_bit;
|
||||
//sum += res;
|
||||
//printf("\n i: %d x w: %d = res: %d \t sum: %d \t mean = %f \n", in_bit, w_bit, res, sum, mean_val);
|
||||
}
|
||||
}
|
||||
//printf("sum = %d, ", sum);
|
||||
sum = sum - (good_val - sum);
|
||||
//printf(" size = %d, sum = %d \n", size, sum);
|
||||
|
||||
// l.output[filters][width][height] +=
|
||||
// state.input[channels][width][height] *
|
||||
// l.weights[filters][channels][filter_width][filter_height];
|
||||
output[output_index] += sum*mean_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
// --------------------------------
|
||||
|
||||
__global__ void convolve_gpu_kernel(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad)
|
||||
{
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
|
||||
int fil;
|
||||
// filter index
|
||||
//for (fil = 0; fil < n; ++fil)
|
||||
int chan, y, x, f_y, f_x;
|
||||
// channel index
|
||||
//for (chan = 0; chan < in_c; ++chan)
|
||||
// input - y
|
||||
//for (y = 0; y < in_h; ++y)
|
||||
// input - x
|
||||
//for (x = 0; x < in_w; ++x)
|
||||
x = index % in_w;
|
||||
int index2 = index / in_w;
|
||||
y = index2 % in_h;
|
||||
fil = index2 / in_h;
|
||||
if (fil < n)
|
||||
{
|
||||
|
||||
int const output_index = fil*in_w*in_h + y*in_w + x;
|
||||
float sum = 0;
|
||||
|
||||
for (chan = 0; chan < in_c; ++chan)
|
||||
{
|
||||
int const weights_pre_index = fil*in_c*size*size + chan*size*size;
|
||||
int const input_pre_index = chan*in_w*in_h;
|
||||
|
||||
// filter - y
|
||||
for (f_y = 0; f_y < size; ++f_y)
|
||||
{
|
||||
int input_y = y + f_y - pad;
|
||||
// filter - x
|
||||
for (f_x = 0; f_x < size; ++f_x)
|
||||
{
|
||||
int input_x = x + f_x - pad;
|
||||
if (input_y < 0 || input_x < 0 || input_y >= in_h || input_x >= in_w) continue;
|
||||
|
||||
int input_index = input_pre_index + input_y*in_w + input_x;
|
||||
int weights_index = weights_pre_index + f_y*size + f_x;
|
||||
|
||||
sum += input[input_index] * weights[weights_index];
|
||||
|
||||
}
|
||||
}
|
||||
// l.output[filters][width][height] +=
|
||||
// state.input[channels][width][height] *
|
||||
// l.weights[filters][channels][filter_width][filter_height];
|
||||
//output[output_index] += sum;
|
||||
}
|
||||
output[output_index] = sum;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void convolve_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad)
|
||||
{
|
||||
size_t array_size = in_w*in_h*n; // width X height X filters
|
||||
const int num_blocks = array_size / BLOCK + 1;
|
||||
//printf("\n array_size = %d, num_blocks = %d, w = %d, h = %d, n = %d, c = %d, pad = %d \n", array_size, num_blocks, in_w, in_h, n, in_c, pad);
|
||||
|
||||
convolve_gpu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (input, weights, output, in_w, in_h, in_c, n, size, pad);
|
||||
}
|
||||
|
||||
// --------------------------------
|
||||
|
||||
/*
|
||||
__global__ void convolve_bin_gpu_kernel(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n,
|
||||
int size, int pad, int new_lda, float *mean_arr_gpu)
|
||||
{
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
|
||||
int fil;
|
||||
// filter index
|
||||
//for (fil = 0; fil < n; ++fil)
|
||||
int chan, y, x, f_y, f_x;
|
||||
// channel index
|
||||
//for (chan = 0; chan < in_c; ++chan)
|
||||
// input - y
|
||||
//for (y = 0; y < in_h; ++y)
|
||||
// input - x
|
||||
//for (x = 0; x < in_w; ++x)
|
||||
x = index % in_w;
|
||||
int index2 = index / in_w;
|
||||
y = index2 % in_h;
|
||||
fil = index2 / in_h;
|
||||
if (fil < n) // (1-6 for one BLOCK)
|
||||
{
|
||||
//float mean_val = mean_arr_gpu[fil];
|
||||
int const output_index = fil*in_w*in_h + y*in_w + x;
|
||||
int sum = 0;
|
||||
int good_val = 0;
|
||||
|
||||
for (chan = 0; chan < in_c; ++chan)
|
||||
{
|
||||
//int const weights_pre_index = fil*in_c*size*size + chan*size*size;
|
||||
int const weights_pre_index = fil*new_lda + chan*size*size;
|
||||
int const input_pre_index = chan*in_w*in_h;
|
||||
|
||||
// filter - y
|
||||
for (f_y = 0; f_y < size; ++f_y)
|
||||
{
|
||||
int input_y = y + f_y - pad;
|
||||
// filter - x
|
||||
for (f_x = 0; f_x < size; ++f_x)
|
||||
{
|
||||
int input_x = x + f_x - pad;
|
||||
if (input_y < 0 || input_x < 0 || input_y >= in_h || input_x >= in_w) continue;
|
||||
|
||||
int input_index = input_pre_index + input_y*in_w + input_x;
|
||||
int weights_index = weights_pre_index + f_y*size + f_x;
|
||||
//int weights_index = fil*in_c*size*size + chan*size*size + f_y*size + f_x;
|
||||
//int weights_index = fil*new_lda + chan*size*size + f_y*size + f_x;
|
||||
|
||||
uint8_t in_bit = get_bit((uint8_t *)input, input_index);
|
||||
uint8_t w_bit = get_bit((uint8_t *)weights, weights_index);
|
||||
int res = xnor_bit1(in_bit, w_bit);
|
||||
sum += res;
|
||||
good_val++;
|
||||
|
||||
//sum += input[input_index] *weights[weights_index];
|
||||
|
||||
}
|
||||
}
|
||||
// l.output[filters][width][height] +=
|
||||
// state.input[channels][width][height] *
|
||||
// l.weights[filters][channels][filter_width][filter_height];
|
||||
//output[output_index] += sum;
|
||||
}
|
||||
sum = sum - (good_val - sum);
|
||||
output[output_index] = sum * mean_arr_gpu[fil]; // atoimcAdd for inter-BLOCK sum
|
||||
}
|
||||
|
||||
}
|
||||
*/
|
||||
|
||||
__global__ void convolve_bin_gpu_kernel(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n,
|
||||
int size, int pad, int new_lda, float *mean_arr_gpu)
|
||||
{
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
|
||||
int fil;
|
||||
// filter index
|
||||
//for (fil = 0; fil < n; ++fil)
|
||||
int chan, y, x, f_y, f_x;
|
||||
// channel index
|
||||
//for (chan = 0; chan < in_c; ++chan)
|
||||
// input - y
|
||||
//for (y = 0; y < in_h; ++y)
|
||||
// input - x
|
||||
//for (x = 0; x < in_w; ++x)
|
||||
x = index % in_w;
|
||||
int index2 = index / in_w;
|
||||
y = index2 % in_h;
|
||||
fil = index2 / in_h;
|
||||
if (fil < n) // (1-6 for one BLOCK)
|
||||
{
|
||||
//float mean_val = mean_arr_gpu[fil];
|
||||
int const output_index = fil*in_w*in_h + y*in_w + x;
|
||||
int sum = 0;
|
||||
int good_val = 0;
|
||||
|
||||
int min_index = blockIdx.x*blockDim.x;
|
||||
int min_fil = (min_index / in_w) / in_h;
|
||||
int max_index = (blockIdx.x+1)*blockDim.x - 1;
|
||||
int max_fil = (max_index / in_w) / in_h;
|
||||
|
||||
__shared__ uint32_t weights_shared[3*3*1024*6/32 + 1]; // 7 KB (6 filters) - use (new_lda) for size calculation
|
||||
//const int weights_size = size*size*in_c/8;
|
||||
const int weights_size = size*size*in_c / 32 + 1;
|
||||
|
||||
for (int fil = min_fil; fil <= max_fil; fil++) {
|
||||
for (int s = threadIdx.x; s < weights_size; s += blockDim.x) {
|
||||
//weights_shared[s + (fil - min_fil)*new_lda / 8] = ((uint8_t *)weights)[fil*new_lda / 8 + s];
|
||||
weights_shared[s + (fil - min_fil)*new_lda/32] = ((uint32_t *)weights)[fil*new_lda / 32 + s];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (chan = 0; chan < in_c; ++chan)
|
||||
{
|
||||
//int const weights_pre_index = fil*in_c*size*size + chan*size*size;
|
||||
int const weights_pre_index = fil*new_lda + chan*size*size;
|
||||
int const input_pre_index = chan*in_w*in_h;
|
||||
|
||||
__shared__ uint32_t input_shared[416*416/32]; // 21.2 KB bytes
|
||||
const int input_shared_size = in_w*in_h / 32;
|
||||
const int add_input_index = input_pre_index % 32;
|
||||
|
||||
for (int s = threadIdx.x; s < input_shared_size; s += blockDim.x) {
|
||||
input_shared[s] = ((uint32_t *)weights)[input_pre_index / 32 + s];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// filter - y
|
||||
for (f_y = 0; f_y < size; ++f_y)
|
||||
{
|
||||
int input_y = y + f_y - pad;
|
||||
// filter - x
|
||||
for (f_x = 0; f_x < size; ++f_x)
|
||||
{
|
||||
int input_x = x + f_x - pad;
|
||||
if (input_y < 0 || input_x < 0 || input_y >= in_h || input_x >= in_w) continue;
|
||||
|
||||
int input_index = input_pre_index + input_y*in_w + input_x;
|
||||
int weights_index = weights_pre_index + f_y*size + f_x;
|
||||
//int weights_index = fil*in_c*size*size + chan*size*size + f_y*size + f_x;
|
||||
//int weights_index = fil*new_lda + chan*size*size + f_y*size + f_x;
|
||||
|
||||
uint8_t in_bit = get_bit((uint8_t *)input, input_index);
|
||||
//uint8_t w_bit = get_bit((uint8_t *)weights, weights_index);
|
||||
|
||||
//int weights_index = fil*in_c*size*size + chan*size*size + f_y*size + f_x;
|
||||
int weights_shared_index = (fil - min_fil)*new_lda + chan*size*size + f_y*size + f_x;
|
||||
//uint8_t in_bit = get_bit((uint8_t *)weights_shared, weights_shared_index);
|
||||
uint8_t w_bit = get_bit((uint8_t *)weights_shared, weights_shared_index);
|
||||
|
||||
//int input_index = input_pre_index + input_y*in_w + input_x;
|
||||
//int input_shared_index = /*input_pre_index +*/ input_y*in_w + input_x + add_input_index;
|
||||
//uint8_t in_bit = get_bit((uint8_t *)input_shared, input_shared_index);
|
||||
|
||||
int res = xnor_bit1(in_bit, w_bit);
|
||||
sum += res;
|
||||
good_val++;
|
||||
|
||||
//sum += input[input_index] *weights[weights_index];
|
||||
|
||||
}
|
||||
}
|
||||
// l.output[filters][width][height] +=
|
||||
// state.input[channels][width][height] *
|
||||
// l.weights[filters][channels][filter_width][filter_height];
|
||||
//output[output_index] += sum;
|
||||
}
|
||||
sum = sum - (good_val - sum);
|
||||
output[output_index] = sum * mean_arr_gpu[fil]; // atoimcAdd for inter-BLOCK sum
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n,
|
||||
int size, int pad, int new_lda, float *mean_arr_gpu)
|
||||
{
|
||||
size_t array_size = in_w*in_h*n; // width X height X filters
|
||||
const int num_blocks = array_size / BLOCK + 1;
|
||||
//printf("\n array_size = %d, num_blocks = %d, w = %d, h = %d, n = %d, c = %d, pad = %d \n", array_size, num_blocks, in_w, in_h, n, in_c, pad);
|
||||
|
||||
convolve_bin_gpu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (input, weights, output, in_w, in_h, in_c, n, size, pad, new_lda, mean_arr_gpu);
|
||||
}
|
||||
|
||||
// --------------------------------
|
||||
|
||||
|
@ -189,6 +189,7 @@ struct layer{
|
||||
float *mean_arr;
|
||||
int align_bit_weights_size;
|
||||
int lda_align;
|
||||
int new_lda;
|
||||
int bit_align;
|
||||
|
||||
float *col_image;
|
||||
|
Reference in New Issue
Block a user