首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >获取opennmt-py中的注意力权重

获取opennmt-py中的注意力权重
EN

Stack Overflow用户
提问于 2019-07-25 06:37:18
回答 1查看 287关注 0票数 0

特别是在opennmt-py中。现在有很多关于这个话题的问题,比如Getting alignment/attention during translation in OpenNMT-py和开放论坛https://github.com/OpenNMT/OpenNMT-py/issues/575上的下面的帖子。我使用后者建议的代码。然而,似乎没有人能解决我的问题。我尝试运行以下简单的代码片段。

代码语言:javascript
复制
import onmt
import onmt.inputters
import onmt.translate
import onmt.model_builder
from collections import namedtuple


Opt = namedtuple('Opt', ['models', 'data_type', 'reuse_copy_attn', "gpu"])


opt = Opt("/home/Desktop/hidden-att/model/hidden-2/seed-0/LSTMlang1_step_400.pt", "text",False, 0)
fields, model, model_opt =  onmt.model_builder.load_test_model(opt,{"reuse_copy_attn":False})

我得到了这个错误跟踪。

代码语言:javascript
复制
Traceback (most recent call last):

  File "<ipython-input-63-94c1f45c429f>", line 1, in <module>
    runfile('/home/Desktop/hidden-att/graph_hidden_exp.py', wdir='/home/Desktop/hidden-att')

  File "/home/anaconda3/lib/python3.7/site-packages/spyder_kernels/customize/spydercustomize.py", line 786, in runfile
    execfile(filename, namespace)

  File "/home/anaconda3/lib/python3.7/site-packages/spyder_kernels/customize/spydercustomize.py", line 110, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)

  File "/home/Desktop/hidden-att/graph_hidden_exp.py", line 33, in <module>
    fields, model, model_opt =  onmt.model_builder.load_test_model(opt,{"reuse_copy_attn":False})

  File "../../Documents/NMT/OpenNMT-py/onmt/model_builder.py", line 85, in load_test_model
    map_location=lambda storage, loc: storage)

  File "/home/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 387, in load
    return _load(f, map_location, pickle_module, **pickle_load_args)

  File "/home/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 549, in _load
    _check_seekable(f)

  File "/home/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 194, in _check_seekable
    raise_err_msg(["seek", "tell"], e)

  File "/home/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 187, in raise_err_msg
    raise type(e)(msg)

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

那么有没有人体验过并解决了这个问题呢?或者知道去哪里找?我猜这与加载的文件有关,但它是以相当标准的方式使用opennmt-py进行训练的。

EN

回答 1

Stack Overflow用户

发布于 2019-11-29 19:20:00

您可以在翻译脚本中添加--attn_debug参数来查看注意力权重。

代码语言:javascript
复制
translate.py ... \
             -attn_debug \
             ...
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57192111

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档