我正在用image_dataset_from_directory加载一个图像数据集,它给我一个PrefetchDataset,其中包含我的图像和它们相关的标签,一个热编码。
为了构建一个二值图像分类器,我想转换我的PrefetchDataset标签,以知道图像是照片还是其他东西。
我是这样写的:
batch_size = 32
img_height = 250
img_width = 250
train_ds = image_dataset_from_directory(
data_dir,
validation_split=0.2,
color_mode="rgb",
subset="training",
seed=69,
crop_to_aspect_ratio=False,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
# ['Painting', 'Photo', 'Schematics', 'Sketch', 'Text'] in my case
# Convert label to 1 is a photo or else 0
i = 1 # class_names.index('Photo')
def is_photo(batch):
for images, labels in batch:
bool_labels = tf.constant([int(l == 1) for l in labels],
dtype=np.int32)
labels = bool_labels
return batch
new_train_ds = train_ds.apply(is_photo)我的问题是,new_train_ds没有放弃train_ds,这导致我认为apply方法肯定有问题。我还检查了bool_labels,它运行得很好。
有没有人知道如何解决这个问题。
发布于 2022-09-26 10:20:40
也许可以试试这样的方法:
train_ds = train_ds.map(lambda x, y: (x, tf.cast(y == 1, dtype=tf.int64)))https://stackoverflow.com/questions/73815242
复制相似问题