首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >访问和修改服务器上的客户端发送的权重tensorflow联合

访问和修改服务器上的客户端发送的权重tensorflow联合
EN

Stack Overflow用户
提问于 2021-10-18 15:51:21
回答 1查看 153关注 0票数 0

我正在使用Tensorflow Federated,但在读取客户端更新后尝试在服务器上执行某些操作时,我实际上遇到了一些问题。

这就是函数

代码语言:javascript
复制
@tff.federated_computation(federated_server_state_type,
                           federated_dataset_type)
def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.
    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.
    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
    tf.print("run_one_round")
    server_message = tff.federated_map(server_message_fn, server_state)
    server_message_at_client = tff.federated_broadcast(server_message)

    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, server_message_at_client))

    weight_denom = client_outputs.client_weight


    tf.print(client_outputs.weights_delta)
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
    round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

    return server_state, round_loss_metric, client_outputs.weights_delta.comp

在使用tff.federated_mean之前,我想打印client_outputs.weights_delta并对客户端发送给服务器的权重执行一些操作,但我不知道如何做到这一点。

当我尝试打印时,我得到这样的结果

Call(Intrinsic('federated_map', FunctionType(StructType([FunctionType(StructType([('weights_delta', StructType([TensorType(tf.float32, [5, 5, 1, 32]), TensorType(tf.float32, [32]), ....]) as ClientOutput, PlacementLiteral('clients'), False)))]))

有没有办法修改这些元素?

我尝试使用return client_outputs.weights_delta.comp在main中进行修改(我可以这样做),然后我尝试调用一个新方法来执行服务器更新的其余操作,但错误是:

AttributeError: 'IterativeProcess' object has no attribute 'calculate_federated_mean',其中calculate_federated_mean是我创建的新函数的名称。

这是主要的:

代码语言:javascript
复制
 for round_num in range(FLAGS.total_rounds):
        print("--------------------------------------------------------")
        sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False)
        sampled_train_data = [train_data.create_tf_dataset_for_client(client) for client in sampled_clients]

        server_state, train_metrics, value_comp = iterative_process.next(server_state, sampled_train_data)

        print(f'Round {round_num}')
        print(f'\tTraining loss: {train_metrics:.4f}')
        if round_num % FLAGS.rounds_per_eval == 0:
            server_state.model_weights.assign_weights_to(keras_model)
            accuracy = evaluate(keras_model, test_data)
            print(f'\tValidation accuracy: {accuracy * 100.0:.2f}%')
            tf.print(tf.compat.v2.summary.scalar("Accuracy", accuracy * 100.0, step=round_num))

基于github的simple_fedavg项目,Tensorflow联合simple_fedavg作为基础项目。

编辑1:

所以,多亏了@Jakub Konecny,我取得了一些进步,但我发现了一个我实际上并不理解的新问题。

所以,如果我使用这个client_update

代码语言:javascript
复制
@tf.function
def client_update(model, dataset, server_message, client_optimizer):
    """Performans client local training of `model` on `dataset`.
    Args:
      model: A `tff.learning.Model`.
      dataset: A 'tf.data.Dataset'.
      server_message: A `BroadcastMessage` from server.
      client_optimizer: A `tf.keras.optimizers.Optimizer`.
    Returns:
      A 'ClientOutput`.
    """
    model_weights = model.weights
    initial_weights = server_message.model_weights
    tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                          initial_weights)

    num_examples = tf.constant(0, dtype=tf.int32)
    loss_sum = tf.constant(0, dtype=tf.float32)
    # Explicit use `iter` for dataset is a trick that makes TFF more robust in
    # GPU simulation and slightly more performant in the unconventional usage
    # of large number of small datasets.
    for batch in iter(dataset):
        with tf.GradientTape() as tape:
            outputs = model.forward_pass(batch)
        grads = tape.gradient(outputs.loss, model_weights.trainable)
        client_optimizer.apply_gradients(zip(grads, model_weights.trainable))
        batch_size = tf.shape(batch['x'])[0]
        num_examples += batch_size
        loss_sum += outputs.loss * tf.cast(batch_size, tf.float32)

    weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                          model_weights.trainable,
                                          initial_weights.trainable)


    client_weight = tf.cast(num_examples, tf.float32)

    import sparse_ternary_compression
    sparsification_rate = 1
    testing_new = []
    #TODO Da non applicare alle bias
    for tensor in weights_delta:
        testing_new.append(sparse_ternary_compression.stc_compression(tensor, sparsification_rate))

    return ClientOutput(weights_delta, client_weight, loss_sum / client_weight, testing_new)

使用这些函数:

