首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow中两个同级张量之间的广播

tensorflow中两个同级张量之间的广播
EN

Stack Overflow用户
提问于 2017-09-01 02:48:08
回答 2查看 106关注 0票数 0

我有两个具有形状的张量xs

代码语言:javascript
复制
> x.shape
TensorShape([Dimension(None), Dimension(3), Dimension(5), Dimension(5)])
> s.shape
TensorShape([Dimension(None), Dimension(12), Dimension(5), Dimension(5)])

我想通过维数x广播ss之间的点积,如下所示:

代码语言:javascript
复制
> x_s.shape
TensorShape([Dimension(None), Dimension(4), Dimension(5), Dimension(5)])

哪里

代码语言:javascript
复制
x_s[i, 0, k, l] = sum([x[i, j, k, l] * s[i, j, k, l] for j in range (3)])
x_s[i, 1, k, l] = sum([x[i, j-3, k, l] * s[i, j, k, l] for j in range (3, 6)])
x_s[i, 2, k, l] = sum([x[i, j-6, k, l] * s[i, j, k, l] for j in range (6, 9)])
x_s[i, 3, k, l] = sum([x[i, j-9, k, l] * s[i, j, k, l] for j in range (9, 12)])

我有这样的实施:

代码语言:javascript
复制
s_t = tf.transpose(s, [0, 2, 3, 1]) # [None, 5, 5, 12]
x_t = tf.transpose(x, [0, 2, 3, 1]) # [None, 5, 5, 3]
x_t = tf.tile(x_t, [1, 1, 1, 4]) # [None, 5, 5, 12]

x_s = x_t * s_t # [None, 5, 5, 12]
x_s = tf.reshape(x_s, [tf.shape(x_s)[0], 5, 5, 4, 3]) # [None, 5, 5, 4, 3]
x_s = tf.reduce_sum(x_s, axis=-1) # [None, 5, 5, 4]
x_s = tf.transpose(x_s, [0, 3, 1, 2]) # [None, 4, 5, 5]

我知道这在内存中是不有效的,因为tile。此外,reshape's、transpose's element-wisereduce_sum的操作也会对大张量的性能造成损害。还有其他方法可以让它变得更干净吗?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2017-09-01 03:42:56

你有证据证明reshape很贵吗?以下是使用整形和维度广播:

代码语言:javascript
复制
x_s = tf.reduce_sum(tf.reshape(s, (-1, 4, 3, 5, 5)) *
                    tf.expand_dims(x, axis=1), axis=2)
票数 1
EN

Stack Overflow用户

发布于 2017-09-01 03:31:37

只是一些建议,也许不会比你的更快。首先将stf.split拆分为四个张量,然后使用tf.tensordot获得最终结果,如下所示

代码语言:javascript
复制
splits = tf.split(s, [3] * 4, axis=1)
splits = map(lambda split: tf.tensordot(split, x, axes=[[1], [1]]), splits)
x_s = tf.stack(splits, axis=1)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45992667

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档