首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >HuggingFace Bert情感分析

HuggingFace Bert情感分析
EN

Stack Overflow用户
提问于 2021-01-25 17:13:28
回答 1查看 3.2K关注 0票数 1

我收到以下错误:

AssertionError: text input must of type str (single example), List[str] (batch or single pretokenized example) or List[List[str]] (batch of pretokenized examples).,当我运行classifier(encoded)的时候。我的文本类型是str,所以我不确定我做错了什么。任何帮助都是非常感谢的。

代码语言:javascript
复制
import torch
from transformers import AutoTokenizer, BertTokenizer, BertModel, BertForMaskedLM, AutoModelForSequenceClassification, pipeline

# OPTIONAL: if you want to have more information on what's happening under the hood, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

# Load pre-trained model tokenizer (vocabulary)
# used the cased instead of uncased to account for cases like BAD.
tokenizer = BertTokenizer.from_pretrained('bert-base-cased') 


# alternative? what is the difference between these two tokenizers? 
#tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-SST-2")

model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2")


# feed the model and the tokenizer into the pipeline
classifier = pipeline('sentiment-analysis', model=model, tokenizer= tokenizer)


#---------------sample raw input passage--------

text = "Who was Jim Henson ? Jim Henson was a puppeteer. He is simply awful."
# tokenized_text = tokenizer.tokenize(text)

#----------Tokenization and Padding---------
# Encode the sentences to get tokenized and add padding stuff
encoded = tokenizer.encode_plus(
    text=text,  # the sentences to be encoded
    add_special_tokens=True,  # Add [CLS] and [SEP] !!!
    max_length = 64,  # maximum length of a sentence  (TODO Figure the longest passage length)
    pad_to_max_length=True,  # Add [PAD]s
    return_attention_mask = True,  # Generate the attention mask
    truncation=True,  #explicitly truncate examples to max length
    return_tensors = 'pt',  # ask the function to return PyTorch tensors
)

#-------------------------------------------
# view the IDs
for key, value in encoded.items():
    print(f"{key}: {value.numpy().tolist()}")
    
#-------------------------------------------


classifier(encoded)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-01-25 18:02:05

流水线已经包含编码器。而不是

代码语言:javascript
复制
classifier(encoded)

代码语言:javascript
复制
classifier(text)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65881820

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档