首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >saver.restore不恢复

saver.restore不恢复
EN

Stack Overflow用户
提问于 2018-02-14 12:51:41
回答 1查看 1.5K关注 0票数 1

当我试图恢复一个学习模型时,我遇到了一个问题:

当我的程序第一次运行时,它似乎没有加载变量,第二次运行时,变量被加载,第三次我在"saver.restore(sess,'model.ckpt')“行上出现了一个巨大的错误,以"NotFoundError: Key beta2_power_2 not in检查点”开头。

下面是我的代码的开头:

代码语言:javascript
复制
with tf.Session() as sess:
    myModel = SoundCNN(8)#classes
    tf.global_variables_initializer().run() 

    saver = tf.train.Saver(tf.global_variables())

    saver.restore(sess, 'model.ckpt')

您可以在SoundCNN文件中看到这里类github项目。我是新的tensorflow和ML,并希望使用awjuliani的项目,以学习使用tf为声音导向的ML。

编辑:这是完整的代码:

代码语言:javascript
复制
print ("start")
bpm = 240
samplingRate = 44100
mypath = "instruments/drums/"
iterations = 1000
batchSize = 240

with tf.Session() as sess:
    myModel = SoundCNN(8)#classes
    tf.global_variables_initializer().run() 

    saver = tf.train.Saver(tf.global_variables())
    print("loading session ...")
    saver.restore(sess, 'model.ckpt')
    print("session loaded")


    print("processing audio ...")
    classes,trainX,trainYa,valX,valY,testX,testY = util.processAudio(bpm,samplingRate,mypath)
    print("audio processed")

    fullTrain = np.concatenate((trainX,trainYa),axis=1)

    quitFlag = False

    inputsize = fullTrain.shape[0]-1 #6607

    print("entering loop...")
    while (not quitFlag):
        indexstr = input("Type the index (0< _ <" + str(inputsize) + ") of the sample to test then press enter.\nYou can press enter without text for random index.\nType q to quit.\n")

        if (indexstr == "q" or indexstr == "Q"):
            quitFlag = True
        else:
            if(indexstr ==""):
                index = randint(0, inputsize)
                print("Index : " + str(index))
            else:
                index = int(indexstr)     

            tensors,labels_ = np.hsplit(fullTrain,[-1])
            labels = util.oneHotIt(labels_)
            tensor, label = tensors[index,:], labels[index]

            tensor = tensor.reshape(1,1024)

            result = myModel.prediction.eval(session=sess,feed_dict={myModel.x: tensor, myModel.keep_prob: 1.0})

            print("Model found sound: n°"+ str(result) + ".\nActual sound: n°" + str(np.argmax(label)) + ".\n" )

谢谢!

edit2:好的,我试了一下下面的代码:

代码语言:javascript
复制
print ("start")
bpm = 240
samplingRate = 44100
mypath = "instruments/drums/"
iterations = 1000
batchSize = 240


tf.reset_default_graph()
myModel = SoundCNN(8)
saver = tf.train.Saver()

with tf.Session() as sess:

    print("loading session ...")
    saver.restore(sess, 'model.ckpt')
    print("session loaded")

而且变量没有加载(错误的预测),但奇怪的是,我可以通过添加:

代码语言:javascript
复制
    myModel = SoundCNN(8)
    saver = tf.train.Saver()
    print("loading session ...")
    saver.restore(sess, 'model.ckpt')
    print("session loaded")

在第一个saver.restore之后(sess,'model.ckpt')

所以我让代码起作用了但这是个恶心的..。

EN

回答 1

Stack Overflow用户

发布于 2018-02-16 12:28:07

那么首先,模型的训练和测试是分开的。使用:存在检查点运行条件if语句。类似于:

代码语言:javascript
复制
if tf.train.checkpoint_exists(tf.train.latest_checkpoint(".")):
    test()
else:
    trainNetConv(iterations)
    test()

您最好只使用latest_checkpoint,因为如果找到检查点,它将返回None或路径。

每当您知道要加载一个模型以清除任何现有的图形时,就运行'tf.reset_default_graph()‘。根据我的经验,它堆叠了图的副本,这减慢了运行时的速度,我猜这可能会导致其他问题。特别是如果您计划在运行时多次执行此操作。

假设您已经有了一个经过训练的模型,您必须像通常一样创建它,方法是使用与您希望加载的模型相同数量的类调用SoundCNN。确保创建完全相同的模型,即相同数量的类。在您提供的代码中,您使用8个类创建模型,但是在'trainNetConv‘中创建的模型的类数由'util.processAudio’决定。值得检查的是,对于任何有声音文件的给定目录,类的数量确实是8个。

加载模型时的关键区别在于不初始化变量,即不使用全局变量调用保护程序对象或运行全局变量初始化程序。你所要做的就是:

  1. 确保运行tf.reset_default_graph()
  2. 创建模型,调用SoundCNN
  3. 创建一个没有参数的保护程序对象。
  4. 像你一样创建一个会话,
  5. 使用通往最新检查点的路径调用保护程序对象的函数还原。在模型的基dir中使用“tf.train.latest_checkpoint”。
  6. 你就完蛋了。

有关培训和测试阶段的完整示例,请查看我的GitHub。确保从'mnist‘开始,因为它只是一个文件,在那里是最简单的。

假设您希望为自己的使用定义额外的变量,那么假设有一些变量计数器和一个操作符,如果预测是正确的,它会增加计数器。需要在使用restore加载模型之后放置它,然后只初始化这些附加变量。同样,我认为我的例子在这种情况下可能会有所帮助。

如果你还有什么问题请问,我会尽力帮忙的。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48787714

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档