首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TFF:带有自定义数据集的自定义输入规范- TypeError:类型为“TensorSpec”的对象没有len()

TFF:带有自定义数据集的自定义输入规范- TypeError:类型为“TensorSpec”的对象没有len()
EN

Stack Overflow用户
提问于 2020-06-11 19:52:49
回答 1查看 296关注 0票数 1

问题:我需要在tff模拟中使用自定义数据集。我在tff/python/research/压缩示例"run_experiment.py“的基础上进行了构建。错误:

代码语言:javascript
复制
  File "B:\tools and software\Anaconda\envs\bookProjects\lib\site-packages\IPython\core\interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-47998fd56829>", line 1, in <module>
    runfile('B:/projects/openProjects/githubprojects/BotnetTrafficAnalysisFederaedLearning/anomaly-detection/train_v04.py', args=['--experiment_name=temp', '--client_batch_size=20', '--client_optimizer=sgd', '--client_learning_rate=0.2', '--server_optimizer=sgd', '--server_learning_rate=1.0', '--total_rounds=200', '--rounds_per_eval=1', '--rounds_per_checkpoint=50', '--rounds_per_profile=0', '--root_output_dir=B:/projects/openProjects/githubprojects/BotnetTrafficAnalysisFederaedLearning/anomaly-detection/logs/fed_out/'], wdir='B:/projects/openProjects/githubprojects/BotnetTrafficAnalysisFederaedLearning/anomaly-detection')
  File "B:\tools and software\PyCharm 2020.1\plugins\python\helpers\pydev\_pydev_bundle\pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "B:\tools and software\PyCharm 2020.1\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "B:/projects/openProjects/githubprojects/BotnetTrafficAnalysisFederaedLearning/anomaly-detection/train_v04.py", line 292, in <module>
    app.run(main)
  File "B:\tools and software\Anaconda\envs\bookProjects\lib\site-packages\absl\app.py", line 299, in run
    _run_main(main, args)
  File "B:\tools and software\Anaconda\envs\bookProjects\lib\site-packages\absl\app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "B:/projects/openProjects/githubprojects/BotnetTrafficAnalysisFederaedLearning/anomaly-detection/train_v04.py", line 285, in main
    train_main()
  File "B:/projects/openProjects/githubprojects/BotnetTrafficAnalysisFederaedLearning/anomaly-detection/train_v04.py", line 244, in train_main
    input_spec=input_spec),
  File "B:/projects/openProjects/githubprojects/BotnetTrafficAnalysisFederaedLearning/anomaly-detection/train_v04.py", line 193, in model_builder
    metrics=[tf.keras.metrics.Accuracy()]
  File "B:\tools and software\Anaconda\envs\bookProjects\lib\site-packages\tensorflow_federated\python\learning\keras_utils.py", line 125, in from_keras_model
    if len(input_spec) != 2:
TypeError: object of type 'TensorSpec' has no len()

突出显示:TypeError:类型为“TensorSpec”的对象没有len()

2:尝试过:我查看了对:TensorFlow联邦:如何为具有多个输入的模型编写输入规范?的响应,描述了生成自定义输入规范所需的内容。我可能错过了理解输入规范。

如果我不需要这样做,而且有更好的方法,请告诉我。

3:资料来源:

代码语言:javascript
复制
    df = get_train_data(sysarg)
    x_train, x_opt, x_test = np.split(df.sample(frac=1,
                                                random_state=17),
                                      [int(1 / 3 * len(df)), int(2 / 3 * len(df))])

    x_train, x_opt, x_test = create_scalar(x_opt, x_test, x_train)
    input_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor, tf.convert_to_tensor(x_train))
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-12 18:33:55

TFF的模型声明的输入规范与您可能期望的略有不同;它们通常同时将xy值作为参数(IE、数据和标签)。不幸的是,您正在访问该AttributeError,因为在这种情况下,ValueError TFF 会提高可能更有帮助。在此插入信息的执行部分:

代码语言:javascript
复制
The top-level structure in `input_spec` must contain exactly two elements,
as it must specify type information for both inputs to and predictions from the model.

在您的特定示例中,TLDR是:如果您也可以访问标签(下面的y_train),只需将您的input_spec定义更改为:

代码语言:javascript
复制
input_spec = tf.nest.map_structure(
    tf.TensorSpec.from_tensor,
    [tf.convert_to_tensor(x_train), tf.convert_to_tensor(y_train)])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62332459

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档