我目前正在使用tff实现联邦学习。
由于数据集非常大,我们将其拆分为许多npy文件,我目前正在使用tff.simulation.FilePerUserClientData将数据集放在一起。
这就是我想要做的
client_ids_to_files = dict()
for i in range(len(train_filepaths)):
client_ids_to_files[str(i)] = train_filepaths[i]
def dataset_fn(filepath):
print(filepath)
dataSample = np.load(filepath)
label = filepath[:-4].strip().split('_')[-1]
return tf.data.Dataset.from_tensor_slices((dataSample, label))
train_filePerClient = tff.simulation.FilePerUserClientData(client_ids_to_files,dataset_fn)但是,它似乎不能很好地工作,回调函数中的filepath是一个具有dtype字符串的张量。filepath的值是:Tensor("hash_table_Lookup/LookupTableFindV2:0", shape=(), dtype=string)
张量似乎包含错误消息,而不是包含client_ids_to_files中的路径?我做错了什么吗?如何使用npy文件为tff.simulation.FilePerUserClientData编写适当的tff.simulation.FilePerUserClientData?
编辑:这是错误日志。错误本身实际上与我所问的问题无关,但您可以找到被调用的函数:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-46-e61ddbe06cdb> in <module>
22 return tf.data.Dataset.from_tensor_slices(filepath)
23
---> 24 train_filePerClient = tff.simulation.FilePerUserClientData(client_ids_to_files,dataset_fn)
25
~/fasttext-venv/lib/python3.6/site-packages/tensorflow_federated/python/simulation/file_per_user_client_data.py in __init__(self, client_ids_to_files, dataset_fn)
52 return dataset_fn(client_ids_to_files[client_id])
53
---> 54 @computations.tf_computation(tf.string)
55 def dataset_computation(client_id):
56 client_ids_to_path = tf.lookup.StaticHashTable(
~/fasttext-venv/lib/python3.6/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py in __call__(self, tff_internal_types, *args)
405 parameter_type)
406 args, kwargs = unpack_arguments_fn(next(wrapped_fn_generator))
--> 407 result = fn_to_wrap(*args, **kwargs)
408 if result is None:
409 raise ComputationReturnedNoneError(fn_to_wrap)
~/fasttext-venv/lib/python3.6/site-packages/tensorflow_federated/python/simulation/file_per_user_client_data.py in dataset_computation(client_id)
59 list(client_ids_to_files.values())), '')
60 client_path = client_ids_to_path.lookup(client_id)
---> 61 return dataset_fn(client_path)
62
63 self._create_tf_dataset_fn = create_dataset_for_filename_fn
<ipython-input-46-e61ddbe06cdb> in dataset_fn(filepath)
17 filepath = tf.print(filepath)
18 print(filepath)
---> 19 dataSample = np.load(filepath)
20 print(dataSample)
21 label = filepath[:-4].strip().split('_')[-1]
~/fasttext-venv/lib/python3.6/site-packages/numpy/lib/npyio.py in load(file, mmap_mode, allow_pickle, fix_imports, encoding)
426 own_fid = False
427 else:
--> 428 fid = open(os_fspath(file), "rb")
429 own_fid = True
430
TypeError: expected str, bytes or os.PathLike object, not Operation发布于 2020-11-25 22:19:23
问题是dataset_fn必须可序列化为tf.Graph。这是必需的,因为TFF使用TensorFlow图在远程机器上执行逻辑。
在这种情况下,np.load不能序列化到图操作。看起来,numpy是用来从磁盘加载到内存中的,然后tf.data.Dataset.from_tensor_slices用于从内存中的对象创建数据集?我可能可以以不同的格式保存文件,并使用本机tf.data.Dataset操作从磁盘加载,而不是使用Python。一些选项可以是tf.data.TFRecordDataset、tf.data.TextLineDataset或tf.data.experimental.SqlDataset。
https://stackoverflow.com/questions/64959332
复制相似问题