首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >nd4j中的矩阵乘法广播

nd4j中的矩阵乘法广播
EN

Stack Overflow用户
提问于 2018-12-05 06:03:38
回答 1查看 292关注 0票数 1

在python中,假设

代码语言:javascript
复制
a = np.array(range(0,12)).reshape(2,2,3)
b = np.array(range(0,6)).reshape(3,2)
c = np.matmul(a,b) // a @ b

我们有

代码语言:javascript
复制
a: array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]]])

b: array([[0, 1],
       [2, 3],
       [4, 5]])

c: array([[[10, 13],
        [28, 40]],

       [[46, 67],
        [64, 94]]])

有人能帮助我在没有for循环的情况下在java nd4j中实现等价的操作吗?我试过broadcast.mul,但事实证明,broadcast.mul是按元素进行乘法的。我没有找到任何关于mmul的广播操作。

EN

回答 1

Stack Overflow用户

发布于 2018-12-05 15:24:54

我自己想出来的。答案如下所示,以防有人需要答案。利用Nd4j.tensorMmul,可以很容易地实现矩阵广播。例如:

代码语言:javascript
复制
val a = Nd4j.create(0d to 11d by 1d toArray, Array[Int](2, 2, 3))
val b = Nd4j.create(0d to 5d by 1d toArray, Array[Int](3, 2))
Nd4j.tensorMmul(a, b, Array(Array(2), Array(0))) // matrix broadcast

这是scala的代码。对于java,只需更改代码就可以创建数组。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53626094

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档