首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow -如何操作保护程序

Tensorflow -如何操作保护程序
EN

Stack Overflow用户
提问于 2016-12-09 18:38:52
回答 1查看 107关注 0票数 0

我正在使用tensorflow的波士顿住房数据教程,但我正在插入我自己的数据集:

代码语言:javascript
复制
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pandas as pd
import tensorflow as tf

tf.logging.set_verbosity(tf.logging.INFO)

COLUMNS = ["crim", "zn", "indus", "nox", "rm", "age",
       "dis", "tax", "ptratio", "medv"]
FEATURES = ["crim", "zn", "indus", "nox", "rm",
        "age", "dis", "tax", "ptratio"]
LABEL = "medv"


def input_fn(data_set):
    feature_cols = {k: tf.constant(data_set[k].values) for k in FEATURES}
    labels = tf.constant(data_set[LABEL].values)
    return feature_cols, labels

def main(unused_argv):
    # Load datasets
    training_set = pd.read_csv("boston_train.csv", skipinitialspace=True,
                         skiprows=1, names=COLUMNS)
    test_set = pd.read_csv("boston_test.csv", skipinitialspace=True,
                     skiprows=1, names=COLUMNS)

    # Set of 6 examples for which to predict median house values
    prediction_set = pd.read_csv("boston_predict.csv",             skipinitialspace=True,
                           skiprows=1, names=COLUMNS)

    # Feature cols
    feature_cols = [tf.contrib.layers.real_valued_column(k)
              for k in FEATURES]

    # Build 2 layer fully connected DNN with 10, 10 units respectively.
    regressor = tf.contrib.learn.DNNRegressor(
    feature_columns=feature_cols, hidden_units=[10, 10])

    # Fit
    regressor.fit(input_fn=lambda: input_fn(training_set), steps=5000)

    # Score accuracy
    ev = regressor.evaluate(input_fn=lambda: input_fn(test_set), steps=1)
    loss_score = ev["loss"]
    print("Loss: {0:f}".format(loss_score))

    # Print out predictions
    y = regressor.predict(input_fn=lambda: input_fn(prediction_set))
    print("Predictions: {}".format(str(y)))

if __name__ == "__main__":
    tf.app.run()

我遇到的问题是数据集太大了,所以通过tf.train.Saver()保存检查点文件就填满了我所有的磁盘空间。

是否有一种方法可以禁用检查点文件的保存,或者减少上面脚本中保存的检查点的数量?

谢谢

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-12-09 19:06:08

tf.contrib.learn.DNNRegressor初始化程序接受一个tf.contrib.learn.RunConfig对象,该对象可用于控制内部创建的保护程序的行为。例如,您可以执行以下操作,只保留一个检查点:

代码语言:javascript
复制
config = tf.contrib.learn.RunConfig(keep_checkpoint_max=1)
regressor = tf.contrib.learn.DNNRegressor(
    feature_columns=feature_cols, hidden_units=[10, 10], config=config)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/41066853

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档