首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow数据API:重复()

Tensorflow数据API:重复()
EN

Stack Overflow用户
提问于 2021-03-17 15:39:44
回答 2查看 203关注 0票数 0

下面的代码是“用剪刀进行机器学习--学习、Keras和tensorflow”中的一个片段。我理解以下代码中的所有内容,除了第二行中的.repeat(repeat)函数链接。

我知道重复重复dataset元素(在本例中是文件路径),如果参数设置为None或空,重复将永远持续,直到使用它的函数决定何时停止。

正如您在下面的代码中所看到的,作者正在将repeat()参数设置为None

基本上,我想知道作者为什么决定这么做?

或者是因为代码试图模拟数据集不适合内存的情况,如果是这样,那么在实际情况下我们应该避免repeat(),对吗?

代码语言:javascript
复制
def csv_reader_dataset(filepaths, repeat=1, n_readers=5,
                       n_read_threads=None, shuffle_buffer_size=10000,
                       n_parse_threads=5, batch_size=32):
    dataset = tf.data.Dataset.list_files(filepaths, seed = 42).repeat(repeat)
    dataset = dataset.interleave(
        lambda filepath: tf.data.TextLineDataset(filepath).skip(1),
        cycle_length = n_readers, num_parallel_calls = n_read_threads)
    
    dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(preprocess, num_parallel_calls = n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset.prefetch(1)

train_set = csv_reader_dataset(train_filepaths, repeat = None)
valid_set = csv_reader_dataset(valid_filepaths)
test_set = csv_reader_dataset(test_filepaths)


keras.backend.clear_session()
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.InputLayer(input_shape = X_train.shape[-1: ]),
    keras.layers.Dense(30, activation = 'relu'),
    keras.layers.Dense(1)
])

m_loss = keras.losses.mean_squared_error
m_optimizer = keras.optimizers.SGD(lr = 1e-3)

batch_size = 32
model.compile(loss = m_loss, optimizer = m_optimizer, metrics = ['accuracy'])
model.fit(train_set, steps_per_epoch = len(X_train) // batch_size, epochs = 10, validation_data = valid_set)
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-03-18 10:49:36

我还有一本同样的问题,关于书的作者git repo。问题被澄清了;这是由于Keras2.0中的一个bug。

阅读更多信息:https://github.com/ageron/handson-ml2/issues/407

票数 0
EN

Stack Overflow用户

发布于 2021-03-17 15:54:18

对于你们的问题,我认为:

  • tf.data API不会轻易导致内存不足,因为它在文件路径或tfrec棒(压缩模式)的情况下加载数据。因此,repeat()在这里与内存无关;相反,在将data-transforming.
  • I设置为#时,只能使用steps_per_epoch (#)。假设您的batch_num = 32steps_per_epoch = 100//32 =3 ->需要3* 32 =每一个时代96个样本,但是您的数据只有80个样本。然后,我必须使用data.repeat(2)得到总共160个样本,其中repeat_1中的80个样本和repeat_2中的前16个样本将在1个时期内使用。这是为了防止错误输入耗尽数据。
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66676389

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档