我在CIFAR-10数据集上训练一个神经网络(不管哪一个)。我使用的是联合学习:
def server_aggregate(server_model, client_models):
global_dict = server_model.state_dict()
for k in global_dict.keys():
global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)
server_model.load_state_dict(global_dict)
for model in client_models:
model.load_state_dict(server_model.state_dict())0只有与类0相对应的示例,等等。我这样做的方法如下:def split_into_classes(full_ds, batch_size, num_classes=10):
class2indices = [[] for _ in range(num_classes)]
for i, y in enumerate(full_ds.targets):
class2indices[y].append(i)
datasets = [torch.utils.data.Subset(full_ds, indices) for indices in class2indices]
return [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in datasets]问题.在训练期间,我可以看到我的联邦训练损失减少了。然而,我从来没有看到我的测试丢失/准确性提高(acc总是在10%左右)。和,当我在训练/测试数据集上检查准确性时
def epoch_summary(model, fed_loaders, true_train_loader, test_loader, frac):
with torch.no_grad():
train_len = 0
train_loss, train_acc = 0, 0
for train_loader in fed_loaders:
cur_loss, cur_acc, cur_len = true_results(model, train_loader, frac)
train_loss += cur_len * cur_loss
train_acc += cur_len * cur_acc
train_len += cur_len
train_loss /= train_len
train_acc /= train_len
true_train_loss, true_train_acc, true_train_len = true_results(model, true_train_loader, frac)
test_loss, test_acc, test_len = true_results(model, test_loader, frac)
print("TrainLoss: {:.4f} TrainAcc: {:.2f} TrueLoss: {:.4f} TrueAcc: {:.2f} TestLoss: {:.4f} TestAcc: {:.2f}".format(
train_loss, train_acc, true_train_loss, true_train_acc, test_loss, test_acc
), flush=True)完整的代码可以找到这里。似乎不重要的事情:
state_dict或直接操作model.parameters(),没有效果。optim.SGD或直接更新param.data -= learning_rate * param.grad,没有效果。.detach().clone()和with torch.no_grad(),但没有效果。因此,我怀疑问题出在联邦数据本身(特别是考虑到奇怪的准确性结果)。有什么问题吗?
发布于 2021-01-31 11:25:29
CIFAR-10的10%基本上是随机的--你的模型随机输出标签,得到10%。
我认为问题在于你的“联合培训”策略:当你的子模型只看到一个标签的时候,你就不能期望你的子模型学到任何有意义的东西。这就是为什么培训数据洗牌。
想想看:如果您的每个子模型都知道所有的权重都为零,那么除了在子模型所看到的类的条目中包含1的最后一个分类层的1向量之外,每个子模型的训练都是完美的(它对它所看到的所有训练样本都是正确的),但是平均模型是没有意义的。
https://stackoverflow.com/questions/65976605
复制相似问题