首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从TensorFlow word2vec教程加载保存的模型并用于word比较

如何从TensorFlow word2vec教程加载保存的模型并用于word比较
EN

Stack Overflow用户
提问于 2017-10-07 00:03:25
回答 1查看 866关注 0票数 0

我是TensorFlow、word2vec和神经网络的新手,我正在努力学习它们。我正在编写这个TensorFlow教程:https://www.tensorflow.org/tutorials/word2vec。我运行了本教程的word2vec_optimized.py代码,该教程位于以下位置:https://github.com/tensorflow/models/blob/master/tutorials/embedding/word2vec_optimized.py。当教程代码运行完毕时,它会输出一个保存的TensorFlow模型。我正在尝试看看是否可以重新加载模型,并使用它进行单词比较,例如,法国对于巴黎就像俄罗斯对于莫斯科一样。

我在教程代码中看到了一种类比方法,可以用来做这个:

代码语言:javascript
复制
def analogy(self, w0, w1, w2):
    """Predict word w3 as in w0:w1 vs w2:w3."""
    wid = np.array([[self._word2id.get(w, 0) for w in [w0, w1, w2]]])
    idx = self._predict(wid)
    for c in [self._id2word[i] for i in idx[0, :]]:
      if c not in [w0, w1, w2]:
        print(c)
        break
    print("unknown")

但首先,我需要重新加载保存的模型,这是在我的main方法中完成的:

代码语言:javascript
复制
def main(_):
  with tf.Graph().as_default(), tf.Session() as session:
    with tf.device("/cpu:0"):
      model = tf.train.import_meta_graph('model.ckpt.meta')
      model.restore(session, tf.train.latest_checkpoint('/results/'))
      model.analogy(b'france', b'paris', b'russia')

这将导致以下错误:

代码语言:javascript
复制
Traceback (most recent call last):
  File "./word2vec_test.py", line 539, in <module>
  tf.app.run()
  File "/util/opt/anaconda/2.2/envs/tensorflow-1.0.0/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "./embedding_tutorial/embedding/word2vec_test.py", line 534, in main
    model.analogy(b'france', b'paris', b'russia')
AttributeError: 'Saver' object has no attribute 'analogy'

如何加载保存的模块并使用它调用类比方法?我将我的main方法和类比方法放在同一个文件中。

EN

回答 1

Stack Overflow用户

发布于 2017-10-07 01:44:55

我认为问题在于analogy()Word2Vec类的一个方法,但是您并没有将model实例化到该类型的对象中。

试试这个:

代码语言:javascript
复制
opts = Options()
with tf.Graph().as_default(), tf.Session() as session:
  with tf.device("/cpu:0"):
    model = Word2Vec(opts, session)
    model.saver = tf.train.import_meta_graph('/path/to/model.ckpt.meta')
    model.saver.restore(session, 
                        tf.train.latest_checkpoint('/path/to/results/'))
    model.analogy(b'france', b'paris', b'russia')

或者实际上,如果你在训练时使用了--interactive,那么在训练完成后,你将使用ipython进入interactive mode

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46609740

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档