首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >装载机的无效数据类型-火炬闪电DataModule

装载机的无效数据类型-火炬闪电DataModule
EN

Stack Overflow用户
提问于 2021-12-18 17:13:09
回答 1查看 243关注 0票数 0

我正在尝试一个文本摘要练习,我有两列文本和摘要(标签)的训练和测试数据集。我使用的是T5、Pytorch和闪电包装器,我有一个Pytorch Dataset类,我可以确认它是否正确工作,并将以下内容作为文本字典以及ids、标签和掩码作为张量返回。

代码语言:javascript
复制
return dict(
    text=text,
    summary = data_row['summary'],
    text_input_ids = text_encoding['input_ids'].flatten(),
    text_attention_mask = text_encoding['attention_mask'].flatten(),
    labels = labels.flatten(),
    labels_attention_mask = summary_encoding['attention_mask'].flatten()
)

然后,我有一个闪电数据模块类,它将数据格式转换成PyTorch数据集,并适合于数据加载器、返回列车、val和测试数据加载器

代码语言:javascript
复制
class TextSummaryDataModule(pl.LightningModule):
  def __init__(
      self, 
      train_df: pd.DataFrame, 
      test_df: pd.DataFrame, 
      tokenizer: T5Tokenizer, 
      batch_size: int=8, 
      text_max_token_len: int=512, 
      summary_max_token_len: int=128
    ):
    
      super().__init__()
      
      self.train_df = train_df
      self.test_df = test_df

      self.tokenizer = tokenizer
      self.batch_size = batch_size
      self.text_max_token_len = text_max_token_len
      self.summary_max_token_len = summary_max_token_len

  def setup(self):
    self.train_dataset = TextSummaryDataset(
        self.train_df,
        self.tokenizer,
        self.text_max_token_len,
        self.summary_max_token_len
    )

    self.test_dataset = TextSummaryDataset(
        self.test_df,
        self.tokenizer,
        self.text_max_token_len,
        self.summary_max_token_len
    )

  def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size = self.batch_size,
        shuffle=True,
        num_workers=2
    )

  def val_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size = self.batch_size,
        shuffle=False,
        num_workers=2
    )

  def test_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size = self.batch_size,
        shuffle=False,
        num_workers=2
    )

一切都正常,直到我尝试执行模型,并得到以下警告和错误

  1. UserWarning:您定义了validation_step,但没有val_dataloader。跳过验证循环--我已经在数据模块

中明确定义并返回了这个循环。

用于加载程序的

  1. 无效数据类型: TextSummaryDataModule -我已确认正在返回文本和摘要

的标记、attention_mask和标签的字典

EN

回答 1

Stack Overflow用户

发布于 2021-12-27 12:06:40

可耻的是我用的是pl.LightningModule而不是DataModule .

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70405450

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档