首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用AVX-512或AVX-2在大数据上计数1位(人口计数)。

使用AVX-512或AVX-2在大数据上计数1位(人口计数)。
EN

Stack Overflow用户
提问于 2018-04-28 22:04:34
回答 2查看 4.4K关注 0票数 8

我有很长的内存,比如说,256 KiB或更长的内存。我想计算整个块中的1位数,或者换句话说:将所有字节的“填充计数”值相加。

我知道AVX-512有一个VPOPCNTDQ指令,它在一个512位向量中连续计算64位中的1位数,IIANM应该可以在每一个周期中发出一个(如果有一个合适的SIMD向量寄存器)--但是我没有编写SIMD代码的经验(我更像一个GPU的家伙)。此外,我也不能百分之百肯定对AVX-512目标的编译器支持.

在大多数CPU上,AVX-512仍然不被(完全)支持;但是AVX-2是广泛使用的.我还没有找到类似于VPOPCNTDQ的小于-512位的矢量化指令,所以即使从理论上讲,我也不知道如何用AVX-2功能的CPU快速计算比特数;也许存在这样的东西,但我不知怎么错过了它?

无论如何,我希望有一个简短的C/C++函数--或者使用一些内部包装库,或者使用内联程序集--用于这两个指令集中的每一个。签名是

代码语言:javascript
复制
uint64_t count_bits(void* ptr, size_t size);

备注:

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-04-29 00:16:26

AVX-2

@HadiBreis的注释链接到关于快速人口统计的文章 - SSSE3,由Wojciech Muła编写;文章链接到这个GitHub存储库;存储库提供以下AVX-2实现。它基于一个矢量化的查找指令,并使用一个16值的查找表来进行比特计数。

代码语言:javascript
复制
#   include <immintrin.h>
#   include <x86intrin.h>

std::uint64_t popcnt_AVX2_lookup(const uint8_t* data, const size_t n) {

    size_t i = 0;

    const __m256i lookup = _mm256_setr_epi8(
        /* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
        /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
        /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
        /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4,

        /* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
        /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
        /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
        /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4
    );

    const __m256i low_mask = _mm256_set1_epi8(0x0f);

    __m256i acc = _mm256_setzero_si256();

#define ITER { \
        const __m256i vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data + i)); \
        const __m256i lo  = _mm256_and_si256(vec, low_mask); \
        const __m256i hi  = _mm256_and_si256(_mm256_srli_epi16(vec, 4), low_mask); \
        const __m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo); \
        const __m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi); \
        local = _mm256_add_epi8(local, popcnt1); \
        local = _mm256_add_epi8(local, popcnt2); \
        i += 32; \
    }

    while (i + 8*32 <= n) {
        __m256i local = _mm256_setzero_si256();
        ITER ITER ITER ITER
        ITER ITER ITER ITER
        acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
    }

    __m256i local = _mm256_setzero_si256();

    while (i + 32 <= n) {
        ITER;
    }

    acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));

#undef ITER

    uint64_t result = 0;

    result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 0));
    result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 1));
    result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 2));
    result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 3));

    for (/**/; i < n; i++) {
        result += lookup8bit[data[i]];
    }

    return result;
}

AVX-512

同一存储库还具有一个基于VPOPCNT的AVX-512实现。在列出它的代码之前,下面是简化的、更具可读性的伪代码:

  • 对于64个字节的连续序列:
代码语言:javascript
复制
- Load the sequence into a SIMD register with 64x8 = 512 bits
- Perform 8 parallel population counts of 64 bits each on that register
- Add the 8 population-count results in parallel, into an "accumulator" register holding 8 sums
  • 将累加器中的8个值相加
  • 如果有一个小于64个字节的尾,用一些更简单的方式来计算那里的位
  • 返回主和加上尾和

现在是真正的交易:

代码语言:javascript
复制
#   include <immintrin.h>
#   include <x86intrin.h>

uint64_t avx512_vpopcnt(const uint8_t* data, const size_t size) {
    
    const size_t chunks = size / 64;

    uint8_t* ptr = const_cast<uint8_t*>(data);
    const uint8_t* end = ptr + size;

    // count using AVX512 registers
    __m512i accumulator = _mm512_setzero_si512();
    for (size_t i=0; i < chunks; i++, ptr += 64) {
        
        // Note: a short chain of dependencies, likely unrolling will be needed.
        const __m512i v = _mm512_loadu_si512((const __m512i*)ptr);
        const __m512i p = _mm512_popcnt_epi64(v);

        accumulator = _mm512_add_epi64(accumulator, p);
    }

    // horizontal sum of a register
    uint64_t tmp[8] __attribute__((aligned(64)));
    _mm512_store_si512((__m512i*)tmp, accumulator);

    uint64_t total = 0;
    for (size_t i=0; i < 8; i++) {
        total += tmp[i];
    }

    // popcount the tail
    while (ptr + 8 < end) {
        total += _mm_popcnt_u64(*reinterpret_cast<const uint64_t*>(ptr));
        ptr += 8;
    }

    while (ptr < end) {
        total += lookup8bit[*ptr++];
    }

    return total;
}

lookup8bit是一个字节而不是位的弹出查找表,它被定义为这里编辑:作为评论注意,在结尾使用8位查找表不是一个好主意,可以改进。

票数 5
EN

Stack Overflow用户

发布于 2018-04-29 09:35:02

除了标量清理循环之外,Wojciech Muła的大数组弹出函数看起来是最优的。(有关主循环的详细信息,请参阅@einpoklum的答案)。

在结束时只使用几次的256条目LUT可能会缓存失败,即使缓存是热的,超过1字节的LUT也不是最佳选择。我相信所有的AVX2 CPU都有硬件popcnt,我们可以很容易地分离出最后的8个字节,这些字节还没有被计算出来,从而为单个popcnt设置好了。

与通常的SIMD算法一样,它通常可以很好地完成以缓冲区的最后一个字节结束的全宽度加载。但与矢量寄存器不同的是,全整数寄存器的可变计数移位非常便宜(特别是在BMI2中)。Popcnt并不关心比特在哪里,所以我们只需要使用移位,而不需要构造和掩码之类的东西。

代码语言:javascript
复制
// untested
// ptr points at the first byte that hasn't been counted yet
uint64_t final_bytes = reinterpret_cast<const uint64_t*>(end)[-1] >> (8*(end-ptr));
total += _mm_popcnt_u64( final_bytes );
// Careful, this could read outside a small buffer.

或者更好的是,使用更复杂的逻辑来避免页面交叉.例如,这可以避免6字节缓冲区在页面开始时跨页。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50081465

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档