关于WideNDeep tutorial中的这行代码:
m.fit(input_fn=lambda: input_fn(df_train), steps=FLAGS.train_steps)用于训练深度模型的batch_size是什么?目前,在我看来,这个模型不是batch_trained?是否有默认的batch_size?
谢谢
发布于 2016-09-09 00:23:11
您可以将batch_size作为参数传递给fit。See the documentation on BaseEstimator.fit
发布于 2016-09-10 12:12:24
我对批处理的本教程进行了如下更改:
中提供的read_batch函数创建队列
这是我使用的代码:
https://gist.github.com/cirocavani/7d9e827102093139acd400b02d2e7afb
input_fn如下所示:
def input_fn(mode, data_file, batch_size):
input_features = create_feature_columns()
features = tf.contrib.layers.create_feature_spec_for_parsing(input_features)
feature_map = tf.contrib.learn.io.read_batch_record_features(
file_pattern=[data_file],
batch_size=batch_size,
features=features,
name="read_batch_features_{}".format(mode))
target = feature_map.pop("label")
return feature_map, target我认为它会有一个更简单的解决方案,但我还不知道TensorFlow是否提供了一个:)
https://stackoverflow.com/questions/39379586
复制相似问题