首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在TF 2中使用带有自定义函数的tf.data.Dataset.interleave()?

如何在TF 2中使用带有自定义函数的tf.data.Dataset.interleave()?
EN

Stack Overflow用户
提问于 2020-05-31 01:02:30
回答 1查看 915关注 0票数 2

我正在使用TF 2.2,并尝试使用tf.data来创建管道。

下面的代码运行良好:

代码语言:javascript
复制
def load_image(filePath, label):

    print('Loading File: {}' + filePath)
    raw_bytes = tf.io.read_file(filePath)
    image = tf.io.decode_image(raw_bytes, expand_animations = False)

    return image, label

# TrainDS Pipeline
trainDS = getDataset()
trainDS = trainDS.shuffle(size['train'])
trainDS = trainDS.map(load_image, num_parallel_calls=AUTOTUNE)

for d in trainDS:
    print('Image: {} - Label: {}'.format(d[0], d[1]))

我想在Dataset.interleave()上使用load_image()。然后我试着:

代码语言:javascript
复制
# TrainDS Pipeline
trainDS = getDataset()
trainDS = trainDS.shuffle(size['train'])
trainDS = trainDS.interleave(lambda x, y: load_image_with_label(x, y), cycle_length=4)

for d in trainDS:
    print('Image: {} - Label: {}'.format(d[0], d[1]))

但是我得到了以下错误:

代码语言:javascript
复制
Exception has occurred: TypeError
`map_func` must return a `Dataset` object. Got <class 'tuple'>
  File "/data/dev/train_daninhas.py", line 44, in <module>
    trainDS = trainDS.interleave(lambda x, y: load_image_with_label(x, y), cycle_length=4)

如何调整我的代码,使Dataset.interleave()load_image()协同工作,以并行读取图像?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-02 21:55:24

如错误所示,您需要修改load_image,使其返回一个Dataset对象,我已经展示了一个示例,其中包含两个图像,说明如何在tensorflow 2.2.0中执行此操作

代码语言:javascript
复制
import tensorflow as tf
filenames = ["./img1.jpg", "./img2.jpg"]
labels = ["A", "B"]

def load_image(filePath, label):
    print('Loading File: {}' + filePath)
    raw_bytes = tf.io.read_file(filePath)
    image = tf.io.decode_image(raw_bytes, expand_animations = False)
    return tf.data.Dataset.from_tensors((image, label))

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.interleave(lambda x, y: load_image(x, y), cycle_length=4)

for i in dataset.as_numpy_iterator():
    image = i[0]
    label = i[1]
    print(image.shape)
    print(label.decode())

# (275, 183, 3)
# A
# (275, 183, 3)
# B

希望这能有所帮助!

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62105896

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档