首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在TensorFlow 2.0中使用padded_batch()

如何在TensorFlow 2.0中使用padded_batch()
EN

Stack Overflow用户
提问于 2019-12-21 08:45:23
回答 1查看 506关注 0票数 1
代码语言:javascript
复制
X = tf.range(10)
dataset = tf.data.Dataset.from_tensor_slices(X)
dataset2 = dataset.repeat(3).padded_batch(7, padded_shapes=([]))
for item in dataset2:
    print(item)

输出

代码语言:javascript
复制
tf.Tensor([0 1 2 3 4 5 6], shape=(7,), dtype=int32)
tf.Tensor([7 8 9 0 1 2 3], shape=(7,), dtype=int32)
tf.Tensor([4 5 6 7 8 9 0], shape=(7,), dtype=int32)
tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int32)
tf.Tensor([8 9], shape=(2,), dtype=int32)

如何定义padded_shapes来获得像下面这样的结果?

代码语言:javascript
复制
tf.Tensor([0 1 2 3 4 5 6], shape=(7,), dtype=int32)
tf.Tensor([7 8 9 0 1 2 3], shape=(7,), dtype=int32)
tf.Tensor([4 5 6 7 8 9 0], shape=(7,), dtype=int32)
tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int32)
tf.Tensor([8 9 0 0 0 0 0], shape=(7,), dtype=int32)
EN

回答 1

Stack Overflow用户

发布于 2019-12-24 01:39:03

我用batch(7)解决了这个问题。

代码语言:javascript
复制
dataset2 = dataset.repeat(3).batch(7).padded_batch(7, padded_shapes=([None]))

输出

代码语言:javascript
复制
tf.Tensor(
[[0 1 2 3 4 5 6]
 [7 8 9 0 1 2 3]
 [4 5 6 7 8 9 0]
 [1 2 3 4 5 6 7]
 [8 9 0 0 0 0 0]], shape=(5, 7), dtype=int32)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59432717

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档