首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow中不同数量元素的标签和预测的Precision和recall eval_metrics

tensorflow中不同数量元素的标签和预测的Precision和recall eval_metrics
EN

Stack Overflow用户
提问于 2018-07-28 06:14:44
回答 1查看 81关注 0票数 0

我在Tensorflow中作为eval_metrics注册精度和召回率时遇到了问题。我的标签和预测没有相同数量的元素,所以我不能使用已经内置的函数。我有计算精度和召回率的函数,但是我似乎不能得到precision_update_op和recall_update_op。你知道我怎么才能从标签,预测和前面提到的计算精度和召回率的函数中得到它吗?谢谢

EN

回答 1

Stack Overflow用户

发布于 2018-07-30 03:32:12

这里有一个如何构建您自己的指标的简单示例。我将演示mean,您也应该能够适应上面提到的内容。

代码语言:javascript
复制
def mean_metrics(values):
   """ For mean, there are two variables that are 
 required to hold the sum and the total number of variables"""

   # total sum
   total = tf.Variable(initial_value=0., dtype=tf.float32, name='total')

   # total count
   count = tf.Variable(initial_value=0., dtype=tf.float32, name='count')

   # Update total op by updating total with the sum of the values
   update_total_op = tf.assign_add(total, tf.cast(tf.reduce_sum(values), tf.float32))

   # Update count op by updating the total size of the values
   update_count_op = tf.assign_add(count, tf.cast(tf.size(tf.squeeze(values)), tf.float32))

   # Mean
   mean = tf.div(total, count, 'value')

   # Mean update op
   update_op = tf.div(update_total_op, update_count_op, 'value')

   return mean, update_op

测试上面的代码:

代码语言:javascript
复制
tf.reset_default_graph()
values = tf.placeholder(tf.float32, shape=[None])

mean, mean_op = mean_metrics(values)

with tf.Session() as sess:
   tf.global_variables_initializer().run()
   print(sess.run([mean, mean_op], {values:[1.,2.,3.]}))
   print(sess.run([mean, mean_op], {values:[4.,5.,6.]}))

#output
#[nan, 2.0]
#[2.0, 3.5]
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51566194

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档