首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在pytorch中从多个数据集中加载数据

如何在pytorch中从多个数据集中加载数据
EN

Stack Overflow用户
提问于 2022-03-10 10:32:36
回答 1查看 481关注 0票数 0

我有两个图像的数据集-室内和室外,他们没有相同数目的例子。

每个数据集都有包含一定数量的类的图像(最少有1个,最多4个),这些类可以出现在两个数据集中,每个类有4个类别--红色、蓝色、绿色、白色。例如:室内猫,狗,马,户外狗,人类

我试着训练一个模型,在那里我告诉它,“这是一个图像,里面有一只猫,告诉我它是颜色的”,不管它是在哪里拍摄的(室内,户外,车里,月球上)

为此,我需要展示我的模型示例,以便每一批只有一个类别(猫、狗、马或人类),但我想从包含这些对象的所有数据集(在本例中是两个)中取样并混合它们。我该怎么做?

它必须考虑到每个数据集中的示例数是不同的,一些类别出现在一个数据集中,而另一些类别可以出现在多个数据集中。每批必须只包含一个类别。

我希望有任何帮助,我已经尝试了几天来解决这个问题。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-10 15:46:53

假设问题是:

  1. 将2+数据集与可能重叠的对象类别相结合(按标签可区分)
  2. 每个对象对每种颜色都有4个“子类别”(按标签可区分)。
  3. 每一批只应包含一个对象类别。

第一步是确保来自两个数据集的对象标签的一致性,如果不是已经一致的话。例如,如果狗类在第一个数据集中是label 0,而在第二个数据集中是label 2,那么我们需要确保两个狗类别正确合并。我们可以使用一个简单的数据集包装器来完成这个“转换”:

代码语言:javascript
复制
class TranslatedDataset(Dataset):
  """
  Args:
    dataset: The original dataset.
    translate_label: A lambda (function) that maps the original
      dataset label to the label it should have in the combined data set
  """
  def __init__(self, dataset, translate_label):
    super().__init__()
    self._dataset = dataset
    self._translate_label = translate_label

  def __len__(self):
    return len(self._dataset)

  def __getitem__(self, idx):
    inputs, target = self._dataset[idx]
    return inputs, self._translate_label(target)

下一步是将转换后的数据集组合在一起,这可以很容易地用ConcatDataset完成。

代码语言:javascript
复制
first_original_dataset = ...
second_original_dataset = ...

first_translated = TranslateDataset(
  first_original_dataset, 
  lambda y: 0 if y is 2 else 2 if y is 0 else y, # or similar
)
second_translated = TranslateDataset(
  second_original_dataset, 
  lambda y: y, # or similar
)

combined = ConcatDataset([first_translated, second_translated])

最后,我们需要将批抽样限制在同一个类上,这在创建数据加载器时使用自定义Sampler是可能的。

代码语言:javascript
复制
class SingleClassSampler(torch.utils.data.Sampler):
  def __init__(self, dataset, batch_size):
    super().__init__()
    # We need to create sequential groups
    # with batch_size elements from the same class
    indices_for_target = {} # dict to store a list of indices for each target
    
    for i, (_, target) in enumerate(dataset):
      # converting to string since Tensors hash by reference, not value
      str_targ = str(target)
      if str_targ not in indices_for_target:
        indices_for_target[str_targ] = []
      indices_for_target[str_targ] += [i]

    # make sure we have a whole number of batches for each class
    trimmed = { 
      k: v[:-(len(v) % batch_size)] 
      for k, v in indices_for_target.items()
    }

    # concatenate the lists of indices for each class
    self._indices = sum(list(trimmed.values()))
  
  def __len__(self):
    return len(self._indices)

  def __iter__(self):
    yield from self._indices

然后使用取样器:

代码语言:javascript
复制
loader = DataLoader(
  combined, 
  sampler=SingleClassSampler(combined, 64), 
  batch_size=64, 
  shuffle=True
)

我还没有运行这段代码,所以它可能不是完全正确的,但希望它将使您走上正确的轨道。

torch.utils.data文档

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71422639

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档