首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >无法加载tensorflow (tf-agent)保存模型

无法加载tensorflow (tf-agent)保存模型
EN

Stack Overflow用户
提问于 2019-06-11 00:35:32
回答 1查看 1.6K关注 0票数 7

我在下面的代码中创建一个tf代理DqnAgent:

代码语言:javascript
复制
tf_agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
    train_step_counter=train_step_counter

)

在训练过程中,我用

代码语言:javascript
复制
tf.saved_model.save(tf_agent, saved_models_path)

一旦经过训练,我想用

代码语言:javascript
复制
if tf.saved_model.contains_saved_model(saved_models_path):
    tf_agent = tf.saved_model.load(saved_models_path)

只有当saved_path中的文件夹包含一个文件夹时,此代码才会加载保存的模型,函数contains_saved_model(saved_models_path)返回True,因此加载了模型,但是存在一个excetion,程序崩溃:

代码语言:javascript
复制
Traceback (most recent call last):
    File "/home/claudino/Projetos/dino-tf-agents/dino_ia/model/agent.py", line 50, in <module>
        tf_agent = tf.saved_model.load(saved_models_path)
    File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 408, in load
        return load_internal(export_dir, tags)
    File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 432, in load_internal
        export_dir)
    File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 58, in __init__
        self._load_all()
    File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 168, in _load_all
        slot_variable = optimizer_object.add_slot(
    AttributeError: '_UserObject' object has no attribute 'add_slot'

    Process finished with exit code 1

我浏览了tensorflow代码,但找不到问题。有人能帮我吗?

我之所以使用tf-agents-nightly,是因为谷歌的冒牌源代码不适用于tf-agents的“稳定”版本(我不确定tf-agent是否真的稳定),并且尝试使用tensorflow 1.3和2.0.0-beta0的代码时,也会出现同样的问题。

EN

回答 1

Stack Overflow用户

发布于 2022-01-19 17:31:33

你试过TensorFlow 2.7吗?这通常有助于解决这个问题。

其他对我有用的方法是以这种方式加载模型(假设模型是keras/tf.keras模型):

代码语言:javascript
复制
try:
    model = tf.keras.models.load_model(model_dir)
except:
  load_options = tf.saved_model.LoadOptions(experimental_io_device= '/job:localhost')
  model = tf.saved_model.load(model_dir, options= load_options)

try子句将导致异常,因为load_model()需要一个keras_metadata.pb文件,在用saved_model.save()保存模型时缺少该文件。

但是,运行该子句将使tf.saved_model.load()在没有任何问题的情况下运行。在我不太理解的背景下,可能会发生某种交互,但这对我来说很有效,而且“no attribute add_slot”错误也没有出现。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56535020

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档