首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tf.data.Dataset apply()不更新dataset

tf.data.Dataset apply()不更新dataset
EN

Stack Overflow用户
提问于 2022-09-22 13:09:48
回答 1查看 30关注 0票数 0

我正在用image_dataset_from_directory加载一个图像数据集,它给我一个PrefetchDataset,其中包含我的图像和它们相关的标签,一个热编码。

为了构建一个二值图像分类器,我想转换我的PrefetchDataset标签,以知道图像是照片还是其他东西。

我是这样写的:

代码语言:javascript
复制
batch_size = 32
img_height = 250
img_width = 250

train_ds = image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  color_mode="rgb",
  subset="training",
  seed=69,
  crop_to_aspect_ratio=False,
  image_size=(img_height, img_width),
  batch_size=batch_size)

class_names = train_ds.class_names
# ['Painting', 'Photo', 'Schematics', 'Sketch', 'Text'] in my case

# Convert label to 1 is a photo or else 0
i = 1 # class_names.index('Photo')

def is_photo(batch):
    for images, labels in batch:
        bool_labels = tf.constant([int(l == 1) for l in labels],
                                  dtype=np.int32)
        labels = bool_labels
    return batch

new_train_ds = train_ds.apply(is_photo)

我的问题是,new_train_ds没有放弃train_ds,这导致我认为apply方法肯定有问题。我还检查了bool_labels,它运行得很好。

有没有人知道如何解决这个问题。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-09-26 10:20:40

也许可以试试这样的方法:

代码语言:javascript
复制
train_ds = train_ds.map(lambda x, y: (x, tf.cast(y == 1, dtype=tf.int64)))
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73815242

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档