首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >不能加载保存策略(TF-代理)

不能加载保存策略(TF-代理)
EN

Stack Overflow用户
提问于 2020-12-25 13:03:28
回答 2查看 780关注 0票数 3

我用策略保护程序保存了经过培训的策略如下:

代码语言:javascript
复制
  tf_policy_saver = policy_saver.PolicySaver(agent.policy)
  tf_policy_saver.save(policy_dir)

我想继续用保存的政策训练。因此,我尝试使用保存的策略初始化培训,这会导致一些错误。

代码语言:javascript
复制
agent = dqn_agent.DqnAgent(
tf_env.time_step_spec(),
tf_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)

agent.initialize()

agent.policy=tf.compat.v2.saved_model.load(policy_dir)

错误:

代码语言:javascript
复制
  File "C:/Users/Rohit/PycharmProjects/pythonProject/waypoint.py", line 172, in <module>
agent.policy=tf.compat.v2.saved_model.load('waypoints\\Two_rewards')


File "C:\Users\Rohit\anaconda3\envs\btp36\lib\site-packages\tensorflow\python\training\tracking\tracking.py", line 92, in __setattr__
    super(AutoTrackable, self).__setattr__(name, value)
AttributeError: can't set attribute

我只想从每次再培训中节省时间。如何加载保存的策略并继续培训?

提前感谢

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-01-01 09:16:44

是的,如前所述,您应该使用Check指针来实现这一点--请看下面的示例代码。

代码语言:javascript
复制
agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
                                          policy=policy)

... # Train the agent

# Policy --> X
policy_checkpointer.save(global_step=epoch_counter.numpy())

稍后要重新加载策略时,只需运行相同的初始化代码。

代码语言:javascript
复制
agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y1, possibly Y1==Y depending on agent class you are using, if it's DQN
#               then they are different because of random initialization of network weights
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
                                          policy=policy)
# Policy --> X

创建时,policy_checkpointer将自动意识到是否存在任何预先存在的检查点。如果存在,它将在创建时自动更新正在跟踪的变量的值。

有几个笔记要做:

  1. 不仅可以保存策略,而且我建议这样做。TF-代理的Check指针对象非常灵活,例如:

代码语言:javascript
复制
train_checkpointer = common.Checkpointer(ckpt_dir=first/dir,
                                         agent=tf_agent,               # tf_agent.TFAgent
                                         train_step=train_step,        # tf.Variable
                                         epoch_counter=epoch_counter,  # tf.Variable
                                         metrics=metric_utils.MetricsGroup(
                                                 train_metrics, 'train_metrics'))

policy_checkpointer = common.Checkpointer(ckpt_dir=second/dir,
                                          policy=agent.policy)

rb_checkpointer = common.Checkpointer(ckpt_dir=third/dir,
                                      max_to_keep=1,
                                      replay_buffer=replay_buffer  # TFUniformReplayBuffer
                                      )

  1. 注意到,在DqnAgent的情况下,agent.policyagent.collect_policy本质上是QNetwork的包装器。下面的代码(查看有关策略变量状态的注释)

显示了这一点的含义。

代码语言:javascript
复制
agent = DqnAgent(...)
policy = agent.policy      # Random initial policy ---> X

dataset = replay_buffer.as_dataset(...)
for data in dataset:
   experience, _ = data
   loss_agent_info = agent.train(experience=experience)

# policy variable stores a trained Policy object ---> Y

这是因为TF中的张量在运行时是共享的。因此,当您用QNetwork更新代理的agent.train时,这些工具也会在policy变量的QNetwork中隐式更新。事实上,这并不是因为policy的张量被更新,而是因为它们与agent中的张量是相同的。

票数 5
EN

Stack Overflow用户

发布于 2020-12-29 20:57:43

为此,您应该查看Check指针。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65448313

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档