mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
improve XNOR Tensor Cores GEMM - N 2x unrolled - minor performance improvement
This commit is contained in:
@ -1226,6 +1226,189 @@ int warpAllReduceSum(int val) {
|
||||
#define WMMA_K 128
|
||||
#define WMMA_K32 (WMMA_K/32)
|
||||
|
||||
#define WMMA_Nx2 (WMMA_N*2)
|
||||
|
||||
// Tensor Cores are used for XOR-GEMM
|
||||
__global__ void gemm_nn_custom_bin_mean_transposed_tensor_kernel(int M, int N, int K,
|
||||
unsigned char *A, int lda,
|
||||
unsigned char *B, int ldb,
|
||||
float *C, int ldc, float *mean_arr, float *bias_arr, int leaky_activation)
|
||||
{
|
||||
// total 57%
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
|
||||
__shared__ int C_s[WMMA_N * WMMA_M * 32 * 2]; // 2 * 8 KB - Temprorary result of GEMM WMMA for 32 warps
|
||||
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
const int global_warp_id = index / 32;
|
||||
|
||||
const int N_aligned = N + WMMA_Nx2 - (N % WMMA_Nx2);
|
||||
|
||||
/*
|
||||
__syncthreads();
|
||||
__shared__ uint32_t A_s[8 * 512]; // 8x512 = 8 x 16384 bits, instead of 8x4
|
||||
const int start_global_warp_id = blockIdx.x*blockDim.x / 32;
|
||||
int start_i = start_global_warp_id / (N_aligned / WMMA_N);
|
||||
start_i = start_i * WMMA_M;
|
||||
if (start_i + WMMA_M > M) start_i = M - WMMA_M; // must be: i+7 < M
|
||||
for (int tmp_index = threadIdx.x; tmp_index < (8 * 512); tmp_index += blockDim.x)
|
||||
{
|
||||
int k_tmp = tmp_index % 512;
|
||||
int local_i = tmp_index / 512;
|
||||
|
||||
uint32_t a_val = ((uint32_t *)(A))[(start_i + local_i)*lda/32 + k_tmp];
|
||||
A_s[local_i * 512 + k_tmp] = a_val;
|
||||
}
|
||||
__syncthreads();
|
||||
*/
|
||||
|
||||
|
||||
int i, j, k, h;
|
||||
// 47% = 29 + 10 + 8
|
||||
j = global_warp_id % (N_aligned / WMMA_Nx2);
|
||||
j = j * WMMA_Nx2;
|
||||
{ // out_h*out_w - one channel output size [169 - 173056]
|
||||
i = global_warp_id / (N_aligned / WMMA_Nx2);
|
||||
i = i * WMMA_M;
|
||||
|
||||
int count = 0;
|
||||
k = 0;
|
||||
|
||||
if (i < M) //if (i < M) // l.n - filters [16 - 55 - 1024]
|
||||
{
|
||||
if (j + WMMA_Nx2 > N) j = N - WMMA_Nx2; // must be: j+7 < N
|
||||
if (i + WMMA_M > M) i = M - WMMA_M; // must be: i+7 < M
|
||||
|
||||
#if __CUDA_ARCH__ >= 730
|
||||
// Tensor Cores
|
||||
using namespace nvcuda;
|
||||
|
||||
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, wmma::experimental::precision::b1, wmma::row_major> a_frag;
|
||||
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, wmma::experimental::precision::b1, wmma::col_major> b_frag;
|
||||
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, int> c1_frag, c2_frag;
|
||||
wmma::fill_fragment(c1_frag, 0); // !!!! XOR isn't XNOR !!!!!!!!!!
|
||||
wmma::fill_fragment(c2_frag, 0); // !!!! XOR isn't XNOR !!!!!!!!!!
|
||||
|
||||
// 8 x 8 x 4 (uint32_t, 4 * 32 = 128 bit)
|
||||
for (; k < K; k += 128) // l.size*l.size*l.c - one filter size [27 - 144 - 9216]
|
||||
{
|
||||
int64_t A_cur_index = (i*lda + k) / 8; // index in bits
|
||||
int64_t B1_cur_index = (j*ldb + k) / 8; // index in bits
|
||||
int64_t B2_cur_index = ((j + 8)*ldb + k) / 8; // index in bits
|
||||
|
||||
// try to use A that is cached in shared memory - poor performance
|
||||
//if (i == start_i) wmma::load_matrix_sync(a_frag, &A_s[k / 32], (512 * 32)); // lda = (128*32) bits
|
||||
//else wmma::load_matrix_sync(a_frag, (uint32_t *)(A + A_cur_index), lda); // lda = M
|
||||
|
||||
// lda, ldb - are in bits
|
||||
wmma::load_matrix_sync(a_frag, (uint32_t *)(A + A_cur_index), lda); // lda = M
|
||||
|
||||
wmma::load_matrix_sync(b_frag, (uint32_t *)(B + B1_cur_index), ldb); // ldb = K
|
||||
wmma::bmma_sync(c1_frag, a_frag, b_frag, c1_frag); // XOR-GEMM
|
||||
|
||||
wmma::load_matrix_sync(b_frag, (uint32_t *)(B + B2_cur_index), ldb); // ldb = K
|
||||
wmma::bmma_sync(c2_frag, a_frag, b_frag, c2_frag); // XOR-GEMM
|
||||
}
|
||||
// C[i*ldc + j]
|
||||
wmma::store_matrix_sync(&C_s[warp_id*WMMA_M*WMMA_N], c1_frag, WMMA_N, wmma::mem_row_major);
|
||||
wmma::store_matrix_sync(&C_s[warp_id*WMMA_M*WMMA_N + WMMA_M*WMMA_N*32], c2_frag, WMMA_N, wmma::mem_row_major);
|
||||
#else // __CUDA_ARCH__ >= 730
|
||||
|
||||
// Custom XOR-GEMM
|
||||
int k_d = lane_id % 4;
|
||||
int i_d = lane_id / 4;
|
||||
int j_d = lane_id / 4;
|
||||
|
||||
int32_t accum_c_val[8*2]; // wmma::fill_fragment(c_frag, 0);
|
||||
for (int local_j = 0; local_j < 8*2; ++local_j) {
|
||||
accum_c_val[local_j] = 0;
|
||||
}
|
||||
|
||||
// 8 x 8 x 4 (uint32_t, 4 * 32 = 128 bit)
|
||||
for (; k < K; k += 128) // l.size*l.size*l.c - one filter size [27 - 144 - 9216]
|
||||
{
|
||||
int64_t A_cur_index = (i*lda + k) / 8;
|
||||
//int64_t A_cur_index = (local_i*lda + k) / 8;
|
||||
int64_t B_cur_index = (j*ldb + k) / 8;
|
||||
|
||||
// lda, ldb - are in bits
|
||||
// 8*4 = 32
|
||||
// 8*8 = 64
|
||||
int k_d = lane_id % 4;
|
||||
int i_d = lane_id / 4;
|
||||
int j_d = lane_id / 4;
|
||||
uint32_t a_val = *(uint32_t *)(A + ((i + i_d)*lda + (k + k_d*32)) / 8); // wmma::load_matrix_sync(a_frag, (uint32_t *)(A + A_cur_index), lda);
|
||||
|
||||
for (int c_x = 0; c_x < 2; c_x++)
|
||||
{
|
||||
uint32_t b_val = *(uint32_t *)(B + ((c_x * 8 + j + j_d)*ldb + (k + k_d * 32)) / 8); // wmma::load_matrix_sync(b_frag, (uint32_t *)(B + B_cur_index), ldb);
|
||||
|
||||
// wmma::bmma_sync(c_frag, a_frag, b_frag, c_frag);
|
||||
int32_t c_val[8]; // 8 x 32 threads = 256
|
||||
#pragma UNROLL
|
||||
for (int local_j = 0; local_j < 8; ++local_j)
|
||||
{
|
||||
uint32_t b_val_cur = __shfl_custom(b_val, local_j * 4 + k_d);
|
||||
c_val[local_j] = __popc(xor_int32(a_val, b_val_cur));
|
||||
}
|
||||
|
||||
#pragma UNROLL
|
||||
for (int local_j = 0; local_j < 8; ++local_j)
|
||||
{
|
||||
#pragma UNROLL
|
||||
for (int local_k = 0; local_k < 4; ++local_k) {
|
||||
accum_c_val[local_j + c_x*8] += __shfl_custom(c_val[local_j], i_d * 4 + local_k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// only the first 8 threads (i) contain 8 good values each, in c_val[8] (j) = 8 x 8 =64
|
||||
// wmma::store_matrix_sync(&C_s[warp_id*WMMA_M*WMMA_N], c_frag, WMMA_N, wmma::mem_row_major);
|
||||
if (k_d == 0) {
|
||||
for (int c_x = 0; c_x < 2; c_x++)
|
||||
{
|
||||
for (int local_j = 0; local_j < 8; ++local_j)
|
||||
{
|
||||
C_s[warp_id*WMMA_M*WMMA_N + i_d*WMMA_N + local_j + WMMA_M*WMMA_N*32 * c_x] = accum_c_val[local_j + c_x*8];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // __CUDA_ARCH__ >= 730
|
||||
|
||||
for(int c_x = 0; c_x < 2; c_x++)
|
||||
{
|
||||
int j_d = lane_id % WMMA_N;
|
||||
{
|
||||
#pragma UNROLL
|
||||
for (int i_d = lane_id / WMMA_N; i_d < WMMA_M; i_d += WMMA_M / 2)
|
||||
{
|
||||
int count = C_s[warp_id*WMMA_M*WMMA_N + i_d*WMMA_N + j_d + WMMA_M*WMMA_N*32*c_x];
|
||||
|
||||
const int bit_step = 128;
|
||||
int f1 = (K % bit_step == 0) ? 0 : (bit_step - (K % bit_step));
|
||||
count = count - f1; // remove extra bits (from empty space for align only)
|
||||
|
||||
count = (2 * count - K);
|
||||
|
||||
float mean_val = mean_arr[i + i_d];
|
||||
float bias_val = bias_arr[i + i_d];
|
||||
float dst_val = count *mean_val + bias_val;
|
||||
if (leaky_activation)
|
||||
dst_val = (dst_val > 0) ? (dst_val) : (0.1f*dst_val); // Leaky activation
|
||||
|
||||
C[(i + i_d)*ldc + (c_x*8 + j + j_d)] = dst_val;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // CUDART_VERSION >= 10000
|
||||
|
||||
/*
|
||||
// Tensor Cores are used for XOR-GEMM
|
||||
__global__ void gemm_nn_custom_bin_mean_transposed_tensor_kernel(int M, int N, int K,
|
||||
unsigned char *A, int lda,
|
||||
@ -1368,7 +1551,8 @@ __global__ void gemm_nn_custom_bin_mean_transposed_tensor_kernel(int M, int N, i
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // CUDART_VERSION >= 10000
|
||||
*/
|
||||
|
||||
|
||||
// Coalescing
|
||||
// A (weights) in the shared_memory - GOOD
|
||||
@ -1583,8 +1767,8 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K,
|
||||
if (1)
|
||||
{
|
||||
const int M_aligned = M + (8 - (M % 8));
|
||||
const int N_aligned = N + (8 - (N % 8));
|
||||
int size = (M_aligned / 8)*(N_aligned / 8)*WARP_SIZE;
|
||||
const int N_aligned = N + (16 - (N % 16));
|
||||
int size = (M_aligned / 8)*(N_aligned / 16)*WARP_SIZE;
|
||||
const int num_blocks = get_number_of_blocks(size, BLOCK);
|
||||
|
||||
//printf(" lda = %d, ldb = %d, ldc = %d, lda/32 = %d, ldb/32 = %d, ldc/32 = %d \n", lda, ldb, ldc, lda / 32, ldb / 32, ldc / 32);
|
||||
|
Reference in New Issue
Block a user