首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用jcuda的cublasSgemmBatched用法

使用jcuda的cublasSgemmBatched用法
EN

Stack Overflow用户
提问于 2012-07-04 23:54:49
回答 1查看 289关注 0票数 1

我一直在尝试使用jcuda中的cublasSgemmBatched()函数进行矩阵乘法,但我不确定如何正确处理指针传递和批处理矩阵的向量。如果有人知道如何修改我的代码来正确处理这个问题,我将非常感激。在这个例子中,C数组在cublasGetVector之后保持不变。

代码语言:javascript
复制
public static void SsmmBatchJCublas(int m, int n, int k, float A[], float B[]){

    // Create a CUBLAS handle
    cublasHandle handle = new cublasHandle();
    cublasCreate(handle);

    // Allocate memory on the device
    Pointer d_A = new Pointer();
    Pointer d_B = new Pointer();
    Pointer d_C = new Pointer();


    cudaMalloc(d_A, m*k * Sizeof.FLOAT);
    cudaMalloc(d_B, n*k * Sizeof.FLOAT);
    cudaMalloc(d_C, m*n * Sizeof.FLOAT);

    float[] C = new float[m*n];
    // Copy the memory from the host to the device
    cublasSetVector(m*k, Sizeof.FLOAT, Pointer.to(A), 1, d_A, 1);
    cublasSetVector(n*k, Sizeof.FLOAT, Pointer.to(B), 1, d_B, 1);
    cublasSetVector(m*n, Sizeof.FLOAT, Pointer.to(C), 1, d_C, 1);

    Pointer[] Aarray = new Pointer[]{d_A};
    Pointer AarrayPtr = Pointer.to(Aarray);
    Pointer[] Barray = new Pointer[]{d_B};
    Pointer BarrayPtr = Pointer.to(Barray);
    Pointer[] Carray = new Pointer[]{d_C};
    Pointer CarrayPtr = Pointer.to(Carray);

    // Execute sgemm
    Pointer pAlpha = Pointer.to(new float[]{1});
    Pointer pBeta = Pointer.to(new float[]{0});


    cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, pAlpha, AarrayPtr, Aarray.length, BarrayPtr, Barray.length, pBeta, CarrayPtr, Carray.length, Aarray.length);
    // Copy the result from the device to the host
    cublasGetVector(m*n, Sizeof.FLOAT, d_C, 1, Pointer.to(C), 1);

    // Clean up
    cudaFree(d_A);
    cudaFree(d_B);
    cudaFree(d_C);
    cublasDestroy(handle);
}
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2012-09-12 21:29:24

我在jcuda官方论坛上询问,很快就得到了here的回答。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/11332327

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档