下面的代码是“用剪刀进行机器学习--学习、Keras和tensorflow”中的一个片段。我理解以下代码中的所有内容,除了第二行中的.repeat(repeat)函数链接。
我知道重复重复dataset元素(在本例中是文件路径),如果参数设置为None或空,重复将永远持续,直到使用它的函数决定何时停止。
正如您在下面的代码中所看到的,作者正在将repeat()参数设置为None;
基本上,我想知道作者为什么决定这么做?
或者是因为代码试图模拟数据集不适合内存的情况,如果是这样,那么在实际情况下我们应该避免repeat(),对吗?
def csv_reader_dataset(filepaths, repeat=1, n_readers=5,
n_read_threads=None, shuffle_buffer_size=10000,
n_parse_threads=5, batch_size=32):
dataset = tf.data.Dataset.list_files(filepaths, seed = 42).repeat(repeat)
dataset = dataset.interleave(
lambda filepath: tf.data.TextLineDataset(filepath).skip(1),
cycle_length = n_readers, num_parallel_calls = n_read_threads)
dataset = dataset.shuffle(shuffle_buffer_size)
dataset = dataset.map(preprocess, num_parallel_calls = n_parse_threads)
dataset = dataset.batch(batch_size)
return dataset.prefetch(1)
train_set = csv_reader_dataset(train_filepaths, repeat = None)
valid_set = csv_reader_dataset(valid_filepaths)
test_set = csv_reader_dataset(test_filepaths)
keras.backend.clear_session()
np.random.seed(42)
tf.random.set_seed(42)
model = keras.models.Sequential([
keras.layers.InputLayer(input_shape = X_train.shape[-1: ]),
keras.layers.Dense(30, activation = 'relu'),
keras.layers.Dense(1)
])
m_loss = keras.losses.mean_squared_error
m_optimizer = keras.optimizers.SGD(lr = 1e-3)
batch_size = 32
model.compile(loss = m_loss, optimizer = m_optimizer, metrics = ['accuracy'])
model.fit(train_set, steps_per_epoch = len(X_train) // batch_size, epochs = 10, validation_data = valid_set)发布于 2021-03-18 10:49:36
我还有一本同样的问题,关于书的作者git repo。问题被澄清了;这是由于Keras2.0中的一个bug。
发布于 2021-03-17 15:54:18
对于你们的问题,我认为:
https://stackoverflow.com/questions/66676389
复制相似问题