首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在DJL (Deep Java Library)中调用自定义mxnet运算符?

如何在DJL (Deep Java Library)中调用自定义mxnet运算符?
EN

Stack Overflow用户
提问于 2021-04-11 23:09:20
回答 1查看 36关注 0票数 0

如何从DJL调用自定义mxnet运算符?例如来自examplesmy_gemm运算符。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-04-11 23:09:20

可以像内置的mxnet引擎一样手动调用JnaUtils,只需使用定制的库即可。对于my_gemm示例,如下所示:

代码语言:javascript
复制
import ai.djl.Device;
import ai.djl.mxnet.jna.FunctionInfo;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.util.PairList;

import java.util.Map;

// Load the external mxnet operator library
JnaUtils.loadLib("path/to/incubator-mxnet/example/extensions/lib_custom_op/libgemm_lib.so", 1);
// get a handle to the loaded operator
Map<String, FunctionInfo> allFunctionsAfterLoading = JnaUtils.getNdArrayFunctions();
FunctionInfo myGemmFunction = allFunctionsAfterLoading.get("my_gemm");
// create a manager to execute the example with
try (NDManager ndManager = NDManager.newBaseManager().newSubManager(Device.cpu())) {
    // create input for the gemm call
    NDArray a = ndManager.create(new float[][]{new float[]{1, 2, 3}, new float[]{4, 5, 6}});
    NDArray b = ndManager.create(new float[][]{new float[]{7}, new float[]{8}, new float[]{9}});
    // call the function manually (NDManager.invoke will not work, as it caches the mxnet
    // engine operators and ignores external ones)
    PairList<String, Object> params = new PairList<>();
    NDArray result = myGemmFunction.invoke(ndManager, new NDArray[]{a, b}, params)[0];
    // prints
    // ND: (2, 1) cpu() float32
    //[[ 50.],
    // [122.],
    //]
    // (same as the python example)
    System.out.println(result);
}
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67046609

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档