首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用AVX2代码优化查找元素索引

使用AVX2代码优化查找元素索引
EN

Stack Overflow用户
提问于 2020-05-31 18:32:59
回答 1查看 311关注 0票数 2

我正在利用AVX2编写一些代码,以便在一个包含14个条目的数组中搜索32位哈希,并返回找到的条目的索引。

因为大多数点击很可能在数组的前8个条目中,这段代码已经可以改进了,添加了__builtin_expect的使用,这不是我现在的首要任务。

虽然哈希数组(在变量哈希表示的代码中)总是有14个条目,但它包含在这类结构中。

代码语言:javascript
复制
typedef struct chain_ring chain_ring_t;
struct chain_ring {
    uint32_t hashes[14];
    chain_ring_t* next;
    ...other stuff...
} __attribute__((aligned(16)))

这里的代码

代码语言:javascript
复制
int8_t hash32_find_14_avx2(uint32_t hash, volatile uint32_t* hashes) {
    uint32_t compacted_result_mask, leading_zeroes;
    __m256i cmp_vector, ring_vector, result_mask_vector;
    int8_t found_index = -1;

    if (hashes[0] == hash) {
        return 0;
    }

    for(uint8_t base_index = 0; base_index < 14; base_index += 8) {
        cmp_vector = _mm256_set1_epi32(hash);
        ring_vector = _mm256_stream_load_si256((__m256i*) (hashes + base_index));

        result_mask_vector = _mm256_cmpeq_epi32(ring_vector, cmp_vector);
        compacted_result_mask = _mm256_movemask_epi8(result_mask_vector);

        if (compacted_result_mask != 0) {
            leading_zeroes = 32 - __builtin_clz(compacted_result_mask);
            found_index = base_index + (leading_zeroes >> 2u) - 1;
            break;
        }
    }

    return found_index > 13 ? -1 : found_index;
}

这个逻辑简单地解释了,它搜索前8个条目,然后搜索第二个8个条目。如果发现的索引大于13,就意味着它找到了一个匹配的东西,而这些东西不是数组的一部分,因此必须被认为不匹配。

备注:

为了加速加载(从对齐内存),我使用了上述的_mm256_stream_load_si256

  • because,我需要检查返回的值是否大于13,并且我真的不太喜欢这个特定的部分,如果我使用_mm256_maskload_epi32

  • I am使用for-循环来避免重复代码,gcc当然会展开使用__builtin_clz的循环

  • ,但是我正在用-mlzcnt编译代码,因为据我所读,运行bsr指令的速度要慢得多,gcc正在使用lzcnt代替带标志

  • 的bsr --如果平均延迟约0.30 ns的话,它将平均减少0.6ns --第一次匹配

  • 的时间减少了0.6ns--代码只适用于64位机器

  • ,在某个时候我需要为aarch64

优化这段代码

这里有一个很好的链接,用于生成程序集https://godbolt.org/z/5bxbN6

我也实现了SSE版本(它在要点中),但逻辑是相同的,尽管我不太确定它是否值得

作为参考,我构建了一个简单的线性搜索函数,并使用google-benchmark lib对其性能进行了比较。

代码语言:javascript
复制
int8_t hash32_find_14_loop(uint32_t hash, volatile uint32_t* hashes) {
    for(uint8_t index = 0; index <= 14; index++) {
        if (hashes[index] == hash) {
            return index;
        }
    }

    return -1;
}

完整的代码可以在这个url https://gist.github.com/danielealbano/9fcbc1ff0a42cc9ad61be205366bdb5f上找到。

除了google基准库所需的标志外,我正在使用-avx2 -avx -msse4 -O3 -mbmi -mlzcnt编译它。

为每个元素执行一个工作台(我想比较循环和备选方案)