代码语言:javascript
复制
@tff.tf_computation
def stc_compression(original_tensor, sparsification_percentage):
    original_shape = tf.shape(original_tensor)
    tensor = tf.reshape(original_tensor, [-1])
    sparsification_percentage = tf.cast(sparsification_percentage, tf.float64)
    sparsification_rate = tf.size(tensor) / 100 * sparsification_percentage
    sparsification_rate = tf.cast(sparsification_rate, tf.int32)
    new_shape = tensor.get_shape().as_list()
    if sparsification_rate == 0:
        sparsification_rate = 1
    mask = tf.cast(tf.abs(tensor) >= tf.math.top_k(tf.abs(tensor), sparsification_rate)[0][-1], tf.float32)
    inv_mask = tf.cast(tf.abs(tensor) < tf.math.top_k(tf.abs(tensor), sparsification_rate)[0][-1], tf.float32)
    tensor_masked = tf.multiply(tensor, mask)
    sparsification_rate = tf.cast(sparsification_rate, tf.float32)
    average = tf.reduce_sum(tf.abs(tensor_masked)) / sparsification_rate
    compressed_tensor = tf.add(tf.multiply(average, mask) * tf.sign(tensor), tf.multiply(tensor_masked, inv_mask))
    negatives = tf.where(compressed_tensor < 0)
    positives = tf.where(compressed_tensor > 0)
    return negatives, positives, average, original_shape, new_shape

@tff.tf_computation
def stc_decompression(negatives, positives, average, original_shape, new_shape):
    decompressed_tensor = tf.zeros(new_shape, tf.float32)
    average_values_negative = tf.fill([tf.shape(negatives)[0], ], -average)
    average_values_positive = tf.fill([tf.shape(positives)[0], ], average)
    decompressed_tensor = tf.tensor_scatter_nd_update(decompressed_tensor, negatives, average_values_negative)
    decompressed_tensor = tf.tensor_scatter_nd_update(decompressed_tensor, positives, average_values_positive)
    decompressed_tensor = tf.reshape(decompressed_tensor, original_shape)
    return decompressed_tensor


@tff.tf_computation
def testing_new_list(list):
    testing = []
    for index in list:
        testing.append(
            stc_decompression(index[0], index[1],
                              index[2], index[3],
                              index[4]))

    return testing

run_one_round函数中像这样调用

代码语言:javascript
复制
@tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.
        Args:
          server_state: A `ServerState`.
          federated_dataset: A federated `tf.data.Dataset` with placement
            `tff.CLIENTS`.
        Returns:
          A tuple of updated `ServerState` and `tf.Tensor` of average loss.
        """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, server_message_at_client))

        weight_denom = client_outputs.client_weight

        import sparse_ternary_compression
        testing = tff.federated_map(sparse_ternary_compression.testing_new_list, client_outputs.test)

        # round_model_delta indica i pesi che vengono usati su server_update. Quindi è quello che va cambiato
        round_model_delta = tff.federated_mean(
            client_outputs.weights_delta, weight=weight_denom)

        server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
        round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

        return server_state, round_loss_metric, testing

但是我得到了一个例外

代码语言:javascript
复制
Traceback (most recent call last):
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/main.py", line 214, in <module>
    app.run(main)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/main.py", line 171, in main
    iterative_process = simple_fedavg_tff.build_federated_averaging_process(
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tff.py", line 95, in build_federated_averaging_process
    def client_update_fn(tf_dataset, server_message):
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 478, in __call__
    wrapped_func = self._strategy(
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 216, in __call__
    result = fn_to_wrap(*args, **kwargs)
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tff.py", line 98, in client_update_fn
    return client_update(model, tf_dataset, server_message, client_optimizer)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 933, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 763, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3050, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3444, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3279, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 999, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 672, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.autograph.pyct.error_utils.KeyError: in user code:

        /mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tf.py:222 client_update  *
            testing_new.append(sparse_ternary_compression.stc_compression(tensor, sparsification_rate))
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/computation/function_utils.py:608 __call__  *
            return concrete_fn(packed_arg)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/computation/function_utils.py:525 __call__  *
            return context.invoke(self, arg)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context.py:54 invoke  *
            init_op, result = (
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/utils/tensorflow_utils.py:1097 deserialize_and_call_tf_computation  *
            input_map = {
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3931 get_tensor_by_name  **
            return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3755 as_graph_element
            return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3795 _as_graph_element_locked
            raise KeyError("The name %s refers to a Tensor which does not "
    
        KeyError: "The name 'sub:0' refers to a Tensor which does not exist. The operation, 'sub', does not exist in the graph."
    
    
    Process finished with exit code 1

编辑2:

通过将函数stc_compressionstc_decompression的修饰器从tff.tf_computation更改为tf.function,修复了上述问题。现在似乎工作得很好,因为如果我打印从run_one_round内部的return server_state, round_loss_metric, testing中获得的变量testing,我就会从一开始就得到我想要的权重。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-10-19 09:04:47

我想我刚才写的另一个问题的this reply也适用于这里。

当您打印client_outputs.weights_delta时,您将获得另一个计算结果的抽象表示,这主要是TFF的内部实现细节。

使用TensorFlow代码编写一个tff.tf_computation-decorated方法,该方法执行所需的修改,然后使用tff.federated_map运算符调用该方法,您要在其中打印值。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69619028

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档