首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何处理无维数的张量乘法

如何处理无维数的张量乘法
EN

Stack Overflow用户
提问于 2021-11-27 22:35:52
回答 1查看 63关注 0票数 1

例如,当我使用时,我有两个张量A和B,它们都有维度(None,HWC)

代码语言:javascript
复制
tf.matmul(tf.transpose(A),B)

结果维度将是(HWC,HWC),这是正确的,但我希望保留None维度,这样它就可以是(None,HWC,HWC)。有什么办法可以做到这一点吗?

EN

回答 1

Stack Overflow用户

发布于 2021-11-28 14:37:18

也许可以试试这样的东西:

代码语言:javascript
复制
import tensorflow as tf

input1 = tf.keras.layers.Input(((32, 32, 3)))
input2 = tf.keras.layers.Input(((32, 32, 3)))
a = tf.keras.layers.Conv2D(64, (1, 1))(input1)
b = tf.keras.layers.Conv2D(64, (1, 1))(input2)
z = tf.matmul(a, b, transpose_a=True)
model = tf.keras.Model([input1, input2], z)
print(model.summary())
代码语言:javascript
复制
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_11 (InputLayer)          [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 input_12 (InputLayer)          [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d_17 (Conv2D)             (None, 32, 32, 64)   256         ['input_11[0][0]']               
                                                                                                  
 conv2d_18 (Conv2D)             (None, 32, 32, 64)   256         ['input_12[0][0]']               
                                                                                                  
 tf.linalg.matmul_4 (TFOpLambda  (None, 32, 64, 64)  0           ['conv2d_17[0][0]',              
 )                                                                'conv2d_18[0][0]']              
                                                                                                  
==================================================================================================
Total params: 512
Trainable params: 512
Non-trainable params: 0
__________________________________________________________________________________________________
None
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70139251

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档