首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >紧致AVX2寄存器,因此根据掩码选择的整数是连续的

紧致AVX2寄存器,因此根据掩码选择的整数是连续的
EN

Stack Overflow用户
提问于 2014-08-01 06:31:46
回答 1查看 1.1K关注 0票数 4

在问题Optimizing Array Compaction中,最上面的答案是:

SSE/AVX寄存器与最新的指令集允许一个更好的方法。我们可以直接使用PMOVMSKB的结果,将其转换为类似PSHUFB的控件寄存器。

这与哈斯韦尔(AVX2)有可能吗?还是需要一种AVX512的口味?

我得到了一个包含AVX2向量的int32s,以及一个对应的比较结果向量。我想以某种方式对其进行洗牌,以便在掩码中设置相应msb的元素(比较为真)在向量的低端是连续的。

我能看到的最好是使用_ps 256_movemask_ps/vmovmskps获得一个位掩码(没有*d变体?)然后在一个256个AVX2向量查找表中使用它来获得交叉车道_ to 256 _permutevar8x32_pi32/vpermd的混洗掩码。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2014-08-01 10:05:11

首先要做的是找到一个快速的标量函数。这里有一个不使用分支的版本。

代码语言:javascript
复制
inline int compact(int *x, int *y, const int n) {
    int cnt = 0;
    for(int i=0; i<n; i++) {
        int cut = x[i]!=0;
        y[cnt] = cut*x[i];
        cnt += cut;
    }
    return cnt;
}

SIMD的最佳结果可能取决于零点的分布。如果是稀疏的还是稠密的。对于稀疏或密集的分布,下面的代码应该运行良好。例如,长时间运行零和非零。如果发行版更多,我不知道这段代码是否会有任何好处。但无论如何,它都会给出正确的结果。

下面是我测试过的AVX2版本。

代码语言:javascript
复制
int compact_AVX2(int *x, int *y, int n) {
    int i =0, cnt = 0;
    for(i=0; i<n-8; i+=8) {
        __m256i x4 = _mm256_loadu_si256((__m256i*)&x[i]);
        __m256i cmp = _mm256_cmpeq_epi32(x4, _mm256_setzero_si256());
        int mask = _mm256_movemask_epi8(cmp);
        if(mask == -1) continue; //all zeros
        if(mask) {
            cnt += compact(&x[i],&y[cnt], 8);
        }
        else {
            _mm256_storeu_si256((__m256i*)&y[cnt], x4);
            cnt +=8;
        }       
    }
    cnt += compact(&x[i], &y[cnt], n-i); // cleanup for n not a multiple of 8
    return cnt;
}

这是我测试过的SSE2版本。

代码语言:javascript
复制
int compact_SSE2(int *x, int *y, int n) {
    int i =0, cnt = 0;
    for(i=0; i<n-4; i+=4) {
        __m128i x4 = _mm_loadu_si128((__m128i*)&x[i]);
        __m128i cmp = _mm_cmpeq_epi32(x4, _mm_setzero_si128());
        int mask = _mm_movemask_epi8(cmp);
        if(mask == 0xffff) continue; //all zeroes
        if(mask) {
            cnt += compact(&x[i],&y[cnt], 4);
        }
        else {
            _mm_storeu_si128((__m128i*)&y[cnt], x4);
            cnt +=4;
        }       
    }
    cnt += compact(&x[i], &y[cnt], n-i); // cleanup for n not a multiple of 4
    return cnt;
}

这是一个完整的测试

代码语言:javascript
复制
#include <stdio.h>
#include <stdlib.h>
#if defined (__GNUC__) && ! defined (__INTEL_COMPILER)
#include <x86intrin.h>                
#else
#include <immintrin.h>                
#endif

#define N 50

inline int compact(int *x, int *y, const int n) {
    int cnt = 0;
    for(int i=0; i<n; i++) {
        int cut = x[i]!=0;
        y[cnt] = cut*x[i];
        cnt += cut;
    }
    return cnt;
}

int compact_SSE2(int *x, int *y, int n) {
        int i =0, cnt = 0;
        for(i=0; i<n-4; i+=4) {
            __m128i x4 = _mm_loadu_si128((__m128i*)&x[i]);
            __m128i cmp = _mm_cmpeq_epi32(x4, _mm_setzero_si128());
            int mask = _mm_movemask_epi8(cmp);
            if(mask == 0xffff) continue; //all zeroes
            if(mask) {
                cnt += compact(&x[i],&y[cnt], 4);
            }
            else {
                _mm_storeu_si128((__m128i*)&y[cnt], x4);
                cnt +=4;
            }       
        }
        cnt += compact(&x[i], &y[cnt], n-i); // cleanup for n not a multiple of 4
        return cnt;
    }

int compact_AVX2(int *x, int *y, int n) {
    int i =0, cnt = 0;
    for(i=0; i<n-8; i+=8) {
        __m256i x4 = _mm256_loadu_si256((__m256i*)&x[i]);
        __m256i cmp = _mm256_cmpeq_epi32(x4, _mm256_setzero_si256());
        int mask = _mm256_movemask_epi8(cmp);
        if(mask == -1) continue; //all zeros
        if(mask) {
            cnt += compact(&x[i],&y[cnt], 8);
        }
        else {
            _mm256_storeu_si256((__m256i*)&y[cnt], x4);
            cnt +=8;
        }       
    }
    cnt += compact(&x[i], &y[cnt], n-i); // cleanup for n not a multiple of 8
    return cnt;
}

int main() {
    int x[N], y[N];
    for(int i=0; i<N; i++) x[i] = rand()%10;
    //int cnt = compact_SSE2(x,y,N);
    int cnt = compact_AVX2(x,y,N);
    for(int i=0; i<N; i++) printf("%d ", x[i]); printf("\n");
    for(int i=0; i<cnt; i++) printf("%d ", y[i]); printf("\n");
}
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/25074197

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档