XNOR-net on CPU AVX2

This commit is contained in:
AlexeyAB
2018-08-07 23:20:25 +03:00
parent e6c97a53a7
commit 0a326e7afe
10 changed files with 710 additions and 77 deletions

View File

@ -249,8 +249,8 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference)
if (l->bf_algo == CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED) bf = 2;
//printf("Tensor Cores - Backward-filter enabled: l->bf_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED \n");
if (fw == 2 && bd == 2 && bf == 2) printf("TF ");
else if (fw == 1 && bd == 1 && bf == 1) printf("TH ");
//if (fw == 2 && bd == 2 && bf == 2) printf("TF ");
//else if (fw == 1 && bd == 1 && bf == 1) printf("TH ");
}
}
#endif
@ -543,6 +543,85 @@ void backward_bias(float *bias_updates, float *delta, int batch, int n, int size
}
}
void gemm_nn_custom(int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float *C, int ldc)
{
int i, j, k;
for (i = 0; i < M; ++i) {
for (k = 0; k < K; ++k) {
register float A_PART = ALPHA*A[i*lda + k];
//printf("\n weight = %f \n", A_PART);
for (j = 0; j < N; ++j) {
C[i*ldc + j] += A_PART*B[k*ldb + j];
}
}
}
}
void get_mean_array(float *src, size_t size, size_t filters, float *mean_arr) {
size_t i, counter;
counter = 0;
for (i = 0; i < size; i += size / filters) {
mean_arr[counter++] = fabs(src[i]);
}
}
/*
void float_to_bit(float *src, unsigned char *dst, size_t size) {
size_t dst_size = size / 8 + 1;
memset(dst, 0, dst_size);
size_t i, dst_i, dst_shift;
for (i = 0; i < size; ++i) {
if (src[i] > 0) set_bit(dst, i);
}
}
*/
void bit_to_float(unsigned char *src, float *dst, size_t size, size_t filters, float *mean_arr) {
memset(dst, 0, size *sizeof(float));
size_t i, src_i, src_shift;
for (i = 0; i < size; ++i) {
float mean_val = 1;
if(mean_arr != NULL) mean_val = fabs(mean_arr[i / (size / filters)]);
if(get_bit(src, i)) dst[i] = mean_val;
else dst[i] = -mean_val;
}
}
void binary_transpose_align_weights(convolutional_layer *l, size_t ldb_align)
{
int m = l->n;
int k = l->size*l->size*l->c;
size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8;
binarize_weights(l->weights, m, k, l->binary_weights);
size_t align_weights_size = new_ldb * m;
size_t align_bit_weights_size = align_weights_size / 8;// +1;
float *align_weights = calloc(align_weights_size, sizeof(float));
l->align_bit_weights = calloc(align_bit_weights_size, sizeof(char));
size_t i, j;
// align A without transpose
for (i = 0; i < m; ++i) {
for (j = 0; j < k; ++j) {
align_weights[i*new_ldb + j] = l->binary_weights[i*k + j];
}
}
float_to_bit(align_weights, l->align_bit_weights, align_weights_size);
l->mean_arr = calloc(l->n, sizeof(float));
get_mean_array(align_weights, align_weights_size, l->n, l->mean_arr);
free(align_weights);
}
void forward_convolutional_layer(convolutional_layer l, network_state state)
{
int out_h = convolutional_out_height(l);
@ -552,7 +631,10 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
fill_cpu(l.outputs*l.batch, 0, l.output, 1);
if(l.xnor){
if (!l.align_bit_weights) {
binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
//printf("\n binarize_weights l.align_bit_weights = %p \n", l.align_bit_weights);
}
swap_binary(&l);
binarize_cpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input);
state.input = l.binary_input;
@ -562,15 +644,122 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
int k = l.size*l.size*l.c;
int n = out_h*out_w;
float *a = l.weights;
float *b = state.workspace;
float *c = l.output;
static int u = 0;
u++;
for(i = 0; i < l.batch; ++i){
im2col_cpu(state.input, l.c, l.h, l.w,
l.size, l.stride, l.pad, b);
//gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
//gemm_nn_custom(m, n, k, 1, a, k, b, n, c, n);
if (l.xnor) {
size_t output_size = l.outputs;
//float *count_output = calloc(output_size, sizeof(float));
//size_t bit_output_size = output_size / 8 + 1;
//char *bit_output = calloc(bit_output_size, sizeof(char));
size_t intput_size = n * k; // (out_h*out_w) X (l.size*l.size*l.c) : after im2col()
size_t bit_input_size = intput_size / 8 + 1;
//char *bit_input = calloc(bit_input_size, sizeof(char));
size_t weights_size = k * m; //l.size*l.size*l.c*l.n;
size_t bit_weights_size = weights_size / 8 + 1;
//char *bit_weights = calloc(bit_weights_size, sizeof(char));
//float *mean_arr = calloc(l.n, sizeof(float));
// test: float->bit->float
//get_mean_array(l.weights, weights_size, l.n, mean_arr);
//float_to_bit(l.weights, bit_weights, weights_size);
//memset(l.weights, 0, weights_size * sizeof(float));
//bit_to_float(bit_weights, l.weights, weights_size, l.n, mean_arr); // just for test float->bit->float
//float_to_bit(b, bit_input, intput_size);
//memset(b, 0, intput_size * sizeof(float));
//bit_to_float(bit_input, b, intput_size, 1, NULL); // just for test float->bit->float
// transpose B from NxK to KxN (x-axis (ldb = l.size*l.size*l.c) - should be multiple of 8 bits)
{
size_t ldb_align = 256;// 8;
size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8;
size_t t_intput_size = new_ldb * n;
size_t t_bit_input_size = t_intput_size / 8;// +1;
float *t_input = calloc(t_intput_size, sizeof(float));
char *t_bit_input = calloc(t_bit_input_size, sizeof(char));
//printf("\n bit_input_size = %d, n = %d, k = %d, ldb = %d \n", bit_input_size, n, k, n);
//printf("\n t_bit_input_size = %d, k = %d, n = %d, new_ldb = %d \n", t_bit_input_size, k, n, new_ldb);
//printf("\n align_weights_size = %d, k = %d, m = %d, lda = %d \n", align_weights_size, k, m, k);
//printf("\n align_bit_weights_size = %d, k = %d, m = %d, new_lda = %d \n", align_bit_weights_size, k, m, new_ldb);
// transpose and align B
int i, j;
for (i = 0; i < n; ++i) {
for (j = 0; j < k; ++j) {
t_input[i*new_ldb + j] = b[j*n + i];
}
}
float_to_bit(t_input, t_bit_input, t_intput_size);
if (!l.align_bit_weights)
{
size_t align_weights_size = new_ldb * m;
size_t align_bit_weights_size = align_weights_size / 8;// +1;
float *align_weights = calloc(align_weights_size, sizeof(float));
l.align_bit_weights = calloc(align_bit_weights_size, sizeof(char));
// align A without transpose
for (i = 0; i < m; ++i) {
for (j = 0; j < k; ++j) {
align_weights[i*new_ldb + j] = a[i*k + j];
}
}
float_to_bit(align_weights, l.align_bit_weights, align_weights_size);
l.mean_arr = calloc(l.n, sizeof(float));
get_mean_array(align_weights, align_weights_size, l.n, l.mean_arr);
free(align_weights);
}
gemm_nn_custom_bin_mean_transposed(m, n, k, 1, l.align_bit_weights, new_ldb, t_bit_input, new_ldb, c, n, l.mean_arr);
//gemm_nn_custom_bin_mean_transposed(m, n, k, 1, bit_weights, k, t_bit_input, new_ldb, c, n, mean_arr);
free(t_input);
free(t_bit_input);
//free(align_bit_weights);
}
// for bit_input: (k * n)
//if (u == 8) gemm_nn_custom_bin_mean(m, n, k, 1, bit_weights, k, bit_input, n, c, n, mean_arr); // last xnor layer
//else gemm_nn_custom_bin_mean(m, n, k, 1, bit_weights, k, bit_input, n, c, n, NULL);
//gemm_nn_custom_bin_mean(m, n, k, 1, bit_weights, k, bit_input, n, c, n, mean_arr);
//printf("\n u = %d \n", u);
//gemm_nn_custom(m, n, k, 1, a, k, b, n, c, n);
//int j;
//if (u != 8) for (j = 0; j < l.n; ++j) l.biases[j] = l.biases[j] / (mean_arr[j]*2);
//free(count_output);
//free(bit_input);
//free(bit_weights);
//free(mean_arr);
}
else {
gemm(0, 0, m, n, k, 1, a, k, b, n, 1, c, n);
// bit-count to float
}
c += n*m;
state.input += l.c*l.h*l.w;
}

View File

@ -35,6 +35,8 @@ void binarize_weights(float *weights, int n, int size, float *binary);
void swap_binary(convolutional_layer *l);
void binarize_weights2(float *weights, int n, int size, char *binary, float *scales);
void binary_transpose_align_weights(convolutional_layer *l, size_t ldb_align);
void backward_convolutional_layer(convolutional_layer layer, network_state state);
void add_bias(float *output, float *biases, int batch, int n, int size);

View File

@ -146,6 +146,7 @@ void demo(char *cfgfile, char *weightfile, float thresh, float hier_thresh, int
}
//set_batch_network(&net, 1);
fuse_conv_batchnorm(net);
calculate_binary_weights(net);
srand(2222222);
if(filename){

View File

@ -568,6 +568,7 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float
}
//set_batch_network(&net, 1);
fuse_conv_batchnorm(net);
calculate_binary_weights(net);
srand(time(0));
list *plist = get_paths(valid_images);
@ -1094,6 +1095,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
}
//set_batch_network(&net, 1);
fuse_conv_batchnorm(net);
calculate_binary_weights(net);
if (net.layers[net.n - 1].classes != names_size) {
printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n",
name_list, names_size, net.layers[net.n - 1].classes, cfgfile);

