首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从tensorflow_federated中提取聚集梯度?

如何从tensorflow_federated中提取聚集梯度?
EN

Stack Overflow用户
提问于 2022-07-31 14:59:52
回答 1查看 32关注 0票数 0

我有一个像这样的tensorflow模型

代码语言:javascript
复制
def input_spec():
return(
      tf.TensorSpec([None, 122], tf.float64),
      tf.TensorSpec([None, 5],tf.uint8))

def model_fn():
    model=tf.keras.models.Sequential([
          tf.keras.layers.Dense(64,input_shape=(122,)),
          tf.keras.layers.Dense(32,activation='relu'),
          tf.keras.layers.Dropout(.15),
          tf.keras.layers.Dense(32,activation='relu'),
          tf.keras.layers.Dropout(.15),
          tf.keras.layers.Dense(32,activation='relu'),
          tf.keras.layers.Dropout(.15),
          tf.keras.layers.Dense(5,activation='softmax')])
    return tff.learning.from_keras_model(
           model,
           input_spec=input_spec(),
           loss=tf.keras.losses.CategoricalCrossentropy(),
           metrics=[tf.keras.metrics.CategoricalAccuracy()])

我在下面设置了iterative_process

代码语言:javascript
复制
iterative_process=tff.learning.algorithms.build_weighted_fed_avg(
                  model_fn,
                  client_optimizer_fn=lambda: tf.keras.optimizers.Adam(),
                  server_optimizer_fn=lambda: tf.keras.optimizers.Adam())

我已经了解到,我们可以通过model_weights=iterative_process.get_model_weights(state)获得聚合权重,但我仍然需要知道如何获得聚合梯度。

EN

回答 1

Stack Overflow用户

发布于 2022-07-31 20:52:42

在运行训练过程时,在某些情况下可以通过从结束时减去回合开始时的状态来计算聚集(伪)梯度。在上面的代码片段中,这并不是完全正确的,因为服务器优化器是Adam (如果我没记错的话,它执行一些伪梯度的重新缩放,以及添加一个动量累加器)。

如果您只是在服务器上使用梯度下降,学习速率为1(传统上是FedAvg的默认设置),下面的代码应该会为您提供这个聚集的伪梯度:

代码语言:javascript
复制
pseudo_grad = tf.nest.map_structure(
        lambda x, y: x - y, previous_state.global_model_weights.trainable,
        state.global_model_weights.trainable)

可以通过将聚合器参数封装到build_weighted_fed_avg调用的添加这些调试度量的聚合器中来访问用于调试的一些有用的度量,如果这是这里的底层目标。另外,如果您实现了一个tff.templates.AggregationProcess,输出其结果的measurements字段中的平均伪梯度,那么您还可以直接读取这些值;这些值应该由FedAvg实现的其余部分直接传递。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73184351

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档