首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >内存延迟的金属优化GPU矩阵乘法

内存延迟的金属优化GPU矩阵乘法
EN

Stack Overflow用户
提问于 2017-05-10 21:51:26
回答 1查看 575关注 0票数 1

这是一个非常基本的C++问题,用于计算GPU上的矩阵乘法。以下代码在技术上是MSL,但语法几乎相同。

苹果提供了一个用于计算矩阵乘法实例A^T * B。我正在寻找一些帮助来修改它以简单地计算A * B

对这个着色器的每个调用都在C的一个8×8扇区上运行,而gid是这个扇区在网格中的位置。这是消息来源:

代码语言:javascript
复制
// Note:
//
// (1) m is the number of rows in matrices A and C.
//
// (2) n is the number of columns in matrix A; number of rows in matrix B.
//
// (3) k is the number of columns in matrices B and C.
//
// (4) Matrix multiple computes C = A^T * B where A is m x n matrix (so
//     that, A^T is n x m), B is n x k .
//
// (5) pbytes is stride in bytes from row to another of matrix A.
//     pbytes should be multiple of 32, i.e. A is padded to be
//     M x k matrix where M > m and P is multiple of 8.
//
// (6) Similarly qbytes is stride in bytes from one row to another
//     of B, i.e. B is n x K matrix where K > k matrix where K is
//     multiple of 8.
//
// (7) The output matrix C is the M x K matrix.

typedef struct
{
    ushort m, k, n, pbytes, qbytes;
} MetalMatrixDim;


kernel void MatrixMultiply(const device float*       A    [[ buffer(0) ]],
                           const device float*       B    [[ buffer(1) ]],
                           device float*             C    [[ buffer(2) ]],
                           constant MetalMatrixDim&  dims [[ buffer(3) ]],
                           ushort2                   gid  [[ thread_position_in_grid ]])
{
    ushort m = dims.m;
    ushort k = dims.k;
    ushort n = dims.n;

    ushort pbytes = dims.pbytes;
    ushort qbytes = dims.qbytes;

    // Multiply gid by 8 to get the absolute position in C
    ushort2 gidIn = ushort2(gid.x << 3, gid.y << 3);

    if (gidIn.x >= m || gidIn.y >= k) return;

    const device float4* a = (const device float4*)(A + gidIn.x);
    const device float4* b = (const device float4*)(B + gidIn.y);

    C = (device float*)((device char*)C + gidIn.x*qbytes);

    device float4* c = (device float4*)(C + gidIn.y);

    const device float4* Bend = (const device float4*)((const device char*)B + qbytes*n);

    float4 s0  = 0.0f, s1  = 0.0f, s2  = 0.0f, s3  = 0.0f;
    float4 s4  = 0.0f, s5  = 0.0f, s6  = 0.0f, s7  = 0.0f;
    float4 s8  = 0.0f, s9  = 0.0f, s10 = 0.0f, s11 = 0.0f;
    float4 s12 = 0.0f, s13 = 0.0f, s14 = 0.0f, s15 = 0.0f;

    do
    {
        float4 aCurr0 = a[0];
        float4 aCurr1 = a[1];
        float4 bCurr0 = b[0];
        float4 bCurr1 = b[1];

        s0   += (aCurr0.x * bCurr0);
        s2   += (aCurr0.y * bCurr0);
        s4   += (aCurr0.z * bCurr0);
        s6   += (aCurr0.w * bCurr0);

        s1   += (aCurr0.x * bCurr1);
        s3   += (aCurr0.y * bCurr1);
        s5   += (aCurr0.z * bCurr1);
        s7   += (aCurr0.w * bCurr1);

        s8   += (aCurr1.x * bCurr0);
        s10  += (aCurr1.y * bCurr0);
        s12  += (aCurr1.z * bCurr0);
        s14  += (aCurr1.w * bCurr0);

        s9   += (aCurr1.x * bCurr1);
        s11  += (aCurr1.y * bCurr1);
        s13  += (aCurr1.z * bCurr1);
        s15  += (aCurr1.w * bCurr1);

        a = (device float4*)((device char*)a + pbytes);
        b = (device float4*)((device char*)b + qbytes);

    } while(b < Bend);

    c[0] = s0;  c[1] = s1;  c = (device float4*)((device char*)c + qbytes);
    c[0] = s2;  c[1] = s3;  c = (device float4*)((device char*)c + qbytes);
    c[0] = s4;  c[1] = s5;  c = (device float4*)((device char*)c + qbytes);
    c[0] = s6;  c[1] = s7;  c = (device float4*)((device char*)c + qbytes);
    c[0] = s8;  c[1] = s9;  c = (device float4*)((device char*)c + qbytes);
    c[0] = s10; c[1] = s11; c = (device float4*)((device char*)c + qbytes);
    c[0] = s12; c[1] = s13; c = (device float4*)((device char*)c + qbytes);
    c[0] = s14; c[1] = s15;
}

我花了很多时间在这个问题上,但是我想出的最好的解决方案是一个不考虑内存延迟的天真的解决方案。相反,我希望修改苹果的代码,以消除A的转换,同时仍然允许GPU优化内存读写。

有人能帮我一下吗?

编辑:这里是我(非常)天真的实现。它的执行速度比苹果的内核慢了大约100倍:

代码语言:javascript
复制
int pbytes = (int)dims.pbytes;
int qbytes = (int)dims.qbytes;

for (int row = 0; row < 8; row++) {
    int aStart = (gidIn.y + row) * pbytes / 4;
    for (int col = 0; col < 8; col++) {
        int cIdx = gidIn.y + (row * qbytes / 4) + gidIn.x + col;
        int bStart = gidIn.x + col;
        float sum = 0.0f;
        for (int i = 0; i < (pbytes / 4); i++) {
            float prod = A[aStart + i] * B[bStart + (i * qbytes / 4)];
            sum += prod;
        }
        C[cIdx] = sum;
    }
}

这个实现的问题是它根本不优化内存访问。理想情况下,您可以一次读取和写入尽可能多的数据,从而使编译器能够将操作向量化。

EN

回答 1

Stack Overflow用户

发布于 2017-08-15 22:42:15

MetalPerformanceShaders框架有一个内置矩阵乘法内核,您只需将其编码到金属命令缓冲区中即可。我建议这样做,而不是浪费大量的时间在这里。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/43903290

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档