首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >带有Dataloader的TypeError

带有Dataloader的TypeError
EN

Stack Overflow用户
提问于 2022-06-16 22:10:30
回答 1查看 127关注 0票数 0

我使用了一个非常大的数据集来测试我的模型。为了使测试样本更快,我想构建一个数据加载器。但我错了。我两天都解决不了。这是我的代码:

代码语言:javascript
复制
 PRE_TRAINED_MODEL_NAME = 'bert-base-cased'
 tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)

 class GPReviewDataset(Dataset):
    def __init__(self, Paragraph, target, tokenizer, max_len):
       self.Paragraph = Paragraph
       self.target= target
       self.tokenizer = tokenizer
       self.max_len = max_len
    
    def __len__(self):
       return len(self.Paragraph)

    def __getitem__(self, item):
       Paragraph = str(self.Paragraph[item])
       target = self.target[item]
       encoding = self.tokenizer.encode_plus(
       Paragraph,
       add_special_tokens=True,
       max_length=self.max_len,
       return_token_type_ids=False,
       pad_to_max_length=True,
       return_attention_mask=True,
       return_tensors='pt',
       )
       return {
       'review_text': Paragraph,
       'input_ids': encoding['input_ids'].flatten(),
       'attention_mask': encoding['attention_mask'].flatten(),
       'targets': torch.tensor(target, dtype=torch.long)
       }


def create_data_loader(df, tokenizer, max_len, batch_size):
    ds = GPReviewDataset(
    Paragraph=df.Paragraph.to_numpy(),
    target=df.target.to_numpy(),
    tokenizer=tokenizer,
    max_len=max_len
    )
   return DataLoader(
     ds,
    batch_size=batch_size,
    num_workers=4
    )

 # Main function
 paragraph=['Image to PDF Converter. ', 'Test Test']
 target=['0','1']
 df = pd.DataFrame({'Paragraph': paragraph, 'target': target})


 MAX_LEN='512'
 BATCH_SIZE = 1
 train_data_loader1 = create_data_loader(df, tokenizer, MAX_LEN, BATCH_SIZE)
 for d in train_data_loader1:
      print(d)

当我遍历dataloader时,我得到了以下错误:

代码语言:javascript
复制
  TypeError: Caught TypeError in DataLoader worker process 0.
  Original Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", 
  line 178, in _worker_loop
   data = fetcher.fetch(index)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
  data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
  data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-3-c4f87a4dbb48>", line 20, in __getitem__
  return_tensors='pt',
  File "/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils.py", line 1069, in encode_plus
    return_special_tokens_mask=return_special_tokens_mask,
  File "/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils.py", line 1365, in prepare_for_model
    if max_length and total_len > max_length:
   TypeError: '>' not supported between instances of 'int' and 'str'

有谁可以帮我?另外,你能给出一些关于我如何在大型数据集上测试我的模型的建议吗?我的意思是,在3M数据样本上测试我的模型的更快的方法是什么?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-06-25 19:55:36

这个错误正如它所说的

代码语言:javascript
复制
File "/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils.py", line 1365, in prepare_for_model
    if max_length and total_len > max_length:
   TypeError: '>' not supported between instances of 'int' and 'str'

您应该将MAX_LEN从string改为int:

代码语言:javascript
复制
# MAX_LEN='512'
MAX_LEN=512
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72652399

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档