首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >`get_variable()`不识别tf.estimator的现有变量

`get_variable()`不识别tf.estimator的现有变量
EN

Stack Overflow用户
提问于 2018-11-26 11:26:45
回答 1查看 449关注 0票数 6

这个问题已经被这里问过了,区别在于我的问题是集中在Estimator上。

一些上下文:我们使用估计器训练了一个模型,并在估计值input_fn中定义了一些变量,该函数将数据预处理成批处理。现在,我们开始预测。在预测过程中,我们使用相同的input_fn来读取和处理数据。但got错误声明变量(word_embeddings)不存在(变量存在于chkp图中),下面是input_fn中的相关代码

代码语言:javascript
复制
with tf.variable_scope('vocabulary', reuse=tf.AUTO_REUSE):
    if mode == tf.estimator.ModeKeys.TRAIN:
        word_to_index, word_to_vec = load_embedding(graph_params["word_to_vec"])
        word_embeddings = tf.get_variable(initializer=tf.constant(word_to_vec, dtype=tf.float32),
                                          trainable=False,
                                          name="word_to_vec",
                                          dtype=tf.float32)
    else:
        word_embeddings = tf.get_variable("word_to_vec", dtype=tf.float32)

基本上,当它处于预测模式时,会调用else来加载检查点中的变量。无法识别此变量表明:( a)范围使用不当;( b)图形未恢复。我认为只要正确设置reuse,范围在这里就不那么重要了。

我怀疑这是因为图形还没有恢复到input_fn阶段。通常,图形是通过调用saver.restore(sess, "/tmp/model.ckpt") 参考文献来恢复的。对估值器源代码的研究并没有给我带来任何与恢复有关的东西,最好的方法是MonitoredSession,一个训练的包装。它已经从最初的问题延伸了那么多,没有信心,如果我在正确的道路上,我在这里寻求帮助,如果有人有任何见解。

我问题的一行摘要:如何通过tf.estimator (通过input_fnmodel_fn )恢复图形

EN

回答 1

Stack Overflow用户

发布于 2018-12-11 23:23:39

嗨,我认为出现错误只是因为您没有在tf.get_variable中指定形状(at predict),即使要还原变量,也似乎需要指定形状。

我使用一个简单的线性回归估计器进行了下面的测试,它只需预测x+5

代码语言:javascript
复制
def input_fn(mode):
    def _input_fn():
        with tf.variable_scope('all_input_fn', reuse=tf.AUTO_REUSE):
            if mode == tf.estimator.ModeKeys.TRAIN:
                var_to_follow = tf.get_variable('var_to_follow', initializer=tf.constant(20))
                x_data = np.random.randn(1000)
                labels = x_data + 5
                return {'x':x_data}, labels
            elif mode == tf.estimator.ModeKeys.PREDICT:
                var_to_follow = tf.get_variable("var_to_follow", dtype=tf.int32, shape=[])
                return {'x':[0,10,100,var_to_follow]}
    return _input_fn

featcols = [tf.feature_column.numeric_column('x')]
model = tf.estimator.LinearRegressor(featcols, './outdir')

这段代码运行得很好,const的值是20,也是为了好玩,在我的测试集中使用它来确认:p

但是,如果删除shape=[],它就会中断,您也可以给出另一个初始化器,例如tf.constant(500),一切都会正常工作,20将被使用。

通过跑

代码语言:javascript
复制
model.train(input_fn(tf.estimator.ModeKeys.TRAIN), max_steps=10000)

代码语言:javascript
复制
preds = model.predict(input_fn(tf.estimator.ModeKeys.PREDICT))
print(next(preds))

您可以可视化该图形,您将看到:( a)范围是正常的,b)图是恢复的。

希望这能帮到你。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53480116

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档