首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用保存的检查点多次运行推理模型会产生随机错误- Tensorflow

使用保存的检查点多次运行推理模型会产生随机错误- Tensorflow
EN

Stack Overflow用户
提问于 2017-01-12 13:20:41
回答 1查看 347关注 0票数 0

我在GPU上运行Tensorflow 0.12.1。我有一个经过训练的Deep CNN模型,我使用检查点文件保存了它的权重。在推断过程中,我使用restorer.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir))重新加载保存的检查点。代码似乎运行没有问题,但每次我重新运行脚本时,我都会得到错误的输出。AFAIK,我不会打乱我的测试集输入。输入正在正确加载并馈送到网络。这只是CNN在相同测试集上不同运行的输出,使用相同的顺序产生非常不同的输出。我被弄糊涂了!另外,在推理过程中,如何在不运行init_op的情况下执行加载了保存的检查点的图形?我的代码似乎要求所有的全局和局部变量在执行前都被初始化。(我首先初始化,然后只恢复检查点!)下面是我的代码片段:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np
import os
import os.path
from datetime import datetime
import time
import random
import json

from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes


from modelFCNN3 import model 

def read_input(inp_queue,height=224,width=224,channels=3, mask=False):
  value = tf.read_file(inp_queue)
  image = tf.image.decode_png(value)
  image =  tf.image.resize_images(image, [height, width],method=2)
  image = tf.cast(image, tf.uint8)
  image.set_shape([height,width,channels])
  image = tf.reshape(image,[height,width,channels])
  if mask:
      image = tf.to_float(tf.greater_equal(image,128))
      image = tf.cast(image,tf.float32)
  else:
      image = tf.image.per_image_standardization(image)
      image = tf.cast(image,tf.float32)
  return image




if __name__ == '__main__':

    tf.reset_default_graph()

    with open('X_test.json', 'r') as infile:
        X_test = json.load(infile)

    with open('y_test.json', 'r') as infile:
        y_test = json.load(infile)

    imagelist = ops.convert_to_tensor(X_test, dtype=dtypes.string)
    labellist = ops.convert_to_tensor(y_test, dtype=dtypes.string)

    input_queue = tf.train.slice_input_producer([imagelist, labellist],
                                            num_epochs=1,
                                            shuffle=False)

    image = read_input(input_queue[0],height=224,width=224,channels=3, mask=False)

    label = read_input(input_queue[1],height=224,width=224,channels=1, mask=True)

    images_batch, labels_batch = tf.train.batch([image, label], batch_size=FLAGS.batch_size,
        enqueue_many=False,shapes=None, allow_smaller_final_batch=True)

    global_step = tf.Variable(0, trainable=False)
    images = tf.placeholder_with_default(images_batch, shape=[None, 224,224,3])
    labels = tf.placeholder_with_default(labels_batch, shape=[None, 224,224,1])

    restorer = tf.train.Saver()

    logits = model(images).logits
    labels = tf.cast(labels,tf.int32)
    labels.set_shape([FLAGS.batch_size,224,224,1])

    valid_prediction = tf.argmax(tf.nn.softmax(logits), dimension=3)
    valid_prediction.set_shape([FLAGS.batch_size,224,224])

    meanIOU,update_op_mIOU= tf.contrib.metrics.streaming_mean_iou(tf.cast(valid_prediction,tf.int32), tf.squeeze(labels),FLAGS.num_classes)

    init = tf.global_variables_initializer()
    init_locals = tf.local_variables_initializer()


    with tf.Session() as sess:

        sess.run([init, init_locals])

        restorer.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir))
        print("Model restored.")


        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord,sess=sess)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        try:
            step = 0
            avg = []
            while not coord.should_stop():
                myimg, predimg, mylbl= sess.run([images,valid_prediction,labels])
                mIOU,_ = sess.run([meanIOU,update_op_mIOU])
                avg.append(mIOU)

            step += 1

        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')

        finally:

            coord.request_stop()
            coord.join(threads)
            sess.close()
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-03-07 13:43:29

您是在同一台机器上运行还是在不同的机器上运行#saver = tf.train.Saver()

以下注释在tensorflow文档中#注:仅当设备分配未更改时,才能从保存的meta_graph重新开始训练。#saver =tf.train.import_meta_graph(元文件)

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/41605636

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档