首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将tf.dataset转换为PyTorch数据集?

将tf.dataset转换为PyTorch数据集?
EN

Stack Overflow用户
提问于 2021-05-01 11:00:42
回答 1查看 1.5K关注 0票数 3

在这个项目中,所有数据都经过预处理,并准备好作为tensorflow数据集,如下所示:

代码语言:javascript
复制
<MapDataset shapes: {input_ids: (128,), input_mask: (128,), label_ids: (), segment_ids: (128,)}, types: {input_ids: tf.int64, input_mask: tf.int64, label_ids: tf.int64, segment_ids: tf.int64}>

我拥有的脚本在PyTorch中,并接受一个Dataset对象,如下所示:

代码语言:javascript
复制
Dataset({
    features: ['attention_mask', 'input_ids', 'label', 'sentence', 'token_type_ids'],
    num_rows: 12
})

有什么办法把其中一个转换成另一个吗?我对这两个API都很陌生,所以我不太确定它们是如何工作的?我是否可以使用dict将其中一个转换成另一个呢?

谢谢

EN

回答 1

Stack Overflow用户

发布于 2022-04-24 12:24:39

我使用tfds.as_numpy(dataset)作为我的模型培训的数据处理程序。为了转换传递给我的模型的数据,我在模型的前向函数中使用了torch.as_tensor(data, device=<device>)

代码语言:javascript
复制
import tensorflow_datasets as tfds
import torch.nn as nn

def train_dataloader(batch_size):
    return tfds.as_numpy(tfds.load('mnist').batch(batch_size))

class Model(nn.Module):
    def forward(self, x):
        x = torch.as_tensor(x, device='cuda')
        ...
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67345480

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档