首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >是什么提高了我的Keras Model.fit_generator中的Model.fit_generator

是什么提高了我的Keras Model.fit_generator中的Model.fit_generator
EN

Stack Overflow用户
提问于 2017-09-19 14:19:31
回答 2查看 11.7K关注 0票数 8

我有下一个代码:

代码语言:javascript
复制
from sklearn.model_selection import train_test_split
from scipy.misc import imresize

def _chunks(l, n):
    """Yield successive n-sized chunks from l."""
    for i in range(0, len(l), n):
        yield l[i:i + n]


def _batch_generator(data, batch_size):
    indexes = range(len(data))
    index_chunks = _chunks(indexes, batch_size)
    for i, indexes in enumerate(index_chunks):
        print("\nLoaded batch {0}\n".format(i + 1))
        batch_X = []
        batch_y = []
        for index in indexes:
            record = data[index]
            image = _read_train_image(record["id"], record["index"])
            mask = _read_train_mask(record["id"], record["index"])
            mask_resized = imresize(mask, (1276, 1916)) >= 123
            mask_reshaped = mask_resized.reshape((1276, 1916, 1))
            batch_X.append(image)
            batch_y.append(mask_reshaped)
        np_batch_X = np.array(batch_X)
        np_batch_y = np.array(batch_y)
        yield np_batch_X, np_batch_y


def train(data, model, batch_size, epochs):
    train_data, test_data = train_test_split(data)
    samples_per_epoch = len(train_data)
    steps_per_epoch = samples_per_epoch // batch_size
    print("Train on {0} records ({1} batches)".format(samples_per_epoch, steps_per_epoch))
    train_generator = _batch_generator(train_data, batch_size)
    model.fit_generator(train_generator, 
                        steps_per_epoch=steps_per_epoch, 
                        nb_epoch=epochs, 
                        verbose=1)

train(train_indexes[:30], autoencoder,
    batch_size=2,
    epochs=1)

因此,似乎它必须用下一种方式:

  • 从dataset获取30个索引(只是示例)
  • 将其发送到22条列车记录和8个验证索引(尚未使用)
  • 将列车索引拆分为生成器中的2个索引批次(SO-11批),它的works - len(list(_batch_generator(train_indexes[:22], 2)))实际上返回11
  • fit模型:
    • 在train_generator生成的批处理上(在我的案例中- 11批,每个-2张图像)
    • 有11个批次在时代(steps_per_epoch=steps_per_epoch)
    • 和一个时代(nb_epochs=epochsepochs=1)

但输出有另一种观点:

代码语言:javascript
复制
Train on 22 records (11 batches)
Epoch 1/1

Loaded batch 1

C:\Users\user\venv\machinelearning\lib\site-packages\ipykernel_launcher.py:39: UserWarning: The semantics of the Keras 2 argument `steps_per_epoch` is not the same as the Keras 1 argument `samples_per_epoch`. `steps_per_epoch` is the number of batches to draw from the generator at each epoch. Basically steps_per_epoch = samples_per_epoch/batch_size. Similarly `nb_val_samples`->`validation_steps` and `val_samples`->`steps` arguments have changed. Update your method calls accordingly.
C:\Users\user\venv\machinelearning\lib\site-packages\ipykernel_launcher.py:39: UserWarning: Update your `fit_generator` call to the Keras 2 API: `fit_generator(<generator..., steps_per_epoch=11, verbose=1, epochs=1)`

Loaded batch 2

1/11 [=>............................] - ETA: 11s - loss: 0.7471
Loaded batch 3


Loaded batch 4


Loaded batch 5


Loaded batch 6

2/11 [====>.........................] - ETA: 17s - loss: 0.7116
Loaded batch 7


Loaded batch 8


Loaded batch 9


Loaded batch 10

3/11 [=======>......................] - ETA: 18s - loss: 0.6931
Loaded batch 11

Exception in thread Thread-50:
Traceback (most recent call last):
File "C:\Anaconda3\Lib\threading.py", line 916, in _bootstrap_inner
    self.run()
File "C:\Anaconda3\Lib\threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
File "C:\Users\user\venv\machinelearning\lib\site-packages\keras\utils\data_utils.py", line 560, in data_generator_task
    generator_output = next(self._generator)
StopIteration

4/11 [=========>....................] - ETA: 18s - loss: 0.6663
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-16-092ba6eb51d2> in <module>()
    1 train(train_indexes[:30], autoencoder,
    2       batch_size=2,
----> 3       epochs=1)

<ipython-input-15-f2fec4e53382> in train(data, model, batch_size, epochs)
    37                         steps_per_epoch=steps_per_epoch,
    38                         nb_epoch=epochs,
---> 39                         verbose=1)

C:\Users\user\venv\machinelearning\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
    85                 warnings.warn('Update your `' + object_name +
    86                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 87             return func(*args, **kwargs)
    88         wrapper._original_function = func
    89         return wrapper

C:\Users\user\venv\machinelearning\lib\site-packages\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, initial_epoch)
1807                 batch_index = 0
1808                 while steps_done < steps_per_epoch:
-> 1809                     generator_output = next(output_generator)
1810 
1811                     if not hasattr(generator_output, '__len__'):

StopIteration: 

因此,正如我所看到的-所有批次都被成功地读取(参见“加载批处理”)

但在第3批环氧化合物1的加工过程中,StopIteration被角化物所引起。

EN

回答 2

Stack Overflow用户

发布于 2018-10-10 08:47:20

我也遇到了这个问题,我发现一个方法是您可以在数据生成器函数中插入"while True“块。但我找不到线人。您可以参考下面的代码:

代码语言:javascript
复制
while True:
     assert len(inputs) == len(targets)
     indices = np.arange(len(inputs))
     if shuffle:
        np.random.shuffle(indices)
     if batchsize > len(indices):
        sys.stderr.write('BatchSize out of index size')
     batchsize = len(indices)
     for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
         if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
         else:
            excerpt = slice(start_idx, start_idx + batchsize)
         yield inputs[excerpt], targets[excerpt]
票数 8
EN

Stack Overflow用户

发布于 2018-11-12 16:07:18

一个关于这个问题的说明,以防其他人来此页追逐它。StopIteration错误是keras中已知的问题,有时可以通过确保将批处理大小设置为样本数的整数倍来解决。如果这不能解决这个问题,我发现的一件事是,拥有数据生成器无法读取的时髦的文件格式有时也会导致stopIteration错误。为了解决这个问题,我在培训文件夹上运行一个脚本,在培训之前将所有图像转换为标准文件类型(jpg或png)。看上去像这样。

代码语言:javascript
复制
import glob
from PIL import Image
import os
d=1
for sample in glob.glob(r'C:\Users\Jeremiah\Pictures\training\classLabel_unformatted\*'):
    im = Image.open(sample)
    im.save(r'C:\Users\Jeremiah\Pictures\training\classLabel_formatted\%s.png' %d)
    d=d+1

我发现运行这个脚本或类似的东西会大大减少这些错误的发生频率,尤其是当我的培训数据来自google图像搜索之类的地方时。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46302911

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档