首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用Keras ImageDataGenerator为pix2pix CNN模型提供数据?

如何使用Keras ImageDataGenerator为pix2pix CNN模型提供数据?
EN

Stack Overflow用户
提问于 2021-08-27 17:17:07
回答 1查看 32关注 0票数 0

我正在尝试使用keras ImageDataGenerator来训练pix2pix CNN模型。它将输入图像映射到输出图像。我们知道pix2pix ImageDataGenerator可以很容易地用于图像分类,但我在训练keras模型时遇到了问题。这是我的尝试:

自定义生成器:

代码语言:javascript
复制
class JoinedGen(tf.keras.utils.Sequence):
    def __init__(self, input_gen, target_gen):
        self.input_gen = input_gen
        self.target_gen = target_gen

        assert len(input_gen) == len(target_gen)

    def __len__(self):
        return len(self.input_gen)

    def __getitem__(self, i):
        x = self.input_gen[i]
        y = self.target_gen[i]

        return x, y

    def on_epoch_end(self):
        self.input_gen.on_epoch_end()
        self.target_gen.on_epoch_end()
        self.target_gen.index_array = self.input_gen.index_array

使用ImageDataGenerator实现:

代码语言:javascript
复制
generator = ImageDataGenerator(shear_range=0.2,
       zoom_range=0.2,
       horizontal_flip=True,
       validation_split=0.3) 

  
input_gen = generator.flow_from_directory(path, 
                                          classes=['area'], 
                                          shuffle=False,
                                          target_size=(256, 256),   
                                          class_mode=None,
                                          batch_size=32,
                                          subset='training')

target_gen = generator.flow_from_directory(path, 
                                          classes=['sat'],
                                          shuffle=False, 
                                          target_size=(256, 256),
                                          class_mode=None,
                                          batch_size=32,
                                          subset='training')

input_gen_val = generator.flow_from_directory(path, 
                                          classes=['area'], 
                                          shuffle=False,
                                          target_size=(256, 256),   
                                          class_mode=None,
                                          batch_size=32,
                                          subset='validation')

target_gen_val = generator.flow_from_directory(path, 
                                          classes=['sat'],   
                                          shuffle=False, 
                                          target_size=(256, 256),
                                          class_mode=None,
                                          batch_size=32,
                                          subset='validation')

但是,当我使用input_gen.next()[0]target_gen.next()[0]请求两个训练生成器的第一个图像时,它没有给我相应的输入和输出!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-09-02 21:51:56

正如Keras文档中所说,解决方案是“向fit和flow方法提供相同的种子和关键字参数- seed = 1”。

只需添加到flow_from_directory方法seed = 1

有关更多信息,请查看链接here

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68957206

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档