首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我怎样才能从问答管道中得到分数?使用问答管道时是否有错误?

我怎样才能从问答管道中得到分数?使用问答管道时是否有错误?
EN

Stack Overflow用户
提问于 2020-08-22 08:06:35
回答 1查看 711关注 0票数 2

当我运行以下代码时

代码语言:javascript
复制
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch

tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")

text = r"""
As checked Dis is not yet on boarded to ARB portal, hence we cannot upload the invoices in portal
"""

questions = [
    "Dis asked if it is possible to post the two invoice in ARB.I have not access so I wanted to check if you would be able to do it.",
]

for question in questions:
    inputs = tokenizer.encode_plus(question, text, add_special_tokens=True, return_tensors="pt")
    input_ids = inputs["input_ids"].tolist()[0]

    text_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    answer_start_scores, answer_end_scores = model(**inputs)

    answer_start = torch.argmax(
        answer_start_scores
    )  # Get the most likely beginning of answer with the argmax of the score
    answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score

    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

    print(f"Question: {question}")
    print(f"Answer: {answer}\n")

我在这里得到的答案是:

代码语言:javascript
复制
Question: Dis asked if it is possible to post the two invoice in ARB.I have not access so I wanted to check if you would be able to do it.
Answer: dis is not yet on boarded to ARB portal

我怎么才能得到这个答案的分数呢?这里的分数与我运行问答管道时得到的非常相似。

我必须采用这种方法,因为问答管道在使用时会给我以下代码带来关键错误。

代码语言:javascript
复制
from transformers import pipeline

nlp = pipeline("question-answering")

context = r"""
As checked Dis is not yet on boarded to ARB portal, hence we cannot upload the invoices in portal.
"""

print(nlp(question="Dis asked if it is possible to post the two invoice in ARB?", context=context))
EN

回答 1

Stack Overflow用户

发布于 2020-12-11 11:17:24

这是我想要得到的分数。看来我不知道什么是feature.p_mask。因此,目前我无法删除对softmax有贡献的非上下文索引。

代码语言:javascript
复制
# ... assuming imports and question and context

model_name="deepset/roberta-base-squad2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

inputs = tokenizer(question, context, 
                       add_special_tokens=True, 
                       return_tensors='pt')
input_ids = inputs['input_ids'].tolist()[0]

outputs = model(**inputs)
    

# used to compute score
start = outputs.start_logits.detach().numpy()
end = outputs.end_logits.detach().numpy()

# from source code

# Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
#?? undesired_tokens = np.abs(np.array(feature.p_mask) - 1) & feature.attention_mask

# Generate mask

undesired_tokens = inputs['attention_mask']
undesired_tokens_mask = undesired_tokens == 0.0

# Make sure non-context indexes in the tensor cannot contribute to the softmax
start_ = np.where(undesired_tokens_mask, -10000.0, start)
end_ = np.where(undesired_tokens_mask, -10000.0, end)

# Normalize logits and spans to retrieve the answer
start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))

# Compute the score of each tuple(start, end) to be the real answer
outer = np.matmul(np.expand_dims(start_, -1), np.expand_dims(end_, 1))

# Remove candidate with end < start and end - start > max_answer_len
max_answer_len = 15
candidates = np.tril(np.triu(outer), max_answer_len - 1)
scores_flat = candidates.flatten()

idx_sort = [np.argmax(scores_flat)]
start, end = np.unravel_index(idx_sort, candidates.shape)[1:]
end += 1
score = candidates[0, start, end-1]
start, end, score = start.item(), end.item(), score.item()


print(tokenizer.decode(input_ids[start:end]))
print(score)

请参阅更多源代码

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63533941

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档