首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras predict_generator损坏的图像

Keras predict_generator损坏的图像
EN

Stack Overflow用户
提问于 2018-10-11 22:57:45
回答 2查看 465关注 0票数 1

我正在尝试使用Python3中的predict_generator,使用keras和tensorflow作为后端,使用我训练过的模型预测数百万张图像。生成器和模型预测工作,但是,目录中的一些图像被损坏或损坏,并导致predict_generator停止并抛出错误。一旦图像被删除,它就会再次工作,直到下一个损坏/损坏的图像通过该函数馈送。

由于有如此多的图像,因此运行脚本来打开每个图像并删除抛出错误的图像是不可行的。有没有办法将“如果损坏则跳过图像”参数合并到生成器或来自目录的流函数中?

任何帮助都是非常感谢的!

EN

回答 2

Stack Overflow用户

发布于 2018-10-11 23:09:01

ImageDataGenerator中没有这样的参数,在flow_from_directory方法中也没有这样的参数,因为你可以看到这两种方法的Keras文档(herehere)。一种解决方法是扩展ImageDataGenerator类并重载flow_from_directory方法,以便在将图像放入生成器之前检查图像是否已损坏。Here你可以找到它的源代码。

票数 1
EN

Stack Overflow用户

发布于 2020-10-13 14:40:56

因为它发生在预测期间,如果您跳过任何图像或批次,您需要跟踪哪些图像被跳过,以便您可以正确地将预测分数映射到图像文件名。

基于这种思想,我的DataGenerator是用一个有效的图像索引跟踪器实现的。特别是,关注变量valid_index,其中跟踪有效图像的索引。

代码语言:javascript
复制
class DataGenerator(keras.utils.Sequence):
    def __init__(self, df, batch_size, verbose=False, **kwargs):
        self.verbose = verbose
        self.df = df
        self.batch_size = batch_size
        self.valid_index = kwargs['valid_index']
        self.success_count = self.total_count = 0

    def __len__(self):
        return int(np.ceil(self.df.shape[0] / float(self.batch_size)))

    def __getitem__(self, idx):
        print('generator is loading batch ',idx)
        batch_df = self.df.iloc[idx * self.batch_size:(idx + 1) * self.batch_size]
        self.total_count += batch_df.shape[0]

        # return a list whose element is either an image array (when image is valid) or None(when image is corrupted)
        x = load_batch_image_to_arrays(batch_df['image_file_names'])

        # filter out corrupted images
        tmp = [(u, i) for u, i in zip(x, batch_df.index.values.tolist()) if
               u is not None]

        # boundary case. # all image failed, return another random batch
        if len(tmp) == 0:
            print('[ERROR] All images loading failed')
            # based on https://github.com/keras-team/keras/blob/master/keras/utils/data_utils.py#L621,
            # Keras will automatically find the next batch if it returns None
            return None

        print('successfully loaded image in {}th batch {}/{}'.format(str(idx), len(tmp), self.batch_size))
        self.success_count += len(tmp)

        x, batch_index = zip(*tmp) 
        x = np.stack(x)  # list to np.array
        self.valid_index[idx] = batch_index

        # follow preprocess input function provided by keras
        x = resnet50_preprocess(np.array(x, dtype=np.float))
        return x

    def on_epoch_end(self):
        print('total image count', self.total_count)
        print('successful images count', self.success_count)
        self.success_count = self.total_count = 0 # reset count after one epoch ends.

在预测过程中。

代码语言:javascript
复制
predictions = model.predict_generator(
            generator=data_gen,
            workers=10,
            use_multiprocessing=False,
            max_queue_size=20,
            verbose=1
        ).squeeze()
indexes = []
for i in sorted(data_gen.valid_index.keys()):
    indexes.extend(data_gen.valid_index[i])
result_df = df.loc[indexes]
result_df['score'] = predictions
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52763225

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档