我最近一直在学习TensorFlow Federated框架,但遇到了一个问题。我想在聚合之前查看发送到中央服务器的经过训练的客户端权重。
例如,在this教程中,我可以访问状态变量:
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
state, metrics = iterative_process.next(state, federated_train_data)
print('round {:2d}, metrics={}'.format(round_num, metrics))状态变量保存中心模型的权重(通过聚合客户端权重创建)。在TensorFlow联邦中聚合之前,有没有办法检查客户端发送的权重?
谢谢,感谢您的帮助。
发布于 2021-01-25 10:26:35
Federated Learning for Image Classification教程使用tff.learning.build_federated_averaging_process应用程序接口构建训练过程;查看该方法如何聚合的代码,并在此之前插入一些内容。然而!这个方法相当复杂,因为它是由tff.aggregators中的聚合方法参数化的。
要创建替代算法,我建议查看更简单的federated/tensorflow_federated/python/examples/simple_fedavg/实现。特别是,客户端更新的平均值是在Line 131上计算的。在此之前插入电源将是值得关注的地方。
另请参阅问题Collecting the weights returned by clients without aggregating them ,其中还解释了如何避免完全聚合。
https://stackoverflow.com/questions/65794925
复制相似问题