当我试图恢复一个学习模型时,我遇到了一个问题:
当我的程序第一次运行时,它似乎没有加载变量,第二次运行时,变量被加载,第三次我在"saver.restore(sess,'model.ckpt')“行上出现了一个巨大的错误,以"NotFoundError: Key beta2_power_2 not in检查点”开头。
下面是我的代码的开头:
with tf.Session() as sess:
myModel = SoundCNN(8)#classes
tf.global_variables_initializer().run()
saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, 'model.ckpt')您可以在SoundCNN文件中看到这里类github项目。我是新的tensorflow和ML,并希望使用awjuliani的项目,以学习使用tf为声音导向的ML。
编辑:这是完整的代码:
print ("start")
bpm = 240
samplingRate = 44100
mypath = "instruments/drums/"
iterations = 1000
batchSize = 240
with tf.Session() as sess:
myModel = SoundCNN(8)#classes
tf.global_variables_initializer().run()
saver = tf.train.Saver(tf.global_variables())
print("loading session ...")
saver.restore(sess, 'model.ckpt')
print("session loaded")
print("processing audio ...")
classes,trainX,trainYa,valX,valY,testX,testY = util.processAudio(bpm,samplingRate,mypath)
print("audio processed")
fullTrain = np.concatenate((trainX,trainYa),axis=1)
quitFlag = False
inputsize = fullTrain.shape[0]-1 #6607
print("entering loop...")
while (not quitFlag):
indexstr = input("Type the index (0< _ <" + str(inputsize) + ") of the sample to test then press enter.\nYou can press enter without text for random index.\nType q to quit.\n")
if (indexstr == "q" or indexstr == "Q"):
quitFlag = True
else:
if(indexstr ==""):
index = randint(0, inputsize)
print("Index : " + str(index))
else:
index = int(indexstr)
tensors,labels_ = np.hsplit(fullTrain,[-1])
labels = util.oneHotIt(labels_)
tensor, label = tensors[index,:], labels[index]
tensor = tensor.reshape(1,1024)
result = myModel.prediction.eval(session=sess,feed_dict={myModel.x: tensor, myModel.keep_prob: 1.0})
print("Model found sound: n°"+ str(result) + ".\nActual sound: n°" + str(np.argmax(label)) + ".\n" )谢谢!
edit2:好的,我试了一下下面的代码:
print ("start")
bpm = 240
samplingRate = 44100
mypath = "instruments/drums/"
iterations = 1000
batchSize = 240
tf.reset_default_graph()
myModel = SoundCNN(8)
saver = tf.train.Saver()
with tf.Session() as sess:
print("loading session ...")
saver.restore(sess, 'model.ckpt')
print("session loaded")而且变量没有加载(错误的预测),但奇怪的是,我可以通过添加:
myModel = SoundCNN(8)
saver = tf.train.Saver()
print("loading session ...")
saver.restore(sess, 'model.ckpt')
print("session loaded")在第一个saver.restore之后(sess,'model.ckpt')
所以我让代码起作用了但这是个恶心的..。
发布于 2018-02-16 12:28:07
那么首先,模型的训练和测试是分开的。使用:存在和检查点运行条件if语句。类似于:
if tf.train.checkpoint_exists(tf.train.latest_checkpoint(".")):
test()
else:
trainNetConv(iterations)
test()您最好只使用latest_checkpoint,因为如果找到检查点,它将返回None或路径。
每当您知道要加载一个模型以清除任何现有的图形时,就运行'tf.reset_default_graph()‘。根据我的经验,它堆叠了图的副本,这减慢了运行时的速度,我猜这可能会导致其他问题。特别是如果您计划在运行时多次执行此操作。
假设您已经有了一个经过训练的模型,您必须像通常一样创建它,方法是使用与您希望加载的模型相同数量的类调用SoundCNN。确保创建完全相同的模型,即相同数量的类。在您提供的代码中,您使用8个类创建模型,但是在'trainNetConv‘中创建的模型的类数由'util.processAudio’决定。值得检查的是,对于任何有声音文件的给定目录,类的数量确实是8个。
加载模型时的关键区别在于不初始化变量,即不使用全局变量调用保护程序对象或运行全局变量初始化程序。你所要做的就是:
有关培训和测试阶段的完整示例,请查看我的GitHub。确保从'mnist‘开始,因为它只是一个文件,在那里是最简单的。
假设您希望为自己的使用定义额外的变量,那么假设有一些变量计数器和一个操作符,如果预测是正确的,它会增加计数器。需要在使用restore加载模型之后放置它,然后只初始化这些附加变量。同样,我认为我的例子在这种情况下可能会有所帮助。
如果你还有什么问题请问,我会尽力帮忙的。
https://stackoverflow.com/questions/48787714
复制相似问题