首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何解决变压器模型DistilBert错误得到一个意想不到的关键字参数'special_tokens_mask‘

如何解决变压器模型DistilBert错误得到一个意想不到的关键字参数'special_tokens_mask‘
EN

Stack Overflow用户
提问于 2022-05-17 20:02:50
回答 1查看 593关注 0票数 1

我在用

苹果Mac M1

操作系统: MacOS蒙特利

Python 3.10.4

我试图使用DistilBERT和Weaviate实现向量搜索,方法是遵循这个教程

下面是代码设置

代码语言:javascript
复制
import nltk
import os
import random
import time
import torch
import weaviate
from transformers import AutoModel, AutoTokenizer
from nltk.tokenize import sent_tokenize

torch.set_grad_enabled(False)

# udpated to use different model if desired
MODEL_NAME = "distilbert-base-uncased"
model = AutoModel.from_pretrained(MODEL_NAME)
model.to('cuda') # remove if working without GPUs
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# initialize nltk (for tokenizing sentences)
nltk.download('punkt')

# initialize weaviate client for importing and searching
client = weaviate.Client("http://localhost:8080")

def get_post_filenames(limit_objects=100):
    file_names = []
    i=0
    for root, dirs, files in os.walk("./data/20news-bydate-test"):
        for filename in files:
            path = os.path.join(root, filename)
            file_names += [path]
        
    random.shuffle(file_names)
    limit_objects = min(len(file_names), limit_objects)
      
    file_names = file_names[:limit_objects]

    return file_names

def read_posts(filenames=[]):
    posts = []
    for filename in filenames:
        f = open(filename, encoding="utf-8", errors='ignore')
        post = f.read()
        
        # strip the headers (the first occurrence of two newlines)
        post = post[post.find('\n\n'):]
        
        # remove posts with less than 10 words to remove some of the noise
        if len(post.split(' ')) < 10:
               continue
        
        post = post.replace('\n', ' ').replace('\t', ' ').strip()
        if len(post) > 1000:
            post = post[:1000]
        posts += [post]

    return posts       


def text2vec(text):
    tokens_pt = tokenizer(text, padding=True, truncation=True, max_length=500, add_special_tokens = True, return_tensors="pt")
    tokens_pt.to('cuda') # remove if working without GPUs
    outputs = model(**tokens_pt)
    return outputs[0].mean(0).mean(0).detach()

def vectorize_posts(posts=[]):
    post_vectors=[]
    before=time.time()
    for i, post in enumerate(posts):
        vec=text2vec(sent_tokenize(post))
        post_vectors += [vec]
        if i % 100 == 0 and i != 0:
            print("So far {} objects vectorized in {}s".format(i, time.time()-before))
    after=time.time()
    
    print("Vectorized {} items in {}s".format(len(posts), after-before))
    
    return post_vectors

def init_weaviate_schema():
    # a simple schema containing just a single class for our posts
    schema = {
        "classes": [{
                "class": "Post",
                "vectorizer": "none", # explicitly tell Weaviate not to vectorize anything, we are providing the vectors ourselves through our BERT model
                "properties": [{
                    "name": "content",
                    "dataType": ["text"],
                }]
        }]
    }

    # cleanup from previous runs
    client.schema.delete_all()

    client.schema.create(schema)

def import_posts_with_vectors(posts, vectors, batchsize=256):
    batch = weaviate.ObjectsBatchRequest()

    for i, post in enumerate(posts):
        props = {
            "content": post,
        }
        batch.add(props, "Post", vector=vectors[i])
        
        # when either batch size is reached or we are at the last object
        if (i !=0 and i % batchsize == 0) or i == len(posts) - 1:
            # send off the batch
            client.batch.create(batch)
            
            # and reset for the next batch
            batch = weaviate.ObjectsBatchRequest() 
    

def search(query="", limit=3):
    before = time.time()
    vec = text2vec(query)
    vec_took = time.time() - before

    before = time.time()
    near_vec = {"vector": vec.tolist()}
    res = client \
        .query.get("Post", ["content", "_additional {certainty}"]) \
        .with_near_vector(near_vec) \
        .with_limit(limit) \
        .do()
    search_took = time.time() - before

    print("\nQuery \"{}\" with {} results took {:.3f}s ({:.3f}s to vectorize and {:.3f}s to search)" \
          .format(query, limit, vec_took+search_took, vec_took, search_took))
    for post in res["data"]["Get"]["Post"]:
        print("{:.4f}: {}".format(post["_additional"]["certainty"], post["content"]))
        print('---')

# run everything
init_weaviate_schema()
posts = read_posts(get_post_filenames(4000))
vectors = vectorize_posts(posts)
import_posts_with_vectors(posts, vectors)

search("the best camera lens", 1)
search("which software do i need to view jpeg files", 1)
search("windows vs mac", 1)

触发错误下面的功能

代码语言:javascript
复制
def text2vec(text):
    # tokens_pt = tokenizer(text, padding=True, truncation=True, max_length=500, add_special_tokens = True, return_tensors="pt")
    tokens_pt = tokenizer.encode_plus(text, add_special_tokens = True,    truncation = True, padding = "max_length", return_attention_mask = True, return_tensors = "pt")

    tokens_pt.to('cuda') # remove if working without GPUs
    outputs = model(**tokens_pt)
    return outputs[0].mean(0).mean(0).detach()

错误1 tokens_pt.to('cuda') # remove如果没有GPUs:'dict‘对象没有属性'to’

当我注释掉GPU时

代码语言:javascript
复制
#tokens_pt.to('cuda')

并运行代码。我知道这个错误

误差2 outputs = model(**tokens_pt) File **tokens_pt第1110行,在_call_impl返回forward_call(*input,**kwargs) TypeError: DistilBertModel.forward()得到一个意外的关键字参数'special_tokens_mask‘

是什么导致了这个错误,我如何修复它?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-05-18 08:46:16

我无法在我的环境(Ubuntu)上重现您的错误,但据我所见,我建议尝试添加return_special_tokens_mask=False参数:

代码语言:javascript
复制
tokens_pt = tokenizer.encode_plus(
    text, 
    add_special_tokens=True,
    truncation=True,
    padding="max_length",
    return_attention_mask=True,
    return_tensors="pt",
    return_special_tokens_mask=False
)

如果失败,请显式地删除它:

代码语言:javascript
复制
tokens_pt.pop("special_tokens_mask")
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72280030

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档