首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow对象检测评估损失

Tensorflow对象检测评估损失
EN

Stack Overflow用户
提问于 2017-07-26 21:27:16
回答 1查看 1.9K关注 0票数 1

我感兴趣的是运行验证图像,并在Tensorflow的对象检测库中的验证数据集上获得损失(类似于训练期间的损失)。

我正在尝试修改evaluator.py (https://github.com/tensorflow/models/blob/master/object_detection/evaluator.py#L38)中的_extract_prediction_tensors函数,如下所示。我在tensor_dict中添加一个损失字典,这样损失就会得到评估。

代码语言:javascript
复制
groundtruth_boxes_list = 
[input_dict[fields.InputDataFields.groundtruth_boxes]]
label_id_offset = 1
groundtruth_classes_list = 
tf.cast(input_dict[fields.InputDataFields.groundtruth_classes],
                  tf.int32)
groundtruth_classes_list -= label_id_offset
groundtruth_classes_list = 
[ops.padded_one_hot_encoding(indices=groundtruth_classes_list,
                    depth=model.num_classes, left_pad=0)]
model.provide_groundtruth(groundtruth_boxes_list, 
groundtruth_classes_list)          
losses_dict = model.loss(prediction_dict)    
tensor_dict['loss'] = losses_dict

但是,我得到的分类损失是错误的,即使我可以看到它已经正确分类。不确定实现中是否仍有错误。

EN

回答 1

Stack Overflow用户

发布于 2018-08-01 22:23:59

评估脚本已经计算了损失。我目前正在用几行代码来提取它。方法是修改文件"/models/research/object_detection/eval_util.py".在repeated_checkpoint_run()中的行:

代码语言:javascript
复制
  write_metrics(metrics, global_step, summary_dir)

这里的“指标”是一个包含所有类as以及所有损失的字典。要提取它们,请添加以下新行(缩进为前一行):

代码语言:javascript
复制
  for k,v in iter(metrics.items()):
    if 'mAP' in k:
      mAP = v
    elif 'localization_loss' in k:
      loc_loss = v
    elif 'classification_loss' in k:
      cls_loss = v

  print('-> mAP:{} loc_loss:{} cls_loss:{} tot_loss {}'.format(mAP,loc_loss,cls_loss,loc_loss + cls_loss))

这些值应该是您在配置文件中设置的评估数据的mAP、本地化和分类损失。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45328395

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档