首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为非分类问题导入成批图像

为非分类问题导入成批图像
EN

Stack Overflow用户
提问于 2021-12-09 19:46:33
回答 1查看 40关注 0票数 0

我试图做一个超分辨率的网络,但我在导入我自己的数据时遇到了困难。我有两种类型的图像:调整大小的图像(较小的),原始图像。第一个将用作网络的输入,第二个将用于网络的培训。

问题是,我需要批量加载图像,因为我的计算机没有足够的GPU内存来同时构建整个数据集。我认为使用以下代码可以工作:

代码语言:javascript
复制
train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

问题是,我只知道如何使其工作在分类问题,因为,就我而言,它是专为只有培训和验证数据集。

为了实现超分辨率,我需要四个数据集:

正常尺寸列车

小型列车

正常尺寸试验

小尺寸试验

注意:当我为调整大小的图像创建张量和为原始图像创建另一个张量时,我的程序可以工作,但是现在我想实现一个更大的数据集。

EN

回答 1

Stack Overflow用户

发布于 2021-12-09 20:57:43

我认为最好为这个任务实现一个数据生成器。举个例子。如果数据集中的图像不具有相同的形状,则可以/必须添加图像整形。

代码语言:javascript
复制
def image_generator(path, batch_size=16):
    list_path = glob.glob(path)
    index = 0
    list_of_low_dim_images = []
    list_of_high_dim_images = []
    size = len(list_path)
    while True:
        index +=1
        for path in list_path:
            path2 = path.replace("small", "normal")
            small_img = tf.io.read_file(path)
            small_img = decode_img(small_img)
            normal_img = tf.io.read_file(path2)
            normal_img = decode_img(normal_img)
            list_of_low_dim_images.append(small_img)
            list_of_high_dim_images.append(normal_img)
            if index == batch_size:
                inedx = 0
                yield list_of_low_dim_images,list_of_high_dim_images
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70295859

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档