首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >WideNDeep教程代码

WideNDeep教程代码
EN

Stack Overflow用户
提问于 2016-09-08 05:54:16
回答 2查看 270关注 0票数 0

关于WideNDeep tutorial中的这行代码:

代码语言:javascript
复制
m.fit(input_fn=lambda: input_fn(df_train), steps=FLAGS.train_steps)

用于训练深度模型的batch_size是什么?目前,在我看来,这个模型不是batch_trained?是否有默认的batch_size?

谢谢

EN

回答 2

Stack Overflow用户

发布于 2016-09-09 00:23:11

您可以将batch_size作为参数传递给fit。See the documentation on BaseEstimator.fit

票数 0
EN

Stack Overflow用户

发布于 2016-09-10 12:12:24

我对批处理的本教程进行了如下更改:

  1. 将CSV数据转换为TensorFlow格式( TFRecord文件中的示例);然后
  2. 从tf.contrib.learn ( input_fn )

中提供的read_batch函数创建队列

这是我使用的代码:

https://gist.github.com/cirocavani/7d9e827102093139acd400b02d2e7afb

input_fn如下所示:

代码语言:javascript
复制
def input_fn(mode, data_file, batch_size):
    input_features = create_feature_columns()
    features = tf.contrib.layers.create_feature_spec_for_parsing(input_features)

    feature_map = tf.contrib.learn.io.read_batch_record_features(
        file_pattern=[data_file],
        batch_size=batch_size,
        features=features,
        name="read_batch_features_{}".format(mode))

    target = feature_map.pop("label")

    return feature_map, target

我认为它会有一个更简单的解决方案,但我还不知道TensorFlow是否提供了一个:)

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/39379586

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档