首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何用dataset_fn编写适当的tff.simulation.FilePerUserClientData?

如何用dataset_fn编写适当的tff.simulation.FilePerUserClientData?
EN

Stack Overflow用户
提问于 2020-11-22 20:42:11
回答 1查看 161关注 0票数 1

我目前正在使用tff实现联邦学习。

由于数据集非常大,我们将其拆分为许多npy文件,我目前正在使用tff.simulation.FilePerUserClientData将数据集放在一起。

这就是我想要做的

代码语言:javascript
复制
client_ids_to_files = dict()
for i in range(len(train_filepaths)):
  client_ids_to_files[str(i)] = train_filepaths[i]

def dataset_fn(filepath):
  print(filepath)
  dataSample = np.load(filepath)
  label = filepath[:-4].strip().split('_')[-1]
  return tf.data.Dataset.from_tensor_slices((dataSample, label))
train_filePerClient = tff.simulation.FilePerUserClientData(client_ids_to_files,dataset_fn)

但是,它似乎不能很好地工作,回调函数中的filepath是一个具有dtype字符串的张量。filepath的值是:Tensor("hash_table_Lookup/LookupTableFindV2:0", shape=(), dtype=string)

张量似乎包含错误消息,而不是包含client_ids_to_files中的路径?我做错了什么吗?如何使用npy文件为tff.simulation.FilePerUserClientData编写适当的tff.simulation.FilePerUserClientData

编辑:这是错误日志。错误本身实际上与我所问的问题无关,但您可以找到被调用的函数:

代码语言:javascript
复制
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-46-e61ddbe06cdb> in <module>
     22     return tf.data.Dataset.from_tensor_slices(filepath)
     23 
---> 24 train_filePerClient = tff.simulation.FilePerUserClientData(client_ids_to_files,dataset_fn)
     25 

~/fasttext-venv/lib/python3.6/site-packages/tensorflow_federated/python/simulation/file_per_user_client_data.py in __init__(self, client_ids_to_files, dataset_fn)
     52       return dataset_fn(client_ids_to_files[client_id])
     53 
---> 54     @computations.tf_computation(tf.string)
     55     def dataset_computation(client_id):
     56       client_ids_to_path = tf.lookup.StaticHashTable(

~/fasttext-venv/lib/python3.6/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py in __call__(self, tff_internal_types, *args)
    405                                             parameter_type)
    406       args, kwargs = unpack_arguments_fn(next(wrapped_fn_generator))
--> 407       result = fn_to_wrap(*args, **kwargs)
    408       if result is None:
    409         raise ComputationReturnedNoneError(fn_to_wrap)

~/fasttext-venv/lib/python3.6/site-packages/tensorflow_federated/python/simulation/file_per_user_client_data.py in dataset_computation(client_id)
     59               list(client_ids_to_files.values())), '')
     60       client_path = client_ids_to_path.lookup(client_id)
---> 61       return dataset_fn(client_path)
     62 
     63     self._create_tf_dataset_fn = create_dataset_for_filename_fn

<ipython-input-46-e61ddbe06cdb> in dataset_fn(filepath)
     17         filepath = tf.print(filepath)
     18     print(filepath)
---> 19     dataSample = np.load(filepath)
     20     print(dataSample)
     21     label = filepath[:-4].strip().split('_')[-1]

~/fasttext-venv/lib/python3.6/site-packages/numpy/lib/npyio.py in load(file, mmap_mode, allow_pickle, fix_imports, encoding)
    426         own_fid = False
    427     else:
--> 428         fid = open(os_fspath(file), "rb")
    429         own_fid = True
    430 

TypeError: expected str, bytes or os.PathLike object, not Operation
EN

回答 1

Stack Overflow用户

发布于 2020-11-25 22:19:23

问题是dataset_fn必须可序列化为tf.Graph。这是必需的,因为TFF使用TensorFlow图在远程机器上执行逻辑。

在这种情况下,np.load不能序列化到图操作。看起来,numpy是用来从磁盘加载到内存中的,然后tf.data.Dataset.from_tensor_slices用于从内存中的对象创建数据集?我可能可以以不同的格式保存文件,并使用本机tf.data.Dataset操作从磁盘加载,而不是使用Python。一些选项可以是tf.data.TFRecordDatasettf.data.TextLineDatasettf.data.experimental.SqlDataset

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64959332

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档