首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用tf.slim.evaluation.evaluate_once时的NotFoundError

使用tf.slim.evaluation.evaluate_once时的NotFoundError
EN

Stack Overflow用户
提问于 2017-08-25 20:27:36
回答 1查看 549关注 0票数 0

当我想使用slim.evaluation.evaluate_once()函数评估我的模型时,遇到了NotFoundError。它告诉我找不到模型的键或值。如下所示:

代码语言:javascript
复制
Running evaluation Loop...
INFO:tensorflow:Starting evaluation at 2017-08-25-11:40:57
INFO:tensorflow:Starting evaluation at 2017-08-25-11:40:57
INFO:tensorflow:Restoring parameters from tmp/flowers/finetune_log/model.ckpt-5000
INFO:tensorflow:Restoring parameters from tmp/flowers/finetune_log/model.ckpt-5000

NotFoundError                             Traceback (most recent call last)
/home/wangx/Dev_env/.tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1326     try:
-> 1327       return fn(*args)
   1328     except errors.OpError as e:

/home/wangx/Dev_env/.tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1305                                    feed_dict, fetch_list, target_list,
-> 1306                                    status, run_metadata)
   1307 

/usr/lib/python3.5/contextlib.py in __exit__(self, type, value, traceback)
     65             try:
---> 66                 next(self.gen)
     67             except StopIteration:

...
NotFoundError (see above for traceback): Key InceptionV1/Mixed_4c/Branch_0/Conv2d_0a_1x1/biases not found in checkpoint
     [[Node: save/RestoreV2_44 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_44/tensor_names, save/RestoreV2_44/shape_and_slices)]]
     [[Node: save/RestoreV2_6/_1 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_238_save/RestoreV2_6", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]

我将我的检查点保存在./tmp/ flowers /finetune_log,然后在本教程之后下载flowers照片。我从培训中得到的检查点文件有问题吗?或者我在做评估的时候漏掉了什么?下面是我的评估代码:

代码语言:javascript
复制
from datasets import flowers
from nets import inception

with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)
    tf_global_step = slim.get_or_create_global_step()
    dataset = flowers.get_split('validation', 'tmp/flowers')
    images, labels = load_batch(dataset)
    logits, endpoints = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=False)
    predictions =tf.argmax(logits, 1)

    # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
    'eval/Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
    'eval/Recall': slim.metrics.streaming_recall(predictions, labels)})

    print('Running evaluation Loop...')
    checkpoint_path = tf.train.latest_checkpoint('tmp/flowers/finetune_log')

    metric_values = slim.evaluation.evaluate_once(
    num_evals=20,
    master='',
    checkpoint_path=checkpoint_path,
    logdir='tmp/flowers/eval_finetune_log',
    eval_op=names_to_updates.values(),
    final_op=names_to_values.values())

以防万一,下面是我的训练代码:

代码语言:javascript
复制
def get_init_fn():
    """Returns a function run by the chief worker to warm-start the training."""
    checkpoint_exclude_scopes=["InceptionV1/Logits", "InceptionV1/AuxLogits"]

    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]

    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)

    return slim.assign_from_checkpoint_fn(
      os.path.join('tmp/checkpoints', 'inception_v1.ckpt'),
      variables_to_restore)

train_dir = 'tmp/flowers/finetune_log'

with tf.Graph().as_default():
    dataset = flowers.get_split('train', 'tmp/flowers')
    images, labels = load_batch(dataset)

    with slim.arg_scope(inception.inception_v1_arg_scope()):
        logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True)

    one_hot_labels = slim.one_hot_encoding(labels, 5)
    slim.losses.softmax_cross_entropy(logits, one_hot_labels)
    total_loss = slim.losses.get_total_loss()

    tf.summary.scalar('losses/Total Loss', total_loss)

    optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
    train_op = slim.learning.create_train_op(total_loss, optimizer)

    final_loss = slim.learning.train(
        train_op,
        logdir=train_dir,
        init_fn=get_init_fn(),
        number_of_steps=5000,
        save_summaries_secs=1)

print('done.')

非常感谢。它阻碍了我很长时间。

EN

回答 1

Stack Overflow用户

发布于 2017-08-27 09:19:26

我发现在评估片段中,如果我执行以下两个更改,程序就可以运行:

  1. 为模型定义了slim.arg_scope(),我认为这就是NotFoundError信标程序不知道模型的参数的原因,所以代码应该像这样改变:

代码语言:javascript
复制
images, labels = load_batch(dataset)
with slim.arg_scope(inception.inception_v1_arg_scope()):
    logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True)
predictions =tf.argmax(logits, 1)

2.我删除了slim.metrics.aggregate_metric_map(),并使用了一个简单的度量:

代码语言:javascript
复制
accuracy, accuracy_updates = slim.metrics.streaming_accuracy(predictions, labels)

现在它可以运行了。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45881448

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档