在我的培训文件(train.py)中,我写道:
def deep_part(self):
with tf.variable_scope("deep-part"):
y_deep = tf.reshape(self.embeddings, shape=[-1, self.field_size * self.factor_size]) # None * (F*K)
# self.deep_layers = 2
for i in range(0,len(self.deep_layers)):
y_deep = tf.contrib.layers.fully_connected(y_deep, self.deep_layers[i], \
activation_fn=self.deep_layers_activation, scope = 'fc%d' % i)
return y_deep现在在预测文件(predict.py)中,我恢复检查点,但我不知道如何重新加载“深层”网络的权重和biases.Because,我认为"fully_conncted“函数可能隐藏权重和偏差。
发布于 2018-01-25 08:38:30
我写了一个很长的在此解释。简短摘要:
由saver.save(sess, '/tmp/my_model') Tensorflow生成多个文件:
checkpoint
my_model.data-00000-of-00001
my_model.index
my_model.meta检查点文件checkpoint只是指向我们模型权重的最新版本的指针,它只是一个简单的文本文件,包含
$ !cat /tmp/model/checkpoint
model_checkpoint_path: "/tmp/my_model"
all_model_checkpoint_paths: "/tmp/my_model"其他文件是包含图形(.meta)和权重(.data*)的二进制文件。
你可以通过跑步来帮助自己
import tensorflow as tf
import numpy as np
data = np.arange(9 * 1).reshape(1, 9).astype(np.float32)
plhdr = tf.placeholder(tf.float32, shape=[1, 9], name='input')
print plhdr.name
activation = tf.layers.dense(plhdr, 10, name='fc')
print activation.name
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
expected = sess.run(activation, {plhdr: data})
print expected
saver = tf.train.Saver(tf.global_variables())
saver.save(sess, '/tmp/my_model')
tf.reset_default_graph()
with tf.Session() as sess:
# load the computation graph (the fully connected + placeholder)
loader = tf.train.import_meta_graph('/tmp/my_model.meta')
sess.run(tf.global_variables_initializer())
plhdr = tf.get_default_graph().get_tensor_by_name('input:0')
activation = tf.get_default_graph().get_tensor_by_name('fc/BiasAdd:0')
actual = sess.run(activation, {plhdr: data})
assert np.allclose(actual, expected) is False
# now load the weights
loader = loader.restore(sess, '/tmp/my_model')
actual = sess.run(activation, {plhdr: data})
assert np.allclose(actual, expected) is Truehttps://stackoverflow.com/questions/48438419
复制相似问题