首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >基于AVX2和SSE2的位向量运算

基于AVX2和SSE2的位向量运算
EN

Stack Overflow用户
提问于 2019-11-04 12:56:56
回答 1查看 264关注 0票数 3

我是AVX2和SSE2指令集的新手,我想了解更多关于如何使用这些指令集来加速位向量操作的知识。

到目前为止,我已经成功地使用它们来用双/浮动操作将代码矢量化。

在本例中,我有一个C++代码,在此之前检查一个条件,以便在位向量(使用无符号int)中将其设置为特定值:

代码语言:javascript
复制
int process_bit_vetcor(unsigned int *bitVector, float *value, const float threshold, const unsigned int dim)
{
       int sum = 0, cond = 0;

       for (unsigned int i = 0; i < dim; i++) {
            unsigned int *word = bitVector + i / 32;
            unsigned int bitValue = ((unsigned int)0x80000000 >> (i & 0x1f));
            cond = (value[i] <= threshold);
            (*word) = (cond) ? (*word) | bitValue : (*word);
            sum += cond;
        }

        return sum;
}

变量和只返回条件为真的情况数。

我试着用SSE2和AVX2重写这个例程,但是没有成功.:-(

可以用C++和SSE2重写这样的代码吗?对这种类型的位操作使用矢量化是否值得?位向量可能包含数千位,所以我希望使用SSE2和AVX2来加快速度是很有趣的。

提前感谢!

EN

回答 1

Stack Overflow用户

发布于 2019-11-07 14:41:42

如果dim是8的倍数(为了处理剩余部分,在末尾添加一个简单的循环),下面的代码应该可以工作。小API-更改:

  • 为循环索引使用long而不是unsigned int (这有助于clang展开循环)
  • 假设bitvector是小端点(如注释中所建议的)

在循环中,bitVector按字节顺序访问.可能值得将movemask和bit的2或4个结果结合起来--或者一次组合它们(可能取决于目标体系结构)。

为了计算sum,直接根据cmp_ps运算的结果计算了8个部分和。由于您无论如何都需要位掩码,所以可能值得使用popcnt (理想情况下,将2、4或8个字节组合在一起--这可能取决于您的目标体系结构)。

代码语言:javascript
复制
int process_bit_vector(uint32_t *bitVector32, float *value,
                       const float threshold_float, const long dim) {
  __m256i sum = _mm256_setzero_si256();
  __m256 threshold_vector = _mm256_set1_ps(threshold_float);
  uint8_t *bitVector8 = (uint8_t *)bitVector32;

  for (long i = 0; i <= dim-8; i += 8) {
    // compare next 8 values with threshold
    // (use threshold as first operand to allow loading other operand from memory)
    __m256 cmp_mask = _mm256_cmp_ps(threshold_vector, _mm256_loadu_ps(value + i), _CMP_GE_OQ);
    // true values are `-1` when interpreted as integers, subtract those from `sum`
    sum = _mm256_sub_epi32(sum, _mm256_castps_si256(cmp_mask));
    // extract bitmask
    int mask = _mm256_movemask_ps(cmp_mask);
    // bitwise-or current mask with result bit-vector
    *bitVector8++ |= mask;
  }

  // reduce 8 partial sums to a single sum and return
  __m128i sum_reduced = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum,1));
  sum_reduced = _mm_add_epi32(sum_reduced, _mm_srli_si128(sum_reduced, 8));
  sum_reduced = _mm_add_epi32(sum_reduced, _mm_srli_si128(sum_reduced, 4));

  return _mm_cvtsi128_si32(sum_reduced);
}

螺栓连接:https://godbolt.org/z/ABwDPe

由于某种原因,GCC不只是vpsubd ymm0, ymm0, ymm1.

  • Clang,而是做了vpsubd ymm2, ymm0, ymm1; vmovdqa ymm0, ymm2;,但是没有加入vcmpps (并且使用LE而不是GE比较) --如果您不关心NaNs是如何处理的,您可以使用_CMP_NLT_US而不是_CMP_GE_OQ.

具有大端输出的修订版(未经测试):

代码语言:javascript
复制
int process_bit_vector(uint32_t *bitVector32, float *value,
                       const float threshold_float, const long dim) {
  int sum = 0;
  __m256 threshold_vector = _mm256_set1_ps(threshold_float);

  for (long i = 0; i <= dim-32; i += 32) {
    // compare next 4x8 values with threshold
    // (use threshold as first operand to allow loading other operand from memory)
    __m256i cmp_maskA = _mm256_castps_si256(_mm256_cmp_ps(threshold_vector, _mm256_loadu_ps(value + i+ 0), _CMP_GE_OQ));
    __m256i cmp_maskB = _mm256_castps_si256(_mm256_cmp_ps(threshold_vector, _mm256_loadu_ps(value + i+ 8), _CMP_GE_OQ));
    __m256i cmp_maskC = _mm256_castps_si256(_mm256_cmp_ps(threshold_vector, _mm256_loadu_ps(value + i+16), _CMP_GE_OQ));
    __m256i cmp_maskD = _mm256_castps_si256(_mm256_cmp_ps(threshold_vector, _mm256_loadu_ps(value + i+24), _CMP_GE_OQ));

    __m256i cmp_mask = _mm256_packs_epi16(
        _mm256_packs_epi16(cmp_maskA,cmp_maskB), // b7b7b6b6'b5b5b4b4'a7a7a6a6'a5a5a4a4 b3b3b2b2'b1b1b0b0'a3a3a2a2'a1a1a0a0
        _mm256_packs_epi16(cmp_maskC,cmp_maskD)  // d7d7d6d6'd5d5d4d4'c7c7c6c6'c5c5c4c4 d3d3d2d2'd1d1d0d0'c3c3c2c2'c1c1c0c0
    );                                // cmp_mask = d7d6d5d4'c7c6c5c4'b7b6b5b4'a7a6a5a4 d3d2d1d0'c3c2c1c0'b3b2b1b0'a3a2a1a0

    cmp_mask = _mm256_permute4x64_epi64(cmp_mask, 0x8d);
                // cmp_mask = [b7b6b5b4'a7a6a5a4 b3b2b1b0'a3a2a1a0  d7d6d5d4'c7c6c5c4 d3d2d1d0'c3c2c1c0]
    __m256i shuff_idx = _mm256_broadcastsi128_si256(_mm_set_epi64x(0x00010203'08090a0b,0x04050607'0c0d0e0f));
    cmp_mask = _mm256_shuffle_epi8(cmp_mask, shuff_idx);

    // extract bitmask
    uint32_t mask = _mm256_movemask_epi8(cmp_mask);
    sum += _mm_popcnt_u32 (mask);
    // bitwise-or current mask with result bit-vector
    *bitVector32++ |= mask;
  }

  return sum;
}

这样做的目的是在应用vpmovmskb之前对字节进行洗牌。这需要对32个输入值进行5次洗牌操作(包括3 vpacksswb),但是计算和的方法是使用popcnt而不是4 vpsubd。在比较128个比特一半到256个比特向量之前,可以避免vpermq (_mm256_permute4x64_epi64)。另一个想法(因为您无论如何都需要对最终结果进行洗牌)是将部分结果混合在一起(这往往需要我检查过的体系结构上的p52*p015,所以可能不值得)。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58693907

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档