mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
XNOR-net 21 FPS on CPU yolov2-tiny.cfg
This commit is contained in:
@ -132,6 +132,7 @@ size_t get_workspace_size(layer l){
|
||||
return most;
|
||||
}
|
||||
#endif
|
||||
if(l.xnor) return (size_t)l.bit_align*l.size*l.size*l.c * sizeof(float);
|
||||
return (size_t)l.out_h*l.out_w*l.size*l.size*l.c*sizeof(float);
|
||||
}
|
||||
|
||||
@ -305,6 +306,10 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
if(xnor){
|
||||
l.binary_weights = calloc(c*n*size*size, sizeof(float));
|
||||
l.binary_input = calloc(l.inputs*l.batch, sizeof(float));
|
||||
|
||||
int align = 8;
|
||||
int src_align = l.out_h*l.out_w;
|
||||
l.bit_align = src_align + (align - src_align % align);
|
||||
}
|
||||
|
||||
if(batch_normalize){
|
||||
@ -622,7 +627,7 @@ void binary_align_weights(convolutional_layer *l)
|
||||
}
|
||||
|
||||
// further optimizations: im2col_bin() for XNOR, and then transpose_aling_bin()
|
||||
size_t binary_transpose_align_input(int k, int n, float *b, char **t_bit_input, size_t ldb_align)
|
||||
size_t binary_transpose_align_input(int k, int n, float *b, char **t_bit_input, size_t ldb_align, int bit_align)
|
||||
{
|
||||
size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8;
|
||||
size_t t_intput_size = new_ldb * n;
|
||||
@ -637,12 +642,17 @@ size_t binary_transpose_align_input(int k, int n, float *b, char **t_bit_input,
|
||||
//printf("\n align_weights_size = %d, k = %d, m = %d, lda = %d \n", align_weights_size, k, m, k);
|
||||
//printf("\n align_bit_weights_size = %d, k = %d, m = %d, new_lda = %d \n", align_bit_weights_size, k, m, new_ldb);
|
||||
|
||||
int blocksize = 64;
|
||||
transpose_block_SSE4x4(b, t_input, k, n, n, new_ldb, blocksize);
|
||||
int src_size = k * bit_align;
|
||||
|
||||
//printf("\n blocksize = %d \n", blocksize);
|
||||
float_to_bit(b, t_input, src_size);
|
||||
|
||||
// b - [bit_align, k] - [l.bit_align, l.size*l.size*l.c] = src_size
|
||||
// t_input - [bit_align, k] - [n', k]
|
||||
// t_bit_input - [new_ldb, n] - [k', n]
|
||||
|
||||
transpose_bin(t_input, *t_bit_input, k, n, bit_align, new_ldb, 8);
|
||||
//transpose_bin(b, *t_bit_input, k, n, bit_align, new_ldb, 8);
|
||||
|
||||
float_to_bit(t_input, *t_bit_input, t_intput_size);
|
||||
free(t_input);
|
||||
|
||||
return t_intput_size;
|
||||
@ -691,12 +701,16 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
||||
//if (l.xnor && l.size == 3 && l.stride == 1 && l.pad == 1) {}
|
||||
//else
|
||||
// further optimizations: im2col_bin() for XNOR, and then transpose_aling_bin()
|
||||
im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b);
|
||||
//im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b);
|
||||
|
||||
|
||||
//gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
|
||||
//gemm_nn_custom(m, n, k, 1, a, k, b, n, c, n);
|
||||
if (l.xnor) {
|
||||
//im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b);
|
||||
memset(b, 0, l.bit_align*l.size*l.size*l.c * sizeof(float));
|
||||
im2col_cpu_custom_bin(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align);
|
||||
|
||||
size_t output_size = l.outputs;
|
||||
//float *count_output = calloc(output_size, sizeof(float));
|
||||
//size_t bit_output_size = output_size / 8 + 1;
|
||||
@ -790,7 +804,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
||||
int ldb_align = l.lda_align;
|
||||
size_t new_ldb = k + (ldb_align - k%ldb_align);
|
||||
char *t_bit_input = NULL;
|
||||
size_t t_intput_size = binary_transpose_align_input(k, n, b, &t_bit_input, ldb_align);
|
||||
size_t t_intput_size = binary_transpose_align_input(k, n, b, &t_bit_input, ldb_align, l.bit_align);
|
||||
//char *t_bit_input = calloc(new_ldb * n, sizeof(char)); // for im2col_cpu_custom_transpose() only
|
||||
//float_to_bit(t_input, t_bit_input, new_ldb * n); // for im2col_cpu_custom_transpose() only
|
||||
|
||||
@ -825,6 +839,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
||||
//free(mean_arr);
|
||||
}
|
||||
else {
|
||||
im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b);
|
||||
|
||||
gemm(0, 0, m, n, k, 1, a, k, b, n, 1, c, n);
|
||||
// bit-count to float
|
||||
}
|
||||
|
216
src/gemm.c
216
src/gemm.c
@ -6,6 +6,7 @@
|
||||
#include <stdio.h>
|
||||
#include <math.h>
|
||||
#include <float.h>
|
||||
#include <string.h>
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
@ -595,7 +596,7 @@ void convolution_2d(int w, int h, int ksize, int n, int c, int pad, int stride,
|
||||
static int max_num_threads = 0;
|
||||
if (max_num_threads == 0) {
|
||||
max_num_threads = omp_get_max_threads();
|
||||
omp_set_num_threads( max_num_threads / 2);
|
||||
//omp_set_num_threads( max_num_threads / 2);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -1037,27 +1038,212 @@ void im2col_cpu_custom(float* data_im,
|
||||
}
|
||||
}
|
||||
|
||||
void transpose_8x8_bits(unsigned char A[8], unsigned char B[8], int m, int n)
|
||||
|
||||
//From Berkeley Vision's Caffe!
|
||||
//https://github.com/BVLC/caffe/blob/master/LICENSE
|
||||
void im2col_cpu_custom_bin(float* data_im,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, int pad, float* data_col, int bit_align)
|
||||
{
|
||||
int c, h, w;
|
||||
int height_col = (height + 2 * pad - ksize) / stride + 1;
|
||||
int width_col = (width + 2 * pad - ksize) / stride + 1;
|
||||
int channels_col = channels * ksize * ksize;
|
||||
|
||||
// optimized version
|
||||
if (height_col == height && width_col == width && stride == 1 && pad == 1 && is_fma_avx2())
|
||||
{
|
||||
__m256i all256_sing1 = _mm256_set_epi32(0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000);
|
||||
|
||||
//int algin = 8;
|
||||
//int ldb = width_col * height_col;
|
||||
//int new_ldb = ldb + (algin - ldb % algin);
|
||||
int new_ldb = bit_align;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (c = 0; c < channels_col; ++c) {
|
||||
int w_offset = c % ksize;
|
||||
int h_offset = (c / ksize) % ksize;
|
||||
int c_im = c / ksize / ksize;
|
||||
for (h = pad; h < height_col - pad; ++h) {
|
||||
for (w = pad; w < width_col - pad - 8; w += 8) {
|
||||
int im_row = h_offset + h - pad;
|
||||
int im_col = w_offset + w - pad;
|
||||
//int col_index = (c * height_col + h) * width_col + w;
|
||||
int col_index = c * new_ldb + h * width_col + w;
|
||||
|
||||
//data_col[col_index] = data_im[im_col + width*(im_row + height*c_im)];
|
||||
__m256 src256 = _mm256_loadu_ps((float *)(&data_im[im_col + width*(im_row + height*c_im)]));
|
||||
_mm256_storeu_ps(&data_col[col_index], src256);
|
||||
|
||||
/*/
|
||||
__m256i src256 = _mm256_loadu_si256((__m256i *)(&data_im[im_col + width*(im_row + height*c_im)]));
|
||||
__m256i result256 = _mm256_and_si256(src256, all256_sing1); // check sign in 8 x 32-bit floats
|
||||
|
||||
uint32_t mask = _mm256_movemask_ps(_mm256_castsi256_ps(result256)); // (val >= 0) ? 0 : 1
|
||||
mask = ~mask; // inverse mask, (val >= 0) ? 1 : 0
|
||||
|
||||
data_col[col_index / 8] = mask; // dst[i / 8] = mask;
|
||||
*/
|
||||
}
|
||||
|
||||
for (; w < width_col - pad; ++w) {
|
||||
int im_row = h_offset + h - pad;
|
||||
int im_col = w_offset + w - pad;
|
||||
//int col_index = (c * height_col + h) * width_col + w;
|
||||
int col_index = c * new_ldb + h * width_col + w;
|
||||
|
||||
data_col[col_index] = data_im[im_col + width*(im_row + height*c_im)];
|
||||
float val = data_im[im_col + width*(im_row + height*c_im)];
|
||||
//if(val > 0) set_bit(data_col, col_index);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
w = 0;
|
||||
for (h = 0; h < height_col; ++h) {
|
||||
int im_row = h_offset + h;
|
||||
int im_col = w_offset + w;
|
||||
//int col_index = (c * height_col + h) * width_col + w;
|
||||
int col_index = c * new_ldb + h * width_col + w;
|
||||
|
||||
data_col[col_index] = im2col_get_pixel(data_im, height, width, channels, im_row, im_col, c_im, pad);
|
||||
float val = im2col_get_pixel(data_im, height, width, channels, im_row, im_col, c_im, pad);
|
||||
//if (val > 0) set_bit(data_col, col_index);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
w = width_col - 1;
|
||||
for (h = 0; h < height_col; ++h) {
|
||||
int im_row = h_offset + h;
|
||||
int im_col = w_offset + w;
|
||||
//int col_index = (c * height_col + h) * width_col + w;
|
||||
int col_index = c * new_ldb + h * width_col + w;
|
||||
|
||||
data_col[col_index] = im2col_get_pixel(data_im, height, width, channels, im_row, im_col, c_im, pad);
|
||||
float val = im2col_get_pixel(data_im, height, width, channels, im_row, im_col, c_im, pad);
|
||||
//if (val > 0) set_bit(data_col, col_index);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
h = 0;
|
||||
for (w = 0; w < width_col; ++w) {
|
||||
int im_row = h_offset + h;
|
||||
int im_col = w_offset + w;
|
||||
//int col_index = (c * height_col + h) * width_col + w;
|
||||
int col_index = c * new_ldb + h * width_col + w;
|
||||
|
||||
data_col[col_index] = im2col_get_pixel(data_im, height, width, channels, im_row, im_col, c_im, pad);
|
||||
float val = im2col_get_pixel(data_im, height, width, channels, im_row, im_col, c_im, pad);
|
||||
//if (val > 0) set_bit(data_col, col_index);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
h = height_col - 1;
|
||||
for (w = 0; w < width_col; ++w) {
|
||||
int im_row = h_offset + h;
|
||||
int im_col = w_offset + w;
|
||||
//int col_index = (c * height_col + h) * width_col + w;
|
||||
int col_index = c * new_ldb + h * width_col + w;
|
||||
|
||||
data_col[col_index] = im2col_get_pixel(data_im, height, width, channels, im_row, im_col, c_im, pad);
|
||||
float val = im2col_get_pixel(data_im, height, width, channels, im_row, im_col, c_im, pad);
|
||||
//if (val > 0) set_bit(data_col, col_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
else {
|
||||
printf("\n Error: is no non-optimized version \n");
|
||||
//im2col_cpu(data_im, channels, height, width, ksize, stride, pad, data_col);
|
||||
}
|
||||
|
||||
/*
|
||||
int src_size = bit_align*channels_col;
|
||||
char *bit_arr = calloc(src_size, sizeof(float));
|
||||
float_to_bit(data_col, bit_arr, src_size);
|
||||
memcpy(data_col, bit_arr, src_size * sizeof(float));
|
||||
free(bit_arr);
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
void transpose_8x8_bits_my(unsigned char *A, unsigned char *B, int lda, int ldb)
|
||||
{
|
||||
unsigned x, y, t;
|
||||
for (y = 0; y < 8; ++y) {
|
||||
for (x = 0; x < 8; ++x) {
|
||||
if(A[y * lda] & (1 << x)) B[x * ldb] |= 1 << y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsigned char reverse_byte_1(char a)
|
||||
{
|
||||
return ((a & 0x1) << 7) | ((a & 0x2) << 5) |
|
||||
((a & 0x4) << 3) | ((a & 0x8) << 1) |
|
||||
((a & 0x10) >> 1) | ((a & 0x20) >> 3) |
|
||||
((a & 0x40) >> 5) | ((a & 0x80) >> 7);
|
||||
}
|
||||
|
||||
unsigned char reverse_byte_2(unsigned char a)
|
||||
{
|
||||
return ((a * 0x0802LU & 0x22110LU) | (a * 0x8020LU & 0x88440LU)) * 0x10101LU >> 16;
|
||||
}
|
||||
|
||||
static unsigned char lookup[16] = {
|
||||
0x0, 0x8, 0x4, 0xc, 0x2, 0xa, 0x6, 0xe,
|
||||
0x1, 0x9, 0x5, 0xd, 0x3, 0xb, 0x7, 0xf, };
|
||||
|
||||
unsigned char reverse_byte(unsigned char n) {
|
||||
// Reverse the top and bottom nibble then swap them.
|
||||
return (lookup[n & 0b1111] << 4) | lookup[n >> 4];
|
||||
}
|
||||
|
||||
|
||||
void transpose8rS32_reversed_diagonale(unsigned char* A, int m, int n, unsigned char* B)
|
||||
{
|
||||
unsigned x, y, t;
|
||||
|
||||
// Load the array and pack it into x and y.
|
||||
|
||||
x = (A[0] << 24) | (A[m] << 16) | (A[2 * m] << 8) | A[3 * m];
|
||||
y = (A[4 * m] << 24) | (A[5 * m] << 16) | (A[6 * m] << 8) | A[7 * m];
|
||||
|
||||
t = (x ^ (x >> 7)) & 0x00AA00AA; x = x ^ t ^ (t << 7);
|
||||
t = (y ^ (y >> 7)) & 0x00AA00AA; y = y ^ t ^ (t << 7);
|
||||
t = (x ^ (x >> 7)) & 0x00AA00AA; x = x ^ t ^ (t << 7);
|
||||
t = (y ^ (y >> 7)) & 0x00AA00AA; y = y ^ t ^ (t << 7);
|
||||
|
||||
t = (x ^ (x >> 14)) & 0x0000CCCC; x = x ^ t ^ (t << 14);
|
||||
t = (y ^ (y >> 14)) & 0x0000CCCC; y = y ^ t ^ (t << 14);
|
||||
t = (x ^ (x >> 14)) & 0x0000CCCC; x = x ^ t ^ (t << 14);
|
||||
t = (y ^ (y >> 14)) & 0x0000CCCC; y = y ^ t ^ (t << 14);
|
||||
|
||||
t = (x & 0xF0F0F0F0) | ((y >> 4) & 0x0F0F0F0F);
|
||||
y = ((x << 4) & 0xF0F0F0F0) | (y & 0x0F0F0F0F);
|
||||
x = t;
|
||||
|
||||
B[0] = x >> 24; B[n] = x >> 16; B[2 * n] = x >> 8; B[3 * n] = x;
|
||||
B[4 * n] = y >> 24; B[5 * n] = y >> 16; B[6 * n] = y >> 8; B[7 * n] = y;
|
||||
B[7 * n] = reverse_byte(x >> 24); B[6 * n] = reverse_byte(x >> 16); B[5 * n] = reverse_byte(x >> 8); B[4 * n] = reverse_byte(x);
|
||||
B[3 * n] = reverse_byte(y >> 24); B[2 * n] = reverse_byte(y >> 16); B[1 * n] = reverse_byte(y >> 8); B[0 * n] = reverse_byte(y);
|
||||
}
|
||||
|
||||
void transpose_bin(char *A, char *B, const int n, const int m,
|
||||
const int lda, const int ldb, const int block_size)
|
||||
{
|
||||
int i;
|
||||
#pragma omp parallel for
|
||||
for (i = 0; i < n; i += 8) {
|
||||
int j;
|
||||
for (j = 0; j < m - 8; j += 8) {
|
||||
int a_index = i*lda + j;
|
||||
int b_index = j*ldb + i;
|
||||
//transpose_8x8_bits_my(&A[a_index/8], &B[b_index/8], lda/8, ldb/8);
|
||||
transpose8rS32_reversed_diagonale(&A[a_index / 8], lda / 8, ldb / 8, &B[b_index / 8]);
|
||||
}
|
||||
for (; j < m; ++j) {
|
||||
if (get_bit(A, i*lda + j)) set_bit(B, j*ldb + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void activate_array_cpu_custom(float *x, const int n, const ACTIVATION a)
|
||||
@ -1102,14 +1288,18 @@ void float_to_bit(float *src, unsigned char *dst, size_t size)
|
||||
|
||||
size_t i;
|
||||
__m256i all256_sing1 = _mm256_set_epi32(0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000);
|
||||
__m256 float_zero256 = _mm256_set1_ps(0.00001);
|
||||
|
||||
for (i = 0; i < size; i+=8)
|
||||
{
|
||||
__m256i src256 = _mm256_loadu_si256((__m256i *)(&src[i]));
|
||||
__m256i result256 = _mm256_and_si256(src256, all256_sing1); // check sign in 8 x 32-bit floats
|
||||
//__m256i src256 = _mm256_loadu_si256((__m256i *)(&src[i]));
|
||||
//__m256i result256 = _mm256_and_si256(src256, all256_sing1); // check sign in 8 x 32-bit floats
|
||||
//uint32_t mask = _mm256_movemask_ps(_mm256_castsi256_ps(result256)); // (val >= 0) ? 0 : 1
|
||||
//mask = ~mask; // inverse mask, (val >= 0) ? 1 : 0
|
||||
|
||||
uint32_t mask = _mm256_movemask_ps(_mm256_castsi256_ps(result256)); // (val >= 0) ? 0 : 1
|
||||
mask = ~mask; // inverse mask, (val >= 0) ? 1 : 0
|
||||
__m256 src256 = _mm256_loadu_ps((float *)(&src[i]));
|
||||
__m256 result256 = _mm256_cmp_ps(src256, float_zero256, _CMP_GT_OS);
|
||||
uint32_t mask = _mm256_movemask_ps(result256); // (val > 0) ? 0 : 1
|
||||
|
||||
dst[i / 8] = mask;
|
||||
}
|
||||
|
@ -11,12 +11,14 @@ static inline void set_bit(unsigned char *const dst, size_t index) {
|
||||
size_t dst_i = index / 8;
|
||||
int dst_shift = index % 8;
|
||||
dst[dst_i] |= 1 << dst_shift;
|
||||
//dst[dst_i] |= 1 << (8 - dst_shift);
|
||||
}
|
||||
|
||||
static inline unsigned char get_bit(unsigned char const*const src, size_t index) {
|
||||
size_t src_i = index / 8;
|
||||
int src_shift = index % 8;
|
||||
unsigned char val = (src[src_i] & (1 << src_shift)) > 0;
|
||||
//unsigned char val = (src[src_i] & (1 << (8 - src_shift))) > 0;
|
||||
return val;
|
||||
}
|
||||
|
||||
@ -25,6 +27,9 @@ void float_to_bit(float *src, unsigned char *dst, size_t size);
|
||||
void transpose_block_SSE4x4(float *A, float *B, const int n, const int m,
|
||||
const int lda, const int ldb, const int block_size);
|
||||
|
||||
void transpose_bin(char *A, char *B, const int n, const int m,
|
||||
const int lda, const int ldb, const int block_size);
|
||||
|
||||
void gemm_nn_custom_bin_mean_transposed(int M, int N, int K, float ALPHA_UNUSED,
|
||||
unsigned char *A, int lda,
|
||||
unsigned char *B, int ldb,
|
||||
@ -34,6 +39,10 @@ void im2col_cpu_custom(float* data_im,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, int pad, float* data_col);
|
||||
|
||||
void im2col_cpu_custom_bin(float* data_im,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, int pad, float* data_col, int bit_align);
|
||||
|
||||
void im2col_cpu_custom_transpose(float* data_im,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, int pad, float* data_col, int ldb_align);
|
||||
|
@ -182,6 +182,7 @@ struct layer{
|
||||
char *align_bit_weights;
|
||||
float *mean_arr;
|
||||
int lda_align;
|
||||
int bit_align;
|
||||
|
||||
float *col_image;
|
||||
int * input_layers;
|
||||
|
Reference in New Issue
Block a user