首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow1和Tensorflow2中的批处理

Tensorflow1和Tensorflow2中的批处理
EN

Stack Overflow用户
提问于 2021-04-21 19:56:32
回答 1查看 44关注 0票数 0

我正在尝试将图像单应性代码从TF1版本转换为TF2,只是TF脚本转换在这里不起作用。我坚持批量处理数据集,因为图像,image_patch和image_Indices具有不同的形状。虽然TF1在摄取和批处理数据集方面没有问题,但TF2在这方面遇到了麻烦。

代码语言:javascript
复制
imgs= np.random.rand(11,240,320,3)
pts = np.random.randint(100, size =(11,8))
patch = np.random.rand(11,128,128,1)

imgs = tf.convert_to_tensor(imgs)
pts = tf.convert_to_tensor(pts)
patch = tf.convert_to_tensor(patch)

pts= tf.cast(pts,dtype=tf.float64)

tensorflow2:

代码语言:javascript
复制
    img_batch,pts_batch,patch_batch = tf.data.Dataset.from_tensor_slices([imgs,pts,patch]).shuffle(buffer_size=batch_size*4)

这里,11是图像数,240和320是图像尺寸,3是通道数。

错误-

代码语言:javascript
复制
tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [11,240,320,3] != values[2].shape = [11,128,128,1] [Op:Pack] name: component_0

tensorflow1:

代码语言:javascript
复制
tf.compat.v1.train.batch([imgs,pts,patch], batch_size=5)

输出-

代码语言:javascript
复制
[<tf.Tensor 'batch_2:0' shape=(5, 11, 240, 320, 3) dtype=float64>,
 <tf.Tensor 'batch_2:1' shape=(5, 11, 8) dtype=float64>,
 <tf.Tensor 'batch_2:2' shape=(5, 11, 128, 128, 1) dtype=float64>]

如何在tensorflow2中批量处理不同维度的数据集?同样运行时,"tf.compat.v1.train.batch()“在TF2 (版本2.3)中不起作用,因为它给出了急切的执行错误。

在TF2中对这些数据集进行批处理的正确方法是什么?

EN

回答 1

Stack Overflow用户

发布于 2021-07-20 17:45:15

这里的问题不是批处理,而是tf.data.Dataset本身的生成。错误是由img_batch,pts_batch,patch_batch = tf.data.Dataset.from_tensor_slices([imgs,pts,patch])引起的,而不是由.shuffle(batch_size=...)引起的。

我认为这里的.from_tensor_slices级别太高了,请查看tf.data.Dataset.from_generator

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67195414

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档