From f92b20580a21663c5db9eb8608f8cabd7adbeb10 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Tue, 14 Aug 2018 01:51:31 +0300 Subject: [PATCH] Some fixes for AVX support on CPU --- src/convolutional_layer.c | 7 +++++-- src/gemm.c | 26 +++++++++++++++----------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 927eb993..e7c8e3f2 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -621,7 +621,7 @@ void binary_align_weights(convolutional_layer *l) free(align_weights); } - +// further optimizations: im2col_bin() for XNOR, and then transpose_aling_bin() size_t binary_transpose_align_input(int k, int n, float *b, char **t_bit_input, size_t ldb_align) { size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8; @@ -690,7 +690,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) //} //if (l.xnor && l.size == 3 && l.stride == 1 && l.pad == 1) {} //else - im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b); + // further optimizations: im2col_bin() for XNOR, and then transpose_aling_bin() + im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b); //gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); @@ -793,6 +794,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) //char *t_bit_input = calloc(new_ldb * n, sizeof(char)); // for im2col_cpu_custom_transpose() only //float_to_bit(t_input, t_bit_input, new_ldb * n); // for im2col_cpu_custom_transpose() only + // 5x times faster than gemm()-float32 + // further optimizations: accelerate maxpool-layer with OpenMP/AVX gemm_nn_custom_bin_mean_transposed(m, n, k, 1, l.align_bit_weights, new_ldb, t_bit_input, new_ldb, c, n, l.mean_arr); //gemm_nn_custom_bin_mean_transposed(m, n, k, 1, bit_weights, k, t_bit_input, new_ldb, c, n, mean_arr); diff --git a/src/gemm.c b/src/gemm.c index b9098141..d233e9c7 100644 --- a/src/gemm.c +++ b/src/gemm.c @@ -674,6 +674,8 @@ static inline int popcnt256_custom(__m256i n) { + _mm256_extract_epi64(val, 3); } +// 5x times faster than gemm()-float32 +// further optimizations: do mean-mult only for the last layer void gemm_nn_custom_bin_mean_transposed(int M, int N, int K, float ALPHA_UNUSED, unsigned char *A, int lda, unsigned char *B, int ldb, @@ -873,7 +875,7 @@ void im2col_cpu_custom(float* data_im, int channels_col = channels * ksize * ksize; // optimized version - if (height_col == height && width_col == width && stride == 1 && pad == 1) + if (height_col == height && width_col == width && stride == 1 && pad == 1 && is_fma_avx()) { #pragma omp parallel for for (c = 0; c < channels_col; ++c) { @@ -954,24 +956,26 @@ void im2col_cpu_custom(float* data_im, void activate_array_cpu_custom(float *x, const int n, const ACTIVATION a) { - int i; + int i = 0; if (a == LINEAR) {} else if (a == LEAKY) { - __m256i all256_sing1 = _mm256_set_epi32(0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000); - __m256 all256_01 = _mm256_set1_ps(0.1F); + if (is_fma_avx()) { + __m256i all256_sing1 = _mm256_set_epi32(0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000); + __m256 all256_01 = _mm256_set1_ps(0.1F); - for (i = 0; i < n-8; i += 8) { - //x[i] = (x[i]>0) ? x[i] : .1*x[i]; + for (i = 0; i < n - 8; i += 8) { + //x[i] = (x[i]>0) ? x[i] : .1*x[i]; - __m256 src256 = _mm256_loadu_ps(&x[i]); - __m256 mult256 = _mm256_mul_ps((src256), all256_01); // mult * 0.1 + __m256 src256 = _mm256_loadu_ps(&x[i]); + __m256 mult256 = _mm256_mul_ps((src256), all256_01); // mult * 0.1 - __m256i sign256 = _mm256_and_si256(_mm256_castps_si256(src256), all256_sing1); // check sign in 8 x 32-bit floats + __m256i sign256 = _mm256_and_si256(_mm256_castps_si256(src256), all256_sing1); // check sign in 8 x 32-bit floats - __m256 result256 = _mm256_blendv_ps(src256, mult256, _mm256_castsi256_ps(sign256)); // (sign>0) ? src : mult; - _mm256_storeu_ps(&x[i], result256); + __m256 result256 = _mm256_blendv_ps(src256, mult256, _mm256_castsi256_ps(sign256)); // (sign>0) ? src : mult; + _mm256_storeu_ps(&x[i], result256); + } } for (; i < n; ++i) {