首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何按标签或每组数据分割数据集;火炬

如何按标签或每组数据分割数据集;火炬
EN

Stack Overflow用户
提问于 2022-10-31 18:57:21
回答 1查看 41关注 0票数 0

我有一个数据集的苹果图像和他们的糖水平。我为数据集拍了6张苹果照片。所以一个苹果有6张照片&它的糖水平。

我想将数据集分成训练和验证两部分。我希望整个苹果图片(6张照片在一套)进入火车或验证集。我不知道怎样才能那样分裂。

这是数据集的CSV文件。

苹果是标签。

提前谢谢你!

EN

回答 1

Stack Overflow用户

发布于 2022-11-01 22:30:06

您可以简单地找到苹果ID,然后按这些ID进行拆分。然后,可以将其传递到dataset类中,以便将它们拆分到苹果ids中,而不是在df的行之间随机拆分的标准方法。

代码语言:javascript
复制
apple_df = pd.read_csv(...)
apple_ids = apple_df['apple'].unique() #drop_duplicates() if DataFrame
apple_ids = apple_ids.sample(frac=1) #shuffle
train_val_split = int(0.9 * len(apple_ids))
train_apple_ids = apple_ids[:train_val_split]
val_apple_ids = apple_ids[train_val_split:]

class apple_dset(torch.utils.data.Dataset):
     def __init__(self,df)
          super(apple_dset,self).__init__()
          self.df = df
     def __len__(self):
          return len(self.df.index)
     def __getitem__(self,idx):
          apple = self.df.iloc[idx]
          # do loading...
          return img, label

train_apple_df = apple_df.loc[apple_df['apple'].isin([train_apple_ids])]
val_apple_df = apple_df.loc[apple_df['apple'].isin([val_apple_ids])]

train_apple_ds = apple_dset(train_apple_df)
val_apple_ds = apple_dset(val_apple_df)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74267976

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档