diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index f5050521..25528d48 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -888,6 +888,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) if(l.c % 32 == 0) { + //printf(" l.index = %d - new XNOR \n", l.index); + int ldb_align = l.lda_align; size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8; size_t t_intput_size = new_ldb * l.bit_align;// n; @@ -906,7 +908,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) free(re_packed_input); - // convolution the packed inputs and weights: float x 32 by channel (as in cuDNN) + // slow - convolution the packed inputs and weights: float x 32 by channel (as in cuDNN) //convolution_repacked((uint32_t *)bin_re_packed_input, (uint32_t *)l.align_bit_weights, l.output, // l.w, l.h, l.c, l.n, l.size, l.pad, l.new_lda, l.mean_arr); @@ -920,10 +922,11 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) int new_k = l.size*l.size*l.c / 32; -// gemm_nn_bin_32bit_packed(m, n, new_k, 1, -// l.align_bit_weights, l.new_lda/32, -// b, n, -// c, n, l.mean_arr); + // good for (l.c == 64) + //gemm_nn_bin_32bit_packed(m, n, new_k, 1, + // l.align_bit_weights, l.new_lda/32, + // b, n, + // c, n, l.mean_arr); // // then exit from if() @@ -951,6 +954,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) else { // else (l.c % 32 != 0) //-------------------------------------------------------- + //printf(" l.index = %d - old XNOR \n", l.index); //im2col_cpu_custom_align(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align); im2col_cpu_custom_bin(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align); @@ -993,6 +997,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) } else { + //printf(" l.index = %d - FP32 \n", l.index); im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b); gemm(0, 0, m, n, k, 1, a, k, b, n, 1, c, n); diff --git a/src/gemm.c b/src/gemm.c index bf52a118..87de2159 100644 --- a/src/gemm.c +++ b/src/gemm.c @@ -489,9 +489,9 @@ void transpose_bin(uint32_t *A, uint32_t *B, const int n, const int m, } static inline int popcnt_32(uint32_t val32) { -#ifdef WIN32 // Windows +#ifdef WIN32 // Windows MSVS int tmp_count = __popcnt(val32); -#else // Linux +#else // Linux GCC int tmp_count = __builtin_popcount(val32); #endif return tmp_count; @@ -755,39 +755,15 @@ void gemm_nn_bin_32bit_packed(int M, int N, int K, float ALPHA, __m256i all_1 = _mm256_set1_epi8(255); __m256i xnor256 = _mm256_andnot_si256(xor256, all_1); // xnor = not(xor(a,b)) - //_m256 count = _mm256_set_ps( - /* - __m256i count = _mm256_setr_epi32( - (int)popcnt_32(xnor256.m256i_u32[0]), - (int)popcnt_32(xnor256.m256i_u32[1]), - (int)popcnt_32(xnor256.m256i_u32[2]), - (int)popcnt_32(xnor256.m256i_u32[3]), - (int)popcnt_32(xnor256.m256i_u32[4]), - (int)popcnt_32(xnor256.m256i_u32[5]), - (int)popcnt_32(xnor256.m256i_u32[6]), - (int)popcnt_32(xnor256.m256i_u32[7])); - - __m256i val2 = _mm256_set1_epi32(2); - count = _mm256_mullo_epi32(count, val2); - - __m256i val32 = _mm256_set1_epi32(32); - count = _mm256_sub_epi32(count, val32); - - int z; - for (z = 0; z < 8; ++z) { - C[i*ldc + j + z] += count.m256i_i32[z] * mean_val; - } - */ - __m256 count = _mm256_setr_ps( - popcnt_32(xnor256.m256i_u32[0]), - popcnt_32(xnor256.m256i_u32[1]), - popcnt_32(xnor256.m256i_u32[2]), - popcnt_32(xnor256.m256i_u32[3]), - popcnt_32(xnor256.m256i_u32[4]), - popcnt_32(xnor256.m256i_u32[5]), - popcnt_32(xnor256.m256i_u32[6]), - popcnt_32(xnor256.m256i_u32[7])); + popcnt_32(_mm256_extract_epi32(xnor256, 0)), + popcnt_32(_mm256_extract_epi32(xnor256, 1)), + popcnt_32(_mm256_extract_epi32(xnor256, 2)), + popcnt_32(_mm256_extract_epi32(xnor256, 3)), + popcnt_32(_mm256_extract_epi32(xnor256, 4)), + popcnt_32(_mm256_extract_epi32(xnor256, 5)), + popcnt_32(_mm256_extract_epi32(xnor256, 6)), + popcnt_32(_mm256_extract_epi32(xnor256, 7))); __m256 val2 = _mm256_set1_ps(2); count = _mm256_mul_ps(count, val2); // count * 2 @@ -2274,17 +2250,19 @@ void gemm_nn_bin_transposed_32bit_packed(int M, int N, int K, float ALPHA, for (i = 0; i < M; ++i) { // l.n int j, s; float mean_val = mean_arr[i]; - for (s = 0; s < K; ++s) // l.size*l.size*l.c/32 or (l.size*l.size*l.c) + for (j = 0; j < N; ++j) // out_h*out_w; { - register uint32_t A_PART = ((uint32_t*)A)[i*lda + s]; - for (j = 0; j < N; ++j) // out_h*out_w; + float val = 0; + for (s = 0; s < K; ++s) // l.size*l.size*l.c/32 or (l.size*l.size*l.c) { + register uint32_t A_PART = ((uint32_t*)A)[i*lda + s]; register uint32_t B_PART = ((uint32_t*)B)[j*ldb + s]; uint32_t xnor_result = ~(A_PART ^ B_PART); int32_t count = popcnt_32(xnor_result); // must be Signed int - C[i*ldc + j] += (2 * count - 32) * mean_val; + val += (2 * count - 32) * mean_val; } + C[i*ldc + j] += val; } } } @@ -2422,6 +2400,8 @@ void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA, } } + is_avx(); // initialize static variable + is_fma_avx2(); int t; #pragma omp parallel for for (t = 0; t < M; ++t) { diff --git a/src/gemm.h b/src/gemm.h index dd727830..c34b4b35 100644 --- a/src/gemm.h +++ b/src/gemm.h @@ -22,6 +22,9 @@ static inline unsigned char get_bit(unsigned char const*const src, size_t index) return val; } +int is_avx(); +int is_fma_avx2(); + void float_to_bit(float *src, unsigned char *dst, size_t size); void transpose_block_SSE4x4(float *A, float *B, const int n, const int m, diff --git a/src/network.c b/src/network.c index b91edd5b..cf0d9351 100644 --- a/src/network.c +++ b/src/network.c @@ -190,6 +190,8 @@ network make_network(int n) return net; } +double get_time_point(); + void forward_network(network net, network_state state) { state.workspace = net.workspace; @@ -200,7 +202,9 @@ void forward_network(network net, network_state state) if(l.delta){ scal_cpu(l.outputs * l.batch, 0, l.delta, 1); } + //double time = get_time_point(); l.forward(l, state); + //printf("%d - Predicted in %lf milli-seconds.\n", i, ((double)get_time_point() - time) / 1000); state.input = l.output; } }