我试图做一个超分辨率的网络,但我在导入我自己的数据时遇到了困难。我有两种类型的图像:调整大小的图像(较小的),原始图像。第一个将用作网络的输入,第二个将用于网络的培训。
问题是,我需要批量加载图像,因为我的计算机没有足够的GPU内存来同时构建整个数据集。我认为使用以下代码可以工作:
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)问题是,我只知道如何使其工作在分类问题,因为,就我而言,它是专为只有培训和验证数据集。
为了实现超分辨率,我需要四个数据集:
正常尺寸列车
小型列车
正常尺寸试验
小尺寸试验
注意:当我为调整大小的图像创建张量和为原始图像创建另一个张量时,我的程序可以工作,但是现在我想实现一个更大的数据集。
发布于 2021-12-09 20:57:43
我认为最好为这个任务实现一个数据生成器。举个例子。如果数据集中的图像不具有相同的形状,则可以/必须添加图像整形。
def image_generator(path, batch_size=16):
list_path = glob.glob(path)
index = 0
list_of_low_dim_images = []
list_of_high_dim_images = []
size = len(list_path)
while True:
index +=1
for path in list_path:
path2 = path.replace("small", "normal")
small_img = tf.io.read_file(path)
small_img = decode_img(small_img)
normal_img = tf.io.read_file(path2)
normal_img = decode_img(normal_img)
list_of_low_dim_images.append(small_img)
list_of_high_dim_images.append(normal_img)
if index == batch_size:
inedx = 0
yield list_of_low_dim_images,list_of_high_dim_imageshttps://stackoverflow.com/questions/70295859
复制相似问题