首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch -创建联邦CIFAR-10数据集

PyTorch -创建联邦CIFAR-10数据集
EN

Stack Overflow用户
提问于 2021-01-31 07:09:42
回答 1查看 393关注 0票数 1

我在CIFAR-10数据集上训练一个神经网络(不管哪一个)。我使用的是联合学习:

  • 我有10个模型,每个模型都可以访问自己的部分数据集。在每一时间步骤中,每个模型都使用自己的数据执行一个步骤,然后全局模型是模型的平均值(这个版本基于,但我尝试了许多选项):
代码语言:javascript
复制
def server_aggregate(server_model, client_models):
    global_dict = server_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)
    server_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(server_model.state_dict())
  • 具体而言,每台机器只能访问与单个类相对应的数据。也就是说,机器0只有与类0相对应的示例,等等。我这样做的方法如下:
代码语言:javascript
复制
def split_into_classes(full_ds, batch_size, num_classes=10):
  class2indices = [[] for _ in range(num_classes)]
  for i, y in enumerate(full_ds.targets):
    class2indices[y].append(i)

  datasets = [torch.utils.data.Subset(full_ds, indices) for indices in class2indices]
  return [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in datasets]

问题.在训练期间,我可以看到我的联邦训练损失减少了。然而,我从来没有看到我的测试丢失/准确性提高(acc总是在10%左右)。,当我在训练/测试数据集上检查准确性时

  • 对于联邦数据集,准确度提高了。
  • 对于测试数据集,准确性没有提高。
  • (最令人惊讶的是)对于训练数据集,准确性并没有提高。请注意,此数据集本质上与联邦数据集相同,但不被拆分为类。检查代码是下面是
代码语言:javascript
复制
def epoch_summary(model, fed_loaders, true_train_loader, test_loader, frac):
  with torch.no_grad():
    train_len = 0
    train_loss, train_acc = 0, 0
    for train_loader in fed_loaders:
      cur_loss, cur_acc, cur_len = true_results(model, train_loader, frac)
      train_loss += cur_len * cur_loss
      train_acc += cur_len * cur_acc
      train_len += cur_len

    train_loss /= train_len
    train_acc /= train_len

    true_train_loss, true_train_acc, true_train_len = true_results(model, true_train_loader, frac)
    test_loss, test_acc, test_len = true_results(model, test_loader, frac)

  print("TrainLoss: {:.4f} TrainAcc: {:.2f} TrueLoss: {:.4f} TrueAcc: {:.2f} TestLoss: {:.4f} TestAcc: {:.2f}".format(
        train_loss, train_acc, true_train_loss, true_train_acc, test_loss, test_acc
        ), flush=True)

完整的代码可以找到这里。似乎不重要的事情:

  • 模型。对于Resnet模型和其他一些模型,我也遇到了同样的问题。
  • 我是如何把模型聚合起来的。我尝试使用state_dict或直接操作model.parameters(),没有效果。
  • 我是如何学习模特的。我尝试使用optim.SGD或直接更新param.data -= learning_rate * param.grad,没有效果。
  • 计算图我尝试过在所有可能的地方添加.detach().clone()with torch.no_grad(),但没有效果。

因此,我怀疑问题出在联邦数据本身(特别是考虑到奇怪的准确性结果)。有什么问题吗?

EN

回答 1

Stack Overflow用户

发布于 2021-01-31 11:25:29

CIFAR-10的10%基本上是随机的--你的模型随机输出标签,得到10%。

我认为问题在于你的“联合培训”策略:当你的子模型只看到一个标签的时候,你就不能期望你的子模型学到任何有意义的东西。这就是为什么培训数据洗牌

想想看:如果您的每个子模型都知道所有的权重都为零,那么除了在子模型所看到的类的条目中包含1的最后一个分类层的1向量之外,每个子模型的训练都是完美的(它对它所看到的所有训练样本都是正确的),但是平均模型是没有意义的。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65976605

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档