首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何恢复fully_connected函数中的变量

如何恢复fully_connected函数中的变量
EN

Stack Overflow用户
提问于 2018-01-25 08:20:00
回答 1查看 875关注 0票数 1

在我的培训文件(train.py)中,我写道:

代码语言:javascript
复制
def deep_part(self):
    with tf.variable_scope("deep-part"):                                              
        y_deep = tf.reshape(self.embeddings, shape=[-1, self.field_size * self.factor_size]) # None * (F*K)
        # self.deep_layers = 2
        for i in range(0,len(self.deep_layers)):
            y_deep = tf.contrib.layers.fully_connected(y_deep, self.deep_layers[i], \ 
            activation_fn=self.deep_layers_activation, scope = 'fc%d' % i)
        return y_deep

现在在预测文件(predict.py)中,我恢复检查点,但我不知道如何重新加载“深层”网络的权重和biases.Because,我认为"fully_conncted“函数可能隐藏权重和偏差。

EN

回答 1

Stack Overflow用户

发布于 2018-01-25 08:38:30

我写了一个很长的在此解释。简短摘要:

saver.save(sess, '/tmp/my_model') Tensorflow生成多个文件:

代码语言:javascript
复制
checkpoint
my_model.data-00000-of-00001
my_model.index
my_model.meta

检查点文件checkpoint只是指向我们模型权重的最新版本的指针,它只是一个简单的文本文件,包含

代码语言:javascript
复制
$ !cat /tmp/model/checkpoint
model_checkpoint_path: "/tmp/my_model"
all_model_checkpoint_paths: "/tmp/my_model"

其他文件是包含图形(.meta)和权重(.data*)的二进制文件。

你可以通过跑步来帮助自己

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

data = np.arange(9 * 1).reshape(1, 9).astype(np.float32)

plhdr = tf.placeholder(tf.float32, shape=[1, 9], name='input')
print plhdr.name

activation = tf.layers.dense(plhdr, 10, name='fc')
print activation.name

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    expected = sess.run(activation, {plhdr: data})
    print expected

    saver = tf.train.Saver(tf.global_variables())
    saver.save(sess, '/tmp/my_model')


tf.reset_default_graph()
with tf.Session() as sess:
    # load the computation graph (the fully connected + placeholder)
    loader = tf.train.import_meta_graph('/tmp/my_model.meta')
    sess.run(tf.global_variables_initializer())

    plhdr = tf.get_default_graph().get_tensor_by_name('input:0')
    activation = tf.get_default_graph().get_tensor_by_name('fc/BiasAdd:0')
    actual = sess.run(activation, {plhdr: data})
    assert np.allclose(actual, expected) is False

    # now load the weights
    loader = loader.restore(sess, '/tmp/my_model')
    actual = sess.run(activation, {plhdr: data})
    assert np.allclose(actual, expected) is True
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48438419

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档