View File

@ -71,6 +71,234 @@ void gemm(int TA, int TB, int M, int N, int K, float ALPHA,
gemm_cpu( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
}
//--------------------------------------------
// XNOR bitwise GEMM for binary neural network
//--------------------------------------------
#include <stdint.h>
static inline unsigned char xnor(unsigned char a, unsigned char b) {
//return a == b;
return !(a^b);
}
// INT-32
static inline uint32_t get_bit_int32(uint32_t const*const src, size_t index) {
size_t src_i = index / 32;
int src_shift = index % 32;
unsigned char val = (src[src_i] & (1 << src_shift)) > 0;
return val;
}
static inline uint32_t xnor_int32(uint32_t a, uint32_t b) {
return ~(a^b);
}
static inline uint64_t xnor_int64(uint64_t a, uint64_t b) {
return ~(a^b);
}
static inline uint32_t fill_bit_int32(char src) {
if (src == 0) return 0x00000000;
else return 0xFFFFFFFF;
}
static inline uint64_t fill_bit_int64(char src) {
if (src == 0) return 0x0000000000000000;
else return 0xFFFFFFFFFFFFFFFF;
}
void binary_int32_printf(uint32_t src) {
int i;
for (i = 0; i < 32; ++i) {
if (src & 1) printf("1");
else printf("0");
src = src >> 1;
}
printf("\n");
}
void binary_int64_printf(uint64_t src) {
int i;
for (i = 0; i < 64; ++i) {
if (src & 1) printf("1");
else printf("0");
src = src >> 1;
}
printf("\n");
}
/*
void gemm_nn_custom_bin_mean(int M, int N, int K, float ALPHA_UNUSED,
unsigned char *A, int lda,
unsigned char *B, int ldb,
float *C, int ldc, float *mean_arr)
{
int *count_arr = calloc(M*N, sizeof(int));
int i, j, k;
for (i = 0; i < M; ++i) { // l.n - filters [16 - 55 - 1024]
for (k = 0; k < K; ++k) { // l.size*l.size*l.c - one filter size [27 - 9216]
char a_bit = get_bit(A, i*lda + k);
for (j = 0; j < N; ++j) { // out_h*out_w - one channel output size [169 - 173056]
char b_bit = get_bit(B, k*ldb + j);
count_arr[i*ldc + j] += xnor(a_bit, b_bit);
}
}
}
for (i = 0; i < M; ++i) {
float mean_val = mean_arr[i];
for (j = 0; j < N; ++j) {
C[i*ldc + j] = (2 * count_arr[i*ldc + j] - K) * mean_val;
}
}
free(count_arr);
}
*/
/*
void gemm_nn_custom_bin_mean_transposed(int M, int N, int K, float ALPHA_UNUSED,
unsigned char *A, int lda,
unsigned char *B, int ldb,
float *C, int ldc, float *mean_arr)
{
int *count_arr = calloc(M*N, sizeof(int));
int i, j, k;
for (i = 0; i < M; ++i) { // l.n - filters [16 - 55 - 1024]
for (j = 0; j < N; ++j) { // out_h*out_w - one channel output size [169 - 173056]
for (k = 0; k < K; ++k) { // l.size*l.size*l.c - one filter size [27 - 9216]
char a_bit = get_bit(A, i*lda + k);
char b_bit = get_bit(B, j*ldb + k);
count_arr[i*ldc + j] += xnor(a_bit, b_bit);
}
}
}
for (i = 0; i < M; ++i) {
float mean_val = mean_arr[i];
for (j = 0; j < N; ++j) {
C[i*ldc + j] = (2 * count_arr[i*ldc + j] - K) * mean_val;
}
}
free(count_arr);
}
*/
/*
void gemm_nn_custom_bin_mean(int M, int N, int K, float ALPHA_UNUSED,
unsigned char *A, int lda,
unsigned char *B, int ldb,
float *C, int ldc, float *mean_arr)
{
int *count_arr = calloc(M*N, sizeof(int));
int i, j, k, h;
#pragma omp parallel for
for (i = 0; i < M; ++i) { // l.n - filters [16 - 55 - 1024]
for (k = 0; k < K; ++k) { // l.size*l.size*l.c - one filter size [27 - 9216]
const char a_bit = get_bit(A, i*lda + k);
uint64_t a_bit64 = fill_bit_int64(a_bit);
int k_ldb = k*ldb;
for (j = 0; j < N; j += 64) { // out_h*out_w - one channel output size [169 - 173056]
if ((N - j > 64) && (k_ldb % 8 == 0)) {
uint64_t b_bit64 = *((uint64_t *)(B + (k_ldb + j) / 8));
uint64_t c_bit64 = xnor_int64(a_bit64, b_bit64);
//printf("\n %d \n",__builtin_popcountll(c_bit64)); // gcc
printf("\n %d \n", __popcnt64(c_bit64)); // msvs
int h;
for (h = 0; h < 64; ++h)
if ((c_bit64 >> h) & 1) count_arr[i*ldc + j + h] += 1;
//binary_int64_printf(a_bit64);
//binary_int64_printf(b_bit64);
//binary_int64_printf(c_bit64);
}
else {
for (; j < N; ++j) { // out_h*out_w - one channel output size [169 - 173056]
char b_bit = get_bit(B, k_ldb + j);
if (xnor(a_bit, b_bit)) count_arr[i*ldc + j] += 1;
}
}
}
}
}
if (mean_arr) {
//int K_2 = K / 2;
for (i = 0; i < M; ++i) {
float mean_val = mean_arr[i];
//float mean_val2 = 2 * mean_val;
for (j = 0; j < N; ++j) {
C[i*ldc + j] = (2 * count_arr[i*ldc + j] - K) * mean_val;
//C[i*ldc + j] = (count_arr[i*ldc + j] - K_2) *mean_val2;
}
}
}
else {
for (i = 0; i < M; ++i) {
for (j = 0; j < N; ++j) {
C[i*ldc + j] = count_arr[i*ldc + j] - K / 2;
}
}
}
free(count_arr);
//getchar();
}
*/
/*
void gemm_nn_custom_bin_mean_transposed(int M, int N, int K, float ALPHA_UNUSED,
unsigned char *A, int lda,
unsigned char *B, int ldb,
float *C, int ldc, float *mean_arr)
{
int i, j, k, h;
#pragma omp parallel for
for (i = 0; i < M; ++i) { // l.n - filters [16 - 55 - 1024]
float mean_val = mean_arr[i];
for (j = 0; j < N; ++j) { // out_h*out_w - one channel output size [169 - 173056]
int count = 0;
for (k = 0; k < K; k += 64) { // l.size*l.size*l.c - one filter size [27 - 9216]
uint64_t a_bit64 = *((uint64_t *)(A + (i*lda + k) / 8));
uint64_t b_bit64 = *((uint64_t *)(B + (j*ldb + k) / 8));
uint64_t c_bit64 = xnor_int64(a_bit64, b_bit64);
#ifdef WIN32
int tmp_count = __popcnt64(c_bit64);
#else
int tmp_count = __builtin_popcountll(c_bit64);
#endif
if (K - k < 64) tmp_count = tmp_count - (64 - (K - k)); // remove extra bits
count += tmp_count;
//binary_int64_printf(c_bit64);
//printf(", count = %d \n\n", tmp_count);
}
C[i*ldc + j] = (2 * count - K) * mean_val;
}
}
}
*/
//----------------------------
#if (defined(__AVX__) && defined(__x86_64__)) || defined(_WIN64)
#define OSXSAVEFlag (1UL<<27)
@ -79,8 +307,6 @@ void gemm(int TA, int TB, int M, int N, int K, float ALPHA,
#define CLMULFlag ((1UL<< 1)|AVXFlag|OSXSAVEFlag)
#define VAESFlag ((1UL<<25)|AVXFlag|OSXSAVEFlag)
#include <stdint.h>
#ifdef _WIN64
#include <intrin.h>
#include <ammintrin.h>
@ -196,6 +422,97 @@ void gemm_nn(int M, int N, int K, float ALPHA,
}
}
}
// http://graphics.stanford.edu/~seander/bithacks.html
// https://stackoverflow.com/questions/17354971/fast-counting-the-number-of-set-bits-in-m128i-register
// 2 x faster than popcnt: https://arxiv.org/pdf/1611.07612.pdf
static inline int popcnt128(__m128i n) {
const __m128i n_hi = _mm_unpackhi_epi64(n, n);
#ifdef _MSC_VER
return __popcnt64(_mm_cvtsi128_si64(n)) + __popcnt64(_mm_cvtsi128_si64(n_hi));
#else
return __popcntq(_mm_cvtsi128_si64(n)) + __popcntq(_mm_cvtsi128_si64(n_hi));
#endif
}
static inline int popcnt256(__m256i n) {
return popcnt128(_mm256_extractf128_si256(n, 0)) + popcnt128(_mm256_extractf128_si256(n, 1));
}
void gemm_nn_custom_bin_mean_transposed(int M, int N, int K, float ALPHA_UNUSED,
unsigned char *A, int lda,
unsigned char *B, int ldb,
float *C, int ldc, float *mean_arr)
{
__m256i all_1 = _mm256_set1_epi8(255);
int i, j, k, h;
#pragma omp parallel for
for (i = 0; i < M; ++i) { // l.n - filters [16 - 55 - 1024]
float mean_val = mean_arr[i];
for (j = 0; j < N; ++j) { // out_h*out_w - one channel output size [169 - 173056]
int count = 0;
const int bit_step = 256;
for (k = 0; k < K; k += bit_step) { // l.size*l.size*l.c - one filter size [27 - 9216]
//__m128i a_bit128 = _mm_loadu_si128((__m128i *)(A + (i*lda + k) / 8));
//__m128i b_bit128 = _mm_loadu_si128((__m128i *)(B + (j*ldb + k) / 8));
//__m128i xor128 = _mm_xor_si128(a_bit128, b_bit128);
//__m128i c_bit128 = _mm_andnot_si128(xor128, all_1);
//int tmp_count = popcnt128(c_bit128);
__m256i a_bit256 = _mm256_loadu_si256((__m256i *)(A + (i*lda + k) / 8));
__m256i b_bit256 = _mm256_loadu_si256((__m256i *)(B + (j*ldb + k) / 8));
__m256i xor256 = _mm256_xor_si256(a_bit256, b_bit256);
__m256i c_bit256 = _mm256_andnot_si256(xor256, all_1); //we can do NOT for wegihts once and do not do this NOT
int tmp_count = popcnt256(c_bit256);
if (K - k < bit_step) tmp_count = tmp_count - (bit_step - (K - k)); // remove extra bits
count += tmp_count;
//binary_int64_printf(c_bit64);
//printf(", count = %d \n\n", tmp_count);
}
C[i*ldc + j] = (2 * count - K) * mean_val;
}
}
}
void float_to_bit(float *src, unsigned char *dst, size_t size)
{
size_t dst_size = size / 8 + 1;
memset(dst, 0, dst_size);
size_t i;
__m128i all128_0 = _mm_set_epi32(0, 0, 0, 0);
__m256 all256_0 = _mm256_set1_ps(0);
__m256i bits_asc = _mm256_set_epi32(1, 2, 4, 8, 16, 32, 64, 128);
//for(i = 0; i < 8; ++i) bits_asc.m256i_i32[i] = 1 << i;
for (i = 0; i < size; i+=8)
{
__m256 src256 = _mm256_loadu_ps((__m256i *)(&src[i])); // load 256 bits
__m256 result256 = _mm256_cmp_ps(src256, all256_0, _CMP_GT_OS); // compare dst[i] = (float[i] > 0)
__m256i bits256 = _mm256_castps_si256(result256); // floats to ints32
__m256i and256 = _mm256_and_si256(bits256, bits_asc); // bitwise and
// sum all elements from single and256
__m128i tmp128 = _mm_hadd_epi32(_mm256_extractf128_si256(and256, 0), _mm256_extractf128_si256(and256, 1));
tmp128 = _mm_hadd_epi32(tmp128, all128_0);
tmp128 = _mm_hadd_epi32(tmp128, all128_0);
dst[i / 8] = tmp128.m128i_i32[0];
}
// int _mm256_movemask_epi8 (__m256i a)
}
#else
void gemm_nn(int M, int N, int K, float ALPHA,
@ -213,6 +530,72 @@ void gemm_nn(int M, int N, int K, float ALPHA,
}
}
}
void gemm_nn_custom_bin_mean_transposed(int M, int N, int K, float ALPHA_UNUSED,
unsigned char *A, int lda,
unsigned char *B, int ldb,
float *C, int ldc, float *mean_arr)
{
int i, j, k, h;
#pragma omp parallel for
for (i = 0; i < M; ++i) { // l.n - filters [16 - 55 - 1024]
float mean_val = mean_arr[i];
for (j = 0; j < N; ++j) { // out_h*out_w - one channel output size [169 - 173056]
int count = 0;
for (k = 0; k < K; k += 64) { // l.size*l.size*l.c - one filter size [27 - 9216]
uint64_t a_bit64 = *((uint64_t *)(A + (i*lda + k) / 8));
uint64_t b_bit64 = *((uint64_t *)(B + (j*ldb + k) / 8));
uint64_t c_bit64 = xnor_int64(a_bit64, b_bit64);
#ifdef WIN32
int tmp_count = __popcnt64(c_bit64);
#else
int tmp_count = __builtin_popcountll(c_bit64);
#endif
if (K - k < 64) tmp_count = tmp_count - (64 - (K - k)); // remove extra bits
count += tmp_count;
//binary_int64_printf(c_bit64);
//printf(", count = %d \n\n", tmp_count);
}
C[i*ldc + j] = (2 * count - K) * mean_val;
}
}
}
void float_to_bit(float *src, unsigned char *dst, size_t size)
{
size_t dst_size = size / 8 + 1;
memset(dst, 0, dst_size);
size_t i;
char *byte_arr = calloc(size, sizeof(char));
for (i = 0; i < size; ++i) {
if (src[i] > 0) byte_arr[i] = 1;
}
//for (i = 0; i < size; ++i) {
// dst[i / 8] |= byte_arr[i] << (i % 8);
//}
for (i = 0; i < size; i += 8) {
char dst_tmp = 0;
dst_tmp |= byte_arr[i + 0] << 0;
dst_tmp |= byte_arr[i + 1] << 1;
dst_tmp |= byte_arr[i + 2] << 2;
dst_tmp |= byte_arr[i + 3] << 3;
dst_tmp |= byte_arr[i + 4] << 4;
dst_tmp |= byte_arr[i + 5] << 5;
dst_tmp |= byte_arr[i + 6] << 6;
dst_tmp |= byte_arr[i + 7] << 7;
dst[i / 8] = dst_tmp;
}
free(byte_arr);
}
#endif // __x86_64
void gemm_nt(int M, int N, int K, float ALPHA,

View File

@ -1,6 +1,34 @@
#ifndef GEMM_H
#define GEMM_H
static inline void set_bit(unsigned char *const dst, size_t index) {
size_t dst_i = index / 8;
int dst_shift = index % 8;
dst[dst_i] |= 1 << dst_shift;
}
static inline unsigned char get_bit(unsigned char const*const src, size_t index) {
size_t src_i = index / 8;
int src_shift = index % 8;
unsigned char val = (src[src_i] & (1 << src_shift)) > 0;
return val;
}
void float_to_bit(float *src, unsigned char *dst, size_t size);
void gemm_nn_custom_bin_mean_transposed(int M, int N, int K, float ALPHA_UNUSED,
unsigned char *A, int lda,
unsigned char *B, int ldb,
float *C, int ldc, float *mean_arr);
//void gemm_nn_custom_bin_mean(int M, int N, int K, float ALPHA_UNUSED,
//unsigned char *A, int lda,
//unsigned char *B, int ldb,
//float *C, int ldc, float *mean_arr)
void gemm_bin(int M, int N, int K, float ALPHA,
char *A, int lda,
float *B, int ldb,

View File

@ -33,6 +33,8 @@ void free_layer(layer l)
if (l.scale_updates) free(l.scale_updates);
if (l.weights) free(l.weights);
if (l.weight_updates) free(l.weight_updates);
if (l.weights) free(l.align_bit_weights);
if (l.weights) free(l.mean_arr);
if (l.delta) free(l.delta);
if (l.output) free(l.output);
if (l.squared) free(l.squared);

View File

@ -179,6 +179,9 @@ struct layer{
float *weights;
float *weight_updates;
char *align_bit_weights;
float *mean_arr;
float *col_image;
int * input_layers;
int * input_sizes;

View File

@ -847,3 +847,25 @@ void fuse_conv_batchnorm(network net)
}
}
}
void calculate_binary_weights(network net)
{
int j;
for (j = 0; j < net.n; ++j) {
layer *l = &net.layers[j];
if (l->type == CONVOLUTIONAL) {
//printf(" Merges Convolutional-%d and batch_norm \n", j);
if (l->xnor) {
//printf("\n %d \n", j);
size_t ldb_align = 256; // 256bit for AVX2
binary_transpose_align_weights(l, ldb_align);
}
}
}
//printf("\n calculate_binary_weights Done! \n");
}

View File

@ -151,6 +151,7 @@ YOLODLL_API void optimize_picture(network *net, image orig, int max_layer, float
int get_network_nuisance(network net);
int get_network_background(network net);
YOLODLL_API void fuse_conv_batchnorm(network net);
YOLODLL_API void calculate_binary_weights(network net);
#ifdef __cplusplus
}