因此,我已经通过以下代码使我的keras模型能够与tf.Dataset一起工作:
# Initialize batch generators(returns tf.Dataset)
batch_train = build_features.get_train_batches(batch_size=batch_size)
# Create TensorFlow Iterator object
iterator = batch_train.make_one_shot_iterator()
dataset_inputs, dataset_labels = iterator.get_next()
# Create Model
logits = .....(some layers)
keras.models.Model(inputs=dataset_inputs, outputs=logits)
# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit(epochs=epochs, steps_per_epoch=num_batches, callbacks=callbacks, verbose=1)但是,当我试图将validation_data参数传递给模型时。它告诉我我不能和发电机一起使用它。在使用tf.Dataset时是否有使用验证的方法?
,例如在tensorflow中,我可以执行以下
# initialize batch generators
batch_train = build_features.get_train_batches(batch_size=batch_size)
batch_valid = build_features.get_valid_batches(batch_size=batch_size)
# create TensorFlow Iterator object
iterator = tf.data.Iterator.from_structure(batch_train.output_types,
batch_train.output_shapes)
# create two initialization ops to switch between the datasets
init_op_train = iterator.make_initializer(batch_train)
init_op_valid = iterator.make_initializer(batch_valid)然后只需使用sess.run(init_op_train)和sess.run(init_op_valid)在数据集之间切换。
我尝试实现一个回调(切换到验证集、预测和返回),但它告诉我不能在回调中使用model.predict
有人能帮我验证一下Keras+Tf.Dataset吗?
编辑:将答案合并到代码中
,所以最后对我有用的答案是:
# Initialize batch generators(returns tf.Dataset)
batch_train = # returns tf.Dataset
batch_valid = # returns tf.Dataset
# Create TensorFlow Iterator object and wrap it in a generator
itr_train = make_iterator(batch_train)
itr_valid = make_iterator(batch_train)
# Create Model
logits = # the keras model
keras.models.Model(inputs=dataset_inputs, outputs=logits)
# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit_generator(
generator=itr_train, validation_data=itr_valid, validation_steps=batch_size,
epochs=epochs, steps_per_epoch=num_batches, callbacks=cbs, verbose=1, workers=0)
def make_iterator(dataset):
iterator = dataset.make_one_shot_iterator()
next_val = iterator.get_next()
with K.get_session().as_default() as sess:
while True:
*inputs, labels = sess.run(next_val)
yield inputs, labels这没有引入任何开销
发布于 2018-06-22 01:50:59
我用fit_genertor解决了这个问题。我找到了解决方案here。我申请了@Dat-Nguyen的解决方案。
您只需创建两个迭代器,一个用于培训,一个用于验证,然后创建您自己的生成器,在这里您将从数据集中提取批处理并以(batch_data,batch_labels)的形式提供数据。最后,在model.fit_generator中,您将传递train_generator和validation_generator。
发布于 2019-01-02 17:40:04
将可重新初始化的迭代器连接到Keras模型的方法是插入一个同时返回x和y值的iterator:
sess = tf.Session()
keras.backend.set_session(sess)
x = np.random.random((5, 2))
y = np.array([0, 1] * 3 + [1, 0] * 2).reshape(5, 2) # One hot encoded
input_dataset = tf.data.Dataset.from_tensor_slices((x, y))
# Create your reinitializable_iterator and initializer
reinitializable_iterator = tf.data.Iterator.from_structure(input_dataset.output_types, input_dataset.output_shapes)
init_op = reinitializable_iterator.make_initializer(input_dataset)
#run the initializer
sess.run(init_op) # feed_dict if you're using placeholders as input
# build keras model and plug in the iterator
model = keras.Model.model(...)
model.compile(...)
model.fit(reinitializable_iterator,...)如果您还有一个验证数据集,那么最简单的方法就是创建一个单独的迭代器并将其插入validation_data参数。确保定义您的steps_per_epoch和validation_steps,因为它们不能被推断。
https://stackoverflow.com/questions/50955798
复制相似问题