首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我能导出tensorflow中单词的嵌入矩阵吗?

我能导出tensorflow中单词的嵌入矩阵吗?
EN

Stack Overflow用户
提问于 2017-12-29 08:41:07
回答 1查看 4.3K关注 0票数 3
代码语言:javascript
复制
def word_embedding(shape, dtype=tf.float32, name='word_embedding'):
  with tf.device('/cpu:0'), tf.variable_scope(name):
    return tf.get_variable('embedding', shape, dtype=dtype, initializer=tf.random_normal_initializer(stddev=0.1), trainable=True,partitioner=tf.fixed_size_partitioner(20))
embedding = word_embedding([vocab_size, embed_size])
inputs_embedding = tf.contrib.layers.embedding_lookup_unique(embedding, inputs)

这是我的代码,embedding是word查找自己的嵌入向量的变量。

我已经训练了嵌入矩阵,我想从保存的模型中提取它。该模型还包含其它参数,如嵌入以上的神经网络。我能实现它吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-12-29 09:45:08

类似的问题请参见my answer

最简单的方法是将嵌入矩阵计算到numpy数组中,并将其与解析的单词一起写入文件。

代码语言:javascript
复制
with tf.Session() as sess:
  embedding_val = sess.run(embedding)
  with open('embedding.txt', 'w') as file_:
    for i in range(vocabulary_size):
      embed = embedding_val[i, :]
      word = word_to_idx[i]
      file_.write('%s %s\n' % (word, ' '.join(map(str, embed))))

如果只想保存此图的嵌入,可以创建tf.train.Saver并传递要保存的变量列表:

代码语言:javascript
复制
saver = tf.train.Saver([embedding])
with tf.Session() as sess:
  saver.save(sess, 'path/to/checkpoint')
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48019799

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档