首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Huggingface转换器模型返回字符串而不是logits

Huggingface转换器模型返回字符串而不是logits
EN

Stack Overflow用户
提问于 2020-11-19 05:40:15
回答 1查看 3.6K关注 0票数 6

我试着从huggingface网站上运行这个例子。https://huggingface.co/transformers/task_summary.html。该模型似乎返回两个字符串,而不是logits!这会导致torch.argmax()抛出一个错误

代码语言:javascript
复制
    from transformers import AutoTokenizer, AutoModelForQuestionAnswering
    import torch
    
    tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
    
    model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad", return_dict=True)
    
    text = r"""? Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose
    architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural
    Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between
    TensorFlow 2.0 and PyTorch.
    """
    
    questions = ["How many pretrained models are available in ? Transformers?",
    "What does ? Transformers provide?",
    "? Transformers provides interoperability between which frameworks?"]
    
    for question in questions:
      inputs = tokenizer(question, text, add_special_tokens=True, return_tensors="pt")
      input_ids = inputs["input_ids"].tolist()[0] # the list of all indices of words in question + context
    
      text_tokens = tokenizer.convert_ids_to_tokens(input_ids) # Get the tokens for the question + context
      answer_start_scores, answer_end_scores = model(**inputs)
    
      answer_start = torch.argmax(answer_start_scores)  # Get the most likely beginning of answer with the argmax of the score
      answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score
    
      answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
    
      print(f"Question: {question}")
      print(f"Answer: {answer}")
EN

回答 1

Stack Overflow用户

发布于 2020-11-19 06:23:10

由于最近的一次更新,模型现在返回特定于任务的输出对象(字典),而不是普通的元组。您使用的网站尚未更新以反映此更改。您可以通过指定return_dict=False来强制模型返回元组

代码语言:javascript
复制
answer_start_scores, answer_end_scores = model(**inputs, return_dict=False)

或者,您可以通过调用values()方法从QuestionAnsweringModelOutput对象中提取值:

代码语言:javascript
复制
answer_start_scores, answer_end_scores = model(**inputs).values()

或者甚至使用QuestionAnsweringModelOutput对象:

代码语言:javascript
复制
outputs = model(**inputs)
answer_start_scores = outputs.start_logits
answer_end_scores = outputs.end_logits
票数 17
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64901831

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档