首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在Tensorflow中复制PyTorch的nn.functional.unfold函数?

如何在Tensorflow中复制PyTorch的nn.functional.unfold函数?
EN

Stack Overflow用户
提问于 2020-10-25 20:01:56
回答 1查看 538关注 0票数 5

我想用tensorflow重写pytorch的torch.nn.functional.unfold函数:

代码语言:javascript
复制
#input x:[16, 1, 50, 36]
x = torch.nn.functional.unfold(x, kernel_size=(5, 36), stride=3)
#output x:[16, 180, 16]

我尝试使用函数tf.extract_image_patches()

x = tf.extract_image_patches(x,ksizes=[1, 1,5, 98],strides=[1, 1, 3, 1], rates=[1, 1, 1, 1],padding='VALID')

input x.shape[16,1,64,98]

我得到了x.shape[16,1,20,490]的输出

然后我将X重塑为[16,490,20],这是我所期望的。

但是当我提供数据时,我得到了错误:

代码语言:javascript
复制
UnimplementedError (see above for traceback): Only support ksizes across space.
[[Node:hcn/ExtractImagePatches = ExtractImagePatches[T=DT_FLOAT, ksizes=[1, 1, 5, 98], padding="VALID", rates=[1, 1, 1, 1], strides=[1, 1, 3, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](hcn/Reshape)]]

如何使用tensorflow重写pytorch torch.nn.functional.unfold函数来更改X

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-10-26 21:21:29

代码语言:javascript
复制
x = tf.reshape(x, [16, 50, 36, 1])
x = tf.extract_image_patches(x, ksizes=[1, 4, 98, 1], strides=[1, 4, 1, 1], rates=[1, 1, 1, 1], padding='VALID')
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64523441

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档