代码语言:javascript
复制
----------------------------------------------------------------------------------------------------
Benchmark                                                          Time             CPU   Iterations
----------------------------------------------------------------------------------------------------
bench_template_hash32_find_14_loop/0/iterations:100000000       0.610 ns        0.610 ns    100000000
bench_template_hash32_find_14_loop/1/iterations:100000000        1.16 ns         1.16 ns    100000000
bench_template_hash32_find_14_loop/2/iterations:100000000        1.18 ns         1.18 ns    100000000
bench_template_hash32_find_14_loop/3/iterations:100000000        1.19 ns         1.19 ns    100000000
bench_template_hash32_find_14_loop/4/iterations:100000000        1.28 ns         1.28 ns    100000000
bench_template_hash32_find_14_loop/5/iterations:100000000        1.26 ns         1.26 ns    100000000
bench_template_hash32_find_14_loop/6/iterations:100000000        1.52 ns         1.52 ns    100000000
bench_template_hash32_find_14_loop/7/iterations:100000000        2.15 ns         2.15 ns    100000000
bench_template_hash32_find_14_loop/8/iterations:100000000        1.66 ns         1.66 ns    100000000
bench_template_hash32_find_14_loop/9/iterations:100000000        1.67 ns         1.67 ns    100000000
bench_template_hash32_find_14_loop/10/iterations:100000000       1.90 ns         1.90 ns    100000000
bench_template_hash32_find_14_loop/11/iterations:100000000       1.89 ns         1.89 ns    100000000
bench_template_hash32_find_14_loop/12/iterations:100000000       2.13 ns         2.13 ns    100000000
bench_template_hash32_find_14_loop/13/iterations:100000000       2.20 ns         2.20 ns    100000000
bench_template_hash32_find_14_loop/14/iterations:100000000       2.32 ns         2.32 ns    100000000
bench_template_hash32_find_14_loop/15/iterations:100000000       2.53 ns         2.53 ns    100000000
bench_template_hash32_find_14_sse/0/iterations:100000000        0.531 ns        0.531 ns    100000000
bench_template_hash32_find_14_sse/1/iterations:100000000         1.42 ns         1.42 ns    100000000
bench_template_hash32_find_14_sse/2/iterations:100000000         2.53 ns         2.53 ns    100000000
bench_template_hash32_find_14_sse/3/iterations:100000000         1.45 ns         1.45 ns    100000000
bench_template_hash32_find_14_sse/4/iterations:100000000         2.26 ns         2.26 ns    100000000
bench_template_hash32_find_14_sse/5/iterations:100000000         1.90 ns         1.90 ns    100000000
bench_template_hash32_find_14_sse/6/iterations:100000000         1.90 ns         1.90 ns    100000000
bench_template_hash32_find_14_sse/7/iterations:100000000         1.93 ns         1.93 ns    100000000
bench_template_hash32_find_14_sse/8/iterations:100000000         2.07 ns         2.07 ns    100000000
bench_template_hash32_find_14_sse/9/iterations:100000000         2.05 ns         2.05 ns    100000000
bench_template_hash32_find_14_sse/10/iterations:100000000        2.08 ns         2.08 ns    100000000
bench_template_hash32_find_14_sse/11/iterations:100000000        2.08 ns         2.08 ns    100000000
bench_template_hash32_find_14_sse/12/iterations:100000000        2.55 ns         2.55 ns    100000000
bench_template_hash32_find_14_sse/13/iterations:100000000        2.53 ns         2.53 ns    100000000
bench_template_hash32_find_14_sse/14/iterations:100000000        2.37 ns         2.37 ns    100000000
bench_template_hash32_find_14_sse/15/iterations:100000000        2.59 ns         2.59 ns    100000000
bench_template_hash32_find_14_avx2/0/iterations:100000000       0.537 ns        0.537 ns    100000000
bench_template_hash32_find_14_avx2/1/iterations:100000000        1.37 ns         1.37 ns    100000000
bench_template_hash32_find_14_avx2/2/iterations:100000000        1.38 ns         1.38 ns    100000000
bench_template_hash32_find_14_avx2/3/iterations:100000000        1.36 ns         1.36 ns    100000000
bench_template_hash32_find_14_avx2/4/iterations:100000000        1.37 ns         1.37 ns    100000000
bench_template_hash32_find_14_avx2/5/iterations:100000000        1.38 ns         1.38 ns    100000000
bench_template_hash32_find_14_avx2/6/iterations:100000000        1.40 ns         1.40 ns    100000000
bench_template_hash32_find_14_avx2/7/iterations:100000000        1.39 ns         1.39 ns    100000000
bench_template_hash32_find_14_avx2/8/iterations:100000000        1.99 ns         1.99 ns    100000000
bench_template_hash32_find_14_avx2/9/iterations:100000000        2.02 ns         2.02 ns    100000000
bench_template_hash32_find_14_avx2/10/iterations:100000000       1.98 ns         1.98 ns    100000000
bench_template_hash32_find_14_avx2/11/iterations:100000000       1.98 ns         1.98 ns    100000000
bench_template_hash32_find_14_avx2/12/iterations:100000000       2.03 ns         2.03 ns    100000000
bench_template_hash32_find_14_avx2/13/iterations:100000000       1.98 ns         1.98 ns    100000000
bench_template_hash32_find_14_avx2/14/iterations:100000000       1.96 ns         1.96 ns    100000000
bench_template_hash32_find_14_avx2/15/iterations:100000000       1.97 ns         1.97 ns    100000000

