首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras ImageDataGenerator Slow

Keras ImageDataGenerator Slow
EN

Stack Overflow用户
提问于 2016-12-10 11:24:15
回答 2查看 4.2K关注 0票数 8

我正在寻找在Keras中训练大于内存的数据的最佳方法,目前注意到普通的ImageDataGenerator往往比我希望的要慢。

我在Kaggle猫与狗的数据集(25000张图片)上进行了两个网络训练:

1)这种方法的代码来自:http://www.pyimagesearch.com/2016/09/26/a-simple-neural-network-with-python-and-keras/

2)与(1)相同,但使用ImageDataGenerator而不是将数据加载到内存中

注意:在下文中,“预处理”是指调整大小、缩放、展平。

我在我的gtx970上找到了以下内容:

对于网络1,每个时期需要~0。

对于网络2,如果在数据生成器中进行预处理,则每个时期需要约36s。

对于网络2,如果在数据生成器外部的第一遍中完成预处理,则每个时期需要~13s。

这可能是ImageDataGenerator的速度限制(13s看起来像是磁盘和内存之间通常的10-100倍的差异……)?在使用Keras时,是否有更适合在大于内存的数据上进行训练的方法/机制?例如,也许有办法在Keras中让ImageDataGenerator在第一个纪元之后保存它处理过的图像?

谢谢!

EN

回答 2

Stack Overflow用户

发布于 2017-06-10 05:30:47

我想你可能已经解决了这个问题,但是不管怎样...

Keras图像预处理可以选择通过在flow()flow_from_directory()函数中设置save_to_dir参数来保存结果:

https://keras.io/preprocessing/image/

票数 2
EN

Stack Overflow用户

发布于 2020-10-24 15:18:32

在我的理解中,问题是增强的图像在模型的训练周期中只使用一次,甚至不会在几个时期中使用。因此,当CPU处于挣扎状态时,这是对GPU周期的巨大浪费。我找到了以下解决方案:

  1. I在内存中生成尽可能多的增强,
  2. 我使用它们在10到30个时期的帧上进行训练,以获得明显的增强图像,我生成新一批增强图像(通过实现on_epoch_end),并继续进行处理。

这种方法在大多数情况下使GPU保持繁忙,同时能够从数据增强中受益。我使用定制的Sequence子类来生成增强,同时修复类的不平衡。

编辑:添加一些代码以阐明想法

代码语言:javascript
复制
from pyutilz.string import read_config_file
from tqdm.notebook import tqdm
from gc import collect
import numpy as np
import tensorflow
import random
import cv2

