首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tff.simulation.datasets.ClientData从CSV文件构建联邦学习模型

tff.simulation.datasets.ClientData从CSV文件构建联邦学习模型
EN

Stack Overflow用户
提问于 2022-02-26 01:20:40
回答 1查看 337关注 0票数 2

我正在使用自己的数据集构建一个联邦学习模型。我的目标是建立一个多分类模型。数据显示在单独的8个CSV文件中。

我遵循这个帖子中的说明,如下面的代码所示。

代码语言:javascript
复制
dataset_paths = {
  'client_0': '/content/ds1.csv',
  'client_1': '/content/ds2.csv',
  'client_2': '/content/ds3.csv',
  'client_3': '/content/ds4.csv',
  'client_4': '/content/ds5.csv',
}

def create_tf_dataset_for_client_fn(id):
   path = dataset_paths.get(id)
   if path is None:
     raise ValueError(f'No dataset for client {id}')
   return tf.data.Dataset.TextLineDataset(path)

source = tff.simulation.datasets.ClientData.from_clients_and_fn(
  dataset_paths.keys(), create_tf_dataset_for_client_fn)

但它给了我这个错误

代码语言:javascript
复制
AttributeError: type object 'ClientData' has no attribute 'from_clients_and_fn'

我读了这个文档,发现.datasets方法可以工作,所以我用.from_clients_and_fn代替了它,错误消失了,但是我不知道它是否正确,接下来是什么?

我的问题是:

  1. 这是将数据上传到客户端的正确方法吗?
  2. 如果无法单独上传CSV文件,我是否可以将所有数据合并到一个CSV文件中,然后将它们视为非IID数据并进行相应的培训?我需要一些指导

并预先感谢

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-02-26 15:23:21

在这个设置中,考虑tff.simulation.datasets.FilePerUserClientDatatf.data.experimental.CsvDataset可能很有用。

这可能看起来像(这使得一些测试CSV数据,为了这个例子,您使用的数据集可能有其他形状):

代码语言:javascript
复制
dataset_paths = {
  'client_0': '/content/ds1.csv',
  'client_1': '/content/ds2.csv',
  'client_2': '/content/ds3.csv',
  'client_3': '/content/ds4.csv',
  'client_4': '/content/ds5.csv',
}

# Create some test data for the sake of the example,
# normally we wouldn't do this.
for i, (id, path) in enumerate(dataset_paths.items()):
  with open(path, 'w') as f:
    for _ in range(i):
      f.write(f'test,0.0,{i}\n')

# Values that will fill in any CSV cell if its missing,
# must match the dtypes above.
record_defaults = ['', 0.0, 0]

@tf.function
def create_tf_dataset_for_client_fn(dataset_path):
   return tf.data.experimental.CsvDataset(
     dataset_path, record_defaults=record_defaults )

source = tff.simulation.datasets.FilePerUserClientData(
  dataset_paths, create_tf_dataset_for_client_fn)


print(source.client_ids)
>>> ['client_0', 'client_1', 'client_2', 'client_3', 'client_4']

for x in source.create_tf_dataset_for_client('client_3'):
  print(x)
>>> (<tf.Tensor: shape=(), dtype=string, numpy=b'test'>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=int32, numpy=3>)
>>> (<tf.Tensor: shape=(), dtype=string, numpy=b'test'>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=int32, numpy=3>)
>>> (<tf.Tensor: shape=(), dtype=string, numpy=b'test'>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=int32, numpy=3>)

可以将所有数据连接到一个CSV中,但是每个记录仍然需要一些标识符来指示哪一行属于哪个客户端。将所有行混合在一起而没有任何类型的每个客户端映射将类似于标准的集中培训,而不是联邦学习。

一旦一个CSV拥有所有的行,并且可能有一个具有client_id值的列,人们可能可以使用tf.data.Dataset.filter()只生成属于特定客户端的行。但是,这可能不是特别有效,因为它将迭代每个客户端的整个全局数据集,而不仅仅是该客户端的示例。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71273332

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档