谢谢你的建议!

--更新

我已经用@chtz所做的无分支实现更新了gist,并将__lzcnt32替换为_tzcnt_u32,我不得不稍微改变行为,以考虑在返回32而不是-1时没有发现的行为,但这并不重要。

他们运行的CPU是英特尔核心i7 8700 (6c/12t,3.20GHZ)。

工作台使用cpu钉扎,使用比物理或逻辑cpu内核更多的线程,并执行一些额外的操作,特别是一个for循环,因此存在开销,但在两个测试之间是相同的,因此应该以相同的方式影响它们。

如果您想要运行测试,您需要对CPU_CORE_LOGICAL_COUNT进行优化,以手动匹配您的cpu的逻辑cpu核心的数量。

有趣的是,当有更多争用(从单个线程到64个线程)时,性能改善是如何从+17%跃升到+41%的。我还对128个和256个线程进行了一些测试,在使用AVX2时速度提高了60%,但我没有包括下面的数字。

(bench_template_hash32_find_14_avx2正在使用无分支版本,我缩短了名称以提高文章的可读性)

代码语言:javascript
复制
------------------------------------------------------------------------------------------
Benchmark                                                                 CPU   Iterations
------------------------------------------------------------------------------------------
bench_template_hash32_find_14_loop/iterations:10000000/threads:1      45.2 ns     10000000
bench_template_hash32_find_14_loop/iterations:10000000/threads:2      50.4 ns     20000000
bench_template_hash32_find_14_loop/iterations:10000000/threads:4      52.1 ns     40000000
bench_template_hash32_find_14_loop/iterations:10000000/threads:8      70.9 ns     80000000
bench_template_hash32_find_14_loop/iterations:10000000/threads:16     86.8 ns    160000000
bench_template_hash32_find_14_loop/iterations:10000000/threads:32     87.3 ns    320000000
bench_template_hash32_find_14_loop/iterations:10000000/threads:64     92.9 ns    640000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:1      38.4 ns     10000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:2      42.1 ns     20000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:4      46.5 ns     40000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:8      52.6 ns     80000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:16     60.0 ns    160000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:32     62.1 ns    320000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:64     65.8 ns    640000000
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-05-31 23:22:56

您可以在没有分支的情况下完全实现这一点,方法是比较数组的两个重叠部分,位-或它们在一起,并与单个lzcnt获得最后的位位置。另外,使用vmovmskps而不是vpmovmskb可以将结果除以4(但我不确定这是否会导致跨域延迟)。

代码语言:javascript
复制
int8_t hash32_find_14_avx2(uint32_t hash, volatile uint32_t* hashes) {
    uint32_t compacted_result_mask = 0;
    __m256i cmp_vector = _mm256_set1_epi32(hash);
    for(uint8_t base_index = 0; base_index < 12; base_index += 6) {
        __m256i ring_vector = _mm256_loadu_si256((__m256i*) (hashes + base_index));

        __m256i result_mask_vector = _mm256_cmpeq_epi32(ring_vector, cmp_vector);
        compacted_result_mask |= _mm256_movemask_ps(_mm256_castsi256_ps(result_mask_vector)) << (base_index);
    }
    int32_t leading_zeros = __lzcnt32(compacted_result_mask);
    return (31 - leading_zeros);
}

正如Peter在评论中已经指出的,在大多数情况下,_mm256_stream_load_si256比普通负载更糟糕。另外,请注意,在gcc使用未对齐负载时,您必须使用-mno-avx256-split-unaligned-load (或者实际上只使用-march=native) -- see this post for details进行编译。

通过简单的测试代码(注意,如果数组中有多个匹配值,则循环和avx2 2版本的行为不同):https://godbolt.org/z/2jNWqK

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62120797

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档