首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何保存深度学习模式,并在培训后进行测试?

如何保存深度学习模式,并在培训后进行测试?
EN

Data Science用户
提问于 2018-08-14 13:52:16
回答 1查看 2.7K关注 0票数 0

我用tensorflow编写了一个用于python的CNN模型,该模型用于对肺CT图像(癌症/非癌症)进行分类,经过训练和验证数据训练模型并获得合理的准确性,毕竟,我需要用测试数据来测试模型,但我不知道如何做到这一点?如何保存模型并将其用于测试?

EN

回答 1

Data Science用户

回答已采纳

发布于 2018-08-14 17:10:21

您可以在本教程中找到详细信息:拯救CNN模型

概括地说:

Tensorflow变量仅在会话中处于活动状态。因此,您必须通过对saver对象调用save方法来保存会话中的模型。

代码语言:javascript
复制
import tensorflow as tf
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')

为了在1000次迭代后保存模型,通过传递步骤计数来调用保存:

代码语言:javascript
复制
saver.save(sess, 'my_test_model',global_step=1000)

使用预先训练的模型进行微调:

代码语言:javascript
复制
with tf.Session() as sess:    
    saver = tf.train.import_meta_graph('my-model-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    print(sess.run('w1:0'))
##Model has been restored. Above statement will print the saved value of w1.

通过添加更多的层来添加更多的操作,然后对其进行训练。

代码语言:javascript
复制
sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)

print sess.run(add_on_op,feed_dict)
#This will print 120.
票数 2
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/36929

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档