class StoppingFromFile(tensorflow.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if read_config_file('control.ini','ML','stop',globals()):        
            if stop is not None:        
                if stop==True or stop=='True':
                    logging.warning(f'Model should be stopped according to the control fole')
                    self.model.stop_training = True

class AugmentedBalancedSequence(tensorflow.keras.utils.Sequence):
    def __init__(self, images_and_classes:dict,input_size:tuple,class_sizes:list, augmentations_fn:object, preprocessing_fn:object, batch_size:int=10,
                 num_class_samples=100, frame_length:int=5, aug_p:float=0.1,aug_pipe_p:float=0.2,is_validation:bool=False,
                disk_saving_prob:float=.01,disk_example_nfiles:int=50):
        """
            From a dict of file paths grouped by class label, creates each N epochs augmented balanced training set.
            If current class is too scarce, ensures that current frame has no duplicate final images.
            If it's rich enough, ensures that current frame has no duplicate base images.
        
        """
        logging.info(f'Got {len(images_and_classes)} classes.')
        self.disk_example_nfiles=disk_example_nfiles;self.disk_saving_prob=disk_saving_prob;self.cur_example_file=0
        
        self.images_and_classes=images_and_classes        
        self.num_class_samples=num_class_samples
        self.augmentations_fn=augmentations_fn
        self.preprocessing_fn=preprocessing_fn
        
        self.is_validation=is_validation
        self.frame_length=frame_length                    
        self.batch_size = batch_size      
        self.class_sizes=class_sizes
        self.input_size=input_size        
        self.aug_pipe_p=aug_pipe_p
        self.aug_p=aug_p        
        self.images=None
        self.epoch = 0
        #print(f'got frame_length={self.frame_length}')
        self._generate_data()
        

    def __len__(self):
        return int(np.ceil(len(self.images)/ float(self.batch_size)))

    def __getitem__(self, idx):
        a=idx * self.batch_size;b=a+self.batch_size
        return self.images[a:b],self.labels[a:b]
    
    def on_epoch_end(self):
        import ast
        self.epoch += 1    
        mydict={}

        import pathlib
        fname='control.json'
        p = pathlib.Path(fname)
        if p.is_file():
            try:
                with open (fname) as f:
                    mydict=json.load(f)
                for var,val in mydict.items():
                    if hasattr(self,var):
                        converted = val #ast.literal_eval(val)
                        if converted is not None:
                            if getattr(self, var)!=converted:
                                setattr(self, var, converted)                                        
                                print(f'{var} became {val}')
            except Exception as e:
                logging.error(str(e))
        if self.epoch % self.frame_length == 0:
            #print('generating data...')
            self._generate_data()
            
    def _add_sample(self,image,label):
        from random import random
        idx=self.indices[self.img_sent]
        
        if self.disk_saving_prob>0:
            if random()<self.disk_saving_prob:
                self.cur_example_file+=1
                if self.cur_example_file>self.disk_example_nfiles:
                    self.cur_example_file=1
                Path(r'example_images/').mkdir(parents=True, exist_ok=True)
                cv2.imwrite(f'example_images/test{self.cur_example_file}.jpg',cv2.cvtColor(image,cv2.COLOR_RGB2BGR))
        
        if self.preprocessing_fn: 
            self.images[idx]=self.preprocessing_fn(image)
        else:
            self.images[idx]=image
        
        self.labels[idx]=label
        self.img_sent+=1        
        
    def _generate_data(self):
        logging.info('Generating new set of augmented data...')
        
        collect()
        #del self.images
        #del self.labels        
        #collect()
        
        if self.num_class_samples:
            expected_length=len(self.images_and_classes)*self.num_class_samples
        else:
            expected_length=sum(self.class_sizes.values())        
            
        if self.images is None:
            self.images=np.empty((expected_length,)+(self.input_size[1],)+(self.input_size[0],)+(3,))
            self.labels=np.empty((expected_length),np.int32)
        
        self.indices=np.random.choice(expected_length, expected_length, replace=False)
        self.img_sent=0
        
        
        collect()
        
        relaxed_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=self.aug_pipe_p)
        maxed_out_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=1.0)
        
        #for each class
        x,y=[],[]
        nartificial=0
        for label,images in tqdm(self.images_and_classes.items()):
            if self.num_class_samples is None:
                #Just all native samples without augmentations
                for image in images:
                    self._add_sample(image,label)                        
            else:
                #if there are enough native samples
                if len(images)>=self.num_class_samples:
                    #randomly select samples of this class which will participate in this frame of epochs                
                    indices=np.random.choice(len(images), self.num_class_samples, replace=False)
                    #apply albumentations pipeline to selected samples

                    for idx in indices:
                        if not self.is_validation:
                            self._add_sample(relaxed_augmentation_pipeline(image=images[idx])['image'],label)
                        else:
                            self._add_sample(images[idx],label)
                                                    
                else:
                    #------------------------------------------------------------------------------------------------------------------------------------------------------------------
                    # Randomly pick next image from existing. try applying augmentation pipeline (with maxed out probability) till we get num_class_samples DIFFERENT images
                    #------------------------------------------------------------------------------------------------------------------------------------------------------------------
                    hashes=set()
                    norig=0
                    while len(hashes)<self.num_class_samples:
                        if self.is_validation and norig<len(images):
                            #just include all originals first
                            image=images[norig]
                        else:
                            image=maxed_out_augmentation_pipeline(image=random.choice(images))['image']                                                      
                        next_hash=np.sum(image)
                        if next_hash not in hashes or (self.is_validation and norig<=len(images)):                        
                            
                            #print(f'Adding orig {norig} out of {self.num_class_samples}, hashes={hashes}')
                            
                            self._add_sample(image,label)
                            if next_hash in hashes:
                                norig+=1
                                hashes.add(norig)
                            else:
                                hashes.add(next_hash)
                                nartificial+=1  
                                
        
        #self.images=self.images[indices];self.labels=self.labels[indices]                              
        
        logging.info(f'Generated {self.img_sent} samples ({nartificial} artificial)')

一旦我加载了图像和类,

代码语言:javascript
复制
train_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_train,
                          input_size=INPUT_SIZE,class_sizes=class_sizes_train,num_class_samples=UPSCALE_SAMPLES,
    augmentations_fn=get_albumentations_pipeline,aug_p=AUG_P,aug_pipe_p=AUG_PIPE_P,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,disk_saving_prob=0.05)

val_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_val,
                                        input_size=INPUT_SIZE,class_sizes=class_sizes_val,num_class_samples=None,
    augmentations_fn=get_albumentations_pipeline,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,is_validation=True)

在实例化模型之后,我这样做

代码语言:javascript
复制
model.fit(train_datagen,epochs=600,verbose=1,
          validation_data=(val_datagen.images,val_datagen.labels),validation_batch_size=BATCH_SIZE,
          callbacks=[checkpointer,StoppingFromFile()],validation_freq=1)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/41071842

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档