首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在每个时代之后重置度量的局部变量

在每个时代之后重置度量的局部变量
EN

Stack Overflow用户
提问于 2019-01-02 14:00:05
回答 2查看 1.2K关注 0票数 0

我使用内置方法tf.metrics.precision来评估我的模型。我查看了它的定义,但是局部变量从不重置。

难道不应该在每个时代之后重新设置它们,以便从最后一个时代中删除计数吗?这是自动完成的,我只是在源代码中忽略了它,还是应该这样做呢?如果后者为真,如何重置局部变量?我在文件里什么也没看过。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-01-02 14:59:17

用于跟踪度量的变量是使用metric_variable函数创建的,从而添加到具有键tf.GraphKeys.METRIC_VARIABLES的集合中。在定义了所有度量之后,可以进行如下重置操作:

代码语言:javascript
复制
reset_metrics_op = tf.variables_initializer(tf.get_collection(tf.GraphKeys.METRIC_VARIABLES))

在每个时代结束后再运行。

票数 1
EN

Stack Overflow用户

发布于 2019-01-02 15:07:24

是。在批量处理数据时,必须注意如何重置变量。在计算总体度量(即精度、准确性或auc)和批处理度量时,对操作进行安排是不同的。在计算每一批新数据的精度值之前,需要将运行的变量重置为零。

使用tf.metrics.precision,将创建两个运行变量并将其放入计算图中:true_positivesfalse_positives。因此,您可以使用scope参数tf.get_collection()来选择要重置的变量。

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

import numpy as np
import tensorflow as tf

labels = np.array([[1,1,1,0],
                   [1,1,1,0],
                   [1,1,1,0],
                   [1,1,1,0]], dtype=np.uint8)

predictions = np.array([[1,0,0,0],
                        [1,1,0,0],
                        [1,1,1,0],
                        [0,1,1,1]], dtype=np.uint8)

precision, update_op = tf.metrics.precision(labels, predictions, name = 'precision')

print(precision)
#Tensor("precision/value:0", shape=(), dtype=float32)
print(update_op)
#Tensor("precision/update_op:0", shape=(), dtype=float32)

tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES)
#[<tf.Variable 'precision/true_positives/count:0' shape=() dtype=float32_ref>,
# <tf.Variable 'precision/false_positives/count:0' shape=() dtype=float32_ref>,

running_vars_precision = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope='precision')
running_vars_auc_initializer = tf.variables_initializer(var_list=running_vars_precision )

with tf.Session() as sess:
    sess.run(running_vars_auc_initializer)
    print("tf precision/update_op: {}".format(sess.run([precision, update_op])))
    #tf precision/update_op: [0.8888889, 0.8888889]
    print("tf precision: {}".format(sess.run(precision)))
    #tf precision: 0.8888888955116272
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54007669

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档