首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow -我恢复模型正确吗?

Tensorflow -我恢复模型正确吗?
EN

Stack Overflow用户
提问于 2017-03-27 22:58:29
回答 3查看 1.9K关注 0票数 1

我有下面的代码正在工作(没有错误)。我的问题是,我是否恢复模型正确?特别是我看不到语句print(v_)的任何输出。

所以,我想知道我做的是否正确:

  1. 恢复模型
  2. 使用恢复的模型 将tensorflow作为tf导入 数据,标签= cifar_tools.read_data('C:\Users\abc\Desktop\Testing') x= tf.placeholder(tf.float32,None,150 * 150) y= tf.placeholder(tf.float32,None,2) w1 = tf.Variable(tf.random_normal(5,5,1,64 ) b1 = tf.Variable(tf.random_normal(64)) w2 = tf.Variable(tf.random_normal(5,5,64),b2 = tf.Variable(tf.random_normal(64)) w3 = tf.Variable(tf.random_normal(38*38*64,1024)) b3 = tf.Variable(tf.random_normal(1024)) w_out = tf.Variable(tf.random_normal(1024,2) b_out = tf.Variable(tf.random_normal(2)) def conv_layer(x,w,b):conv = tf.nn.conv2d(x,w,tf.random_normal,填充=‘相同’) conv_with_b = tf.nn.bias_add(conv,b) conv_out = tf.nn.relu(conv_with_b)返回conv_out def maxpool_layer(conv,k=2):返回tf.nn.max_pool(conv,ksize=1,k,k,1,strides=1,k,k,1,填充=‘同’) def模型():x_reshaped = tf.reshape(x,shape=-1,150,150,( 1) conv_out1 = conv_layer(x_reshaped,w1,b1) maxpool_out1 = maxpool_layer(conv_out1) norm1 = tf.nn.lrn(maxpool_out1,4,bias=1.0,alpha=0.001 / 9.0,beta=0.75) conv_out2 = conv_layer(norm1,conv_out2,beta=0.75)=(,4,,/ 9.0,( beta=0.75) maxpool_out2 = maxpool_layer(norm2) maxpool_reshaped = tf.reshape(maxpool_out2,[-1,w3.get_shape().as_list()]) local = tf.add(tf.matmul(maxpool_reshaped,w3),b3) local_out = tf.nn.relu(local) out =tf.add(local_out,en19#),( tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost) correct_pred =tf.equal( model_op,1),tf.argmax(y,1))精度= tf.reduce_mean(tf.cast(correct_pred ),( tf.float32))以tf.Session()作为sess: sess.run(tf.global_variables_initializer() onehot_labels =tf.one_hot(标签,2,on_value=1.,off_value=0.,axis=-1) onehot_vals = sess.run(onehot_labels) batch_size = len(data) # Restore模型保护程序=tf.Session saver.restore(sess ),Tf.train.latest_checkpoint(‘./’) all_vars = tf.get_collection('vars')表示v in all_vars: v_ = sess.run(v) print(v_)表示j在范围(0,5):打印(‘EPOCH’,j)在范围内(0,len(数据),batch_size):batch_data = datai:i+batch_size,:batch_onehot_vals = onehot_valsi:i+batch_size,:_,accuracy_val = sess.run(train_op,准确性,feed_dict={x: batch_data,y: batch_onehot_vals})打印(i,accuracy_val)打印

编辑1

恢复这种方式有效吗?

代码语言:javascript
复制
saver = tf.train.Saver()
saver = tf.train.import_meta_graph('C:\\Users\\Abder-Rahman\\Desktop\\\Testing\\mymodel.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
print('model restored'

编辑2

这就是我如何保存我的模型的

代码语言:javascript
复制
#Save model
saver = tf.train.Saver()
saved_path = saver.save(sess, 'C:\\Users\\abc\\Desktop\\\Testing\\mymodel')
print("The model is in this file: ", saved_path)

谢谢。

EN

回答 3

Stack Overflow用户

发布于 2017-03-28 02:45:41

你的保护程序代码是正确的。而变量必须在检索集合之前添加到集合中。tf.add_to_collection("vars", w1) tf.add_to_collection("vars", b1) ...然后all_vars = tf.get_collection('vars')

票数 2
EN

Stack Overflow用户

发布于 2017-03-28 02:10:25

通常,我会像这样还原一个TensorFlow模型:

代码语言:javascript
复制
 with tf.Session(graph=graph) as session:
    if os.path.exists(save_path):
        # Restore variables from disk.
        saver.restore(session, save_path)
    else:
        tf.initialize_all_variables().run()
        print('Initialized')

    # do the work
    # ... 
 saver.save(session, save_path)   # save the model

示例代码可以获取这里

我需要了解更多关于如何保存模型的信息,您的模型似乎是在保存之前恢复的,并且您的模型没有转向tf.graph并与会话连接。

票数 1
EN

Stack Overflow用户

发布于 2017-03-28 05:45:53

我假设您已经阅读了我的博客这里,模型保存机制非常简单,当您加载一个模型时,参数值和关系(这可能是您所关心的)由变量名称匹配。

例如

代码语言:javascript
复制
#simplesave.py
import tensorflow as tf

with tf.Graph().as_default() as g:#yes you have to have a graph first
  with tf.Session() as sess:
    b = tf.Variable(1.0, name="bias")
    saver = tf.train.Saver()
    saver.save(sess,'model') #b should be saved in the model file

#simpleload.py

import tensorflow as tf

with tf.Graph().as_default() as g:
  with tf.Session() as sess:
    #still need the definition, again
    b = tf.Variable(0.0, name="bias")
    saver = tf.train.Saver() #now it is satisfied...
    saver.restore(sess,model)

让我感到困惑的是,您使用了一个函数all_vars = tf.get_collection('vars'),但是您从未定义过一个名为"vars“的范围。您可能应该首先使用tf.all_variables()进行测试。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/43057816

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档