首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将预训练嵌入导入Tensorflow的Embedding Feature列

将预训练嵌入导入Tensorflow的Embedding Feature列
EN

Stack Overflow用户
提问于 2019-10-08 06:52:37
回答 1查看 435关注 0票数 0

我有一个TF Estimator,它在输入层使用特征列。其中之一是我随机初始化的EmbeddingColumn (默认行为)。

现在我想在gensim中预先训练我的嵌入,并将学习到的嵌入转移到我的TF模型中。embedding_column接受初始值设定项参数,该参数需要一个可调用对象,该可调用对象可以使用tf.contrib.framework.load_embedding_initializer进行created

然而,该函数需要一个保存的TF检查点,而我没有,因为我在gensim中训练了我的嵌入。

问题是:如何将gensim字向量(即numpy数组)保存为TF检查点格式的张量,以便我可以使用它来初始化我的嵌入列?

EN

回答 1

Stack Overflow用户

发布于 2019-10-10 11:19:59

想明白了!这在Tensorflow 1.14.0中起作用。

您首先需要将嵌入向量转换为tf.Variable。然后使用tf.train.Saver将其保存在检查点中。

代码语言:javascript
复制
import tensorflow as tf
import numpy as np


ckpt_name = 'gensim_embeddings'
vocab_file = 'vocab.txt'
tensor_name = 'embeddings_tensor'
vocab = ['A', 'B', 'C']
embedding_vectors = np.array([
    [1,2,3],
    [4,5,6],
    [7,8,9]
], dtype=np.float32)

embeddings = tf.Variable(initial_value=embedding_vectors)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver({tensor_name: embeddings})
with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, ckpt_name)

# writing vocab file
with open(vocab_file, 'w') as f:
    f.write('\n'.join(vocab))

要使用此检查点初始化嵌入功能列,请执行以下操作:

代码语言:javascript
复制
cat = tf.feature_column.categorical_column_with_vocabulary_file(
    key='cat', vocabulary_file=vocab_file)

embedding_initializer = tf.contrib.framework.load_embedding_initializer(
    ckpt_path=ckpt_name,
    embedding_tensor_name='embeddings_tensor',
    new_vocab_size=3,
    embedding_dim=3,
    old_vocab_file=vocab_file,
    new_vocab_file=vocab_file
)

emb = tf.feature_column.embedding_column(cat, dimension=3, initializer=embedding_initializer, trainable=False)

我们可以测试以确保它已被正确初始化:

代码语言:javascript
复制
def test_embedding(feature_column, sample):
    feature_layer = tf.keras.layers.DenseFeatures(feature_column)
    print(feature_layer(sample).numpy())

tf.enable_eager_execution()

sample = {'cat': tf.constant(['B', 'A'], dtype=tf.string)}

test_embedding(item_emb, sample)

不出所料,输出为:

代码语言:javascript
复制
[[4. 5. 6.]
 [1. 2. 3.]]

它们分别是'B‘和'A’的嵌入。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58278111

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档