首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow成批乘法广播

Tensorflow成批乘法广播
EN

Stack Overflow用户
提问于 2017-01-11 13:41:46
回答 1查看 4.4K关注 0票数 2

我们知道tf.multiply可以像这样广播:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np
a = tf.Variable(np.arange(12).reshape(3, 4))
b = tf.Variable(np.arange(4))
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(tf.multiply(a, b))

这会给我们

代码语言:javascript
复制
[[0, 1, 4, 9],
 [0, 5, 12, 21],
 [0, 9, 20, 33]]

但我的问题是,如果ab都是分批的,我该怎么办?那是,

代码语言:javascript
复制
a = tf.Variable(np.arange(24).reshape(2, 3, 4))
b = tf.Variable(np.arange(8).reshape(2, 4))

那么,如何才能得到将向量乘(广播)到每批矩阵的结果呢?如下所示:

代码语言:javascript
复制
[[[0, 1, 4, 9],
  [0, 5, 12, 21],
  [0, 9, 20, 33]],

 [[48, 65, 84, 105],
  [64, 85, 108, 133],
  [80, 105, 132, 161]]]

谢谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-01-11 16:04:50

广播首先向左边添加单点维数,直到秩匹配。在第一种情况下,添加批处理维度。但在第二种情况下,您已经有了批处理维度,因此需要在第二个位置手动插入单例维度:

代码语言:javascript
复制
a = tf.reshape(tf.range(24), (2, 3, 4))
b = tf.reshape(tf.range(8), (2, 4))
sess.run(tf.mul(a, tf.expand_dims(b, 1)))
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/41592537

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档