我正在寻找在Keras中训练大于内存的数据的最佳方法,目前注意到普通的ImageDataGenerator往往比我希望的要慢。
我在Kaggle猫与狗的数据集(25000张图片)上进行了两个网络训练:
1)这种方法的代码来自:http://www.pyimagesearch.com/2016/09/26/a-simple-neural-network-with-python-and-keras/
2)与(1)相同,但使用ImageDataGenerator而不是将数据加载到内存中
注意:在下文中,“预处理”是指调整大小、缩放、展平。
我在我的gtx970上找到了以下内容:
对于网络1,每个时期需要~0。
对于网络2,如果在数据生成器中进行预处理,则每个时期需要约36s。
对于网络2,如果在数据生成器外部的第一遍中完成预处理,则每个时期需要~13s。
这可能是ImageDataGenerator的速度限制(13s看起来像是磁盘和内存之间通常的10-100倍的差异……)?在使用Keras时,是否有更适合在大于内存的数据上进行训练的方法/机制?例如,也许有办法在Keras中让ImageDataGenerator在第一个纪元之后保存它处理过的图像?
谢谢!
发布于 2017-06-10 05:30:47
我想你可能已经解决了这个问题,但是不管怎样...
Keras图像预处理可以选择通过在flow()或flow_from_directory()函数中设置save_to_dir参数来保存结果:
发布于 2020-10-24 15:18:32
在我的理解中,问题是增强的图像在模型的训练周期中只使用一次,甚至不会在几个时期中使用。因此,当CPU处于挣扎状态时,这是对GPU周期的巨大浪费。我找到了以下解决方案:
这种方法在大多数情况下使GPU保持繁忙,同时能够从数据增强中受益。我使用定制的Sequence子类来生成增强,同时修复类的不平衡。
编辑:添加一些代码以阐明想法
from pyutilz.string import read_config_file
from tqdm.notebook import tqdm
from gc import collect
import numpy as np
import tensorflow
import random
import cv2
class StoppingFromFile(tensorflow.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
if read_config_file('control.ini','ML','stop',globals()):
if stop is not None:
if stop==True or stop=='True':
logging.warning(f'Model should be stopped according to the control fole')
self.model.stop_training = True
class AugmentedBalancedSequence(tensorflow.keras.utils.Sequence):
def __init__(self, images_and_classes:dict,input_size:tuple,class_sizes:list, augmentations_fn:object, preprocessing_fn:object, batch_size:int=10,
num_class_samples=100, frame_length:int=5, aug_p:float=0.1,aug_pipe_p:float=0.2,is_validation:bool=False,
disk_saving_prob:float=.01,disk_example_nfiles:int=50):
"""
From a dict of file paths grouped by class label, creates each N epochs augmented balanced training set.
If current class is too scarce, ensures that current frame has no duplicate final images.
If it's rich enough, ensures that current frame has no duplicate base images.
"""
logging.info(f'Got {len(images_and_classes)} classes.')
self.disk_example_nfiles=disk_example_nfiles;self.disk_saving_prob=disk_saving_prob;self.cur_example_file=0
self.images_and_classes=images_and_classes
self.num_class_samples=num_class_samples
self.augmentations_fn=augmentations_fn
self.preprocessing_fn=preprocessing_fn
self.is_validation=is_validation
self.frame_length=frame_length
self.batch_size = batch_size
self.class_sizes=class_sizes
self.input_size=input_size
self.aug_pipe_p=aug_pipe_p
self.aug_p=aug_p
self.images=None
self.epoch = 0
#print(f'got frame_length={self.frame_length}')
self._generate_data()
def __len__(self):
return int(np.ceil(len(self.images)/ float(self.batch_size)))
def __getitem__(self, idx):
a=idx * self.batch_size;b=a+self.batch_size
return self.images[a:b],self.labels[a:b]
def on_epoch_end(self):
import ast
self.epoch += 1
mydict={}
import pathlib
fname='control.json'
p = pathlib.Path(fname)
if p.is_file():
try:
with open (fname) as f:
mydict=json.load(f)
for var,val in mydict.items():
if hasattr(self,var):
converted = val #ast.literal_eval(val)
if converted is not None:
if getattr(self, var)!=converted:
setattr(self, var, converted)
print(f'{var} became {val}')
except Exception as e:
logging.error(str(e))
if self.epoch % self.frame_length == 0:
#print('generating data...')
self._generate_data()
def _add_sample(self,image,label):
from random import random
idx=self.indices[self.img_sent]
if self.disk_saving_prob>0:
if random()<self.disk_saving_prob:
self.cur_example_file+=1
if self.cur_example_file>self.disk_example_nfiles:
self.cur_example_file=1
Path(r'example_images/').mkdir(parents=True, exist_ok=True)
cv2.imwrite(f'example_images/test{self.cur_example_file}.jpg',cv2.cvtColor(image,cv2.COLOR_RGB2BGR))
if self.preprocessing_fn:
self.images[idx]=self.preprocessing_fn(image)
else:
self.images[idx]=image
self.labels[idx]=label
self.img_sent+=1
def _generate_data(self):
logging.info('Generating new set of augmented data...')
collect()
#del self.images
#del self.labels
#collect()
if self.num_class_samples:
expected_length=len(self.images_and_classes)*self.num_class_samples
else:
expected_length=sum(self.class_sizes.values())
if self.images is None:
self.images=np.empty((expected_length,)+(self.input_size[1],)+(self.input_size[0],)+(3,))
self.labels=np.empty((expected_length),np.int32)
self.indices=np.random.choice(expected_length, expected_length, replace=False)
self.img_sent=0
collect()
relaxed_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=self.aug_pipe_p)
maxed_out_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=1.0)
#for each class
x,y=[],[]
nartificial=0
for label,images in tqdm(self.images_and_classes.items()):
if self.num_class_samples is None:
#Just all native samples without augmentations
for image in images:
self._add_sample(image,label)
else:
#if there are enough native samples
if len(images)>=self.num_class_samples:
#randomly select samples of this class which will participate in this frame of epochs
indices=np.random.choice(len(images), self.num_class_samples, replace=False)
#apply albumentations pipeline to selected samples
for idx in indices:
if not self.is_validation:
self._add_sample(relaxed_augmentation_pipeline(image=images[idx])['image'],label)
else:
self._add_sample(images[idx],label)
else:
#------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Randomly pick next image from existing. try applying augmentation pipeline (with maxed out probability) till we get num_class_samples DIFFERENT images
#------------------------------------------------------------------------------------------------------------------------------------------------------------------
hashes=set()
norig=0
while len(hashes)<self.num_class_samples:
if self.is_validation and norig<len(images):
#just include all originals first
image=images[norig]
else:
image=maxed_out_augmentation_pipeline(image=random.choice(images))['image']
next_hash=np.sum(image)
if next_hash not in hashes or (self.is_validation and norig<=len(images)):
#print(f'Adding orig {norig} out of {self.num_class_samples}, hashes={hashes}')
self._add_sample(image,label)
if next_hash in hashes:
norig+=1
hashes.add(norig)
else:
hashes.add(next_hash)
nartificial+=1
#self.images=self.images[indices];self.labels=self.labels[indices]
logging.info(f'Generated {self.img_sent} samples ({nartificial} artificial)')一旦我加载了图像和类,
train_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_train,
input_size=INPUT_SIZE,class_sizes=class_sizes_train,num_class_samples=UPSCALE_SAMPLES,
augmentations_fn=get_albumentations_pipeline,aug_p=AUG_P,aug_pipe_p=AUG_PIPE_P,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,disk_saving_prob=0.05)
val_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_val,
input_size=INPUT_SIZE,class_sizes=class_sizes_val,num_class_samples=None,
augmentations_fn=get_albumentations_pipeline,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,is_validation=True)在实例化模型之后,我这样做
model.fit(train_datagen,epochs=600,verbose=1,
validation_data=(val_datagen.images,val_datagen.labels),validation_batch_size=BATCH_SIZE,
callbacks=[checkpointer,StoppingFromFile()],validation_freq=1)https://stackoverflow.com/questions/41071842
复制相似问题