首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用tensorflow提取ELMo特性并将它们转换为numpy

使用tensorflow提取ELMo特性并将它们转换为numpy
EN

Stack Overflow用户
提问于 2021-04-28 10:53:35
回答 1查看 681关注 0票数 1

因此,我感兴趣的是使用ELMo模型提取句子嵌入。

一开始我试过:

代码语言:javascript
复制
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np

elmo_model = hub.Module("https://tfhub.dev/google/elmo/2", trainable=True)

x = ["Hi my friend"]

embeddings = elmo_model(x, signature="default", as_dict=True)["elmo"]


print(embeddings.shape)
print(embeddings.numpy())

它在最后一行之前运行良好,我无法将其转换为numpy数组。

我搜索了一下,发现如果我在代码的开头加上下面的一行,问题就必须解决。

代码语言:javascript
复制
tf.enable_eager_execution()

但是,我把它放在代码的开头,我意识到我无法编译

代码语言:javascript
复制
elmo_model = hub.Module("https://tfhub.dev/google/elmo/2", trainable=True)

我收到了这个错误:

启用急切执行时,不支持

导出/导入元图。启用急切执行时,不存在图形。

我该如何解决我的问题?我的目标是获得句子特征,并在NumPy数组中使用它们。

提前感谢

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-04-28 14:30:22

TF 2.x

TF2行为更接近于传统的python行为,因为它默认为急切的执行。但是,您应该使用hub.load在TF2中加载您的模型。

代码语言:javascript
复制
elmo = hub.load("https://tfhub.dev/google/elmo/2").signatures["default"]
x = ["Hi my friend"]
embeddings = elmo(tf.constant(x))["elmo"]

然后,可以使用numpy方法访问结果并将其转换为numpy数组。

代码语言:javascript
复制
>>> embeddings.numpy()
array([[[-0.7205108 , -0.27990735, -0.7735629 , ..., -0.24703965,
         -0.8358178 , -0.1974785 ],
        [ 0.18500198, -0.12270843, -0.35163105, ...,  0.14234722,
          0.08479916, -0.11709933],
        [-0.49985904, -0.88964033, -0.30124515, ...,  0.15846594,
          0.05210422,  0.25386307]]], dtype=float32)

TF 1.x

如果使用Tf1.x,则应该在tf.Session中运行该操作。TensorFlow不使用急切的执行,需要先构建图表,然后在会话中评估结果。

代码语言:javascript
复制
elmo_model = hub.Module("https://tfhub.dev/google/elmo/2", trainable=True)
x = ["Hi my friend"]
embeddings_op = elmo_model(x, signature="default", as_dict=True)["elmo"]
# required to load the weights into the graph
init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    embeddings = sess.run(embeddings_op)

在这种情况下,结果将是一个numpy数组:

代码语言:javascript
复制
>>> embeddings
array([[[-0.72051036, -0.27990723, -0.773563  , ..., -0.24703972,
         -0.83581805, -0.19747877],
        [ 0.18500218, -0.12270836, -0.35163072, ...,  0.14234722,
          0.08479934, -0.11709933],
        [-0.49985906, -0.8896401 , -0.3012453 , ...,  0.15846589,
          0.05210405,  0.2538631 ]]], dtype=float32)
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67298869

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档