首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >澄清observation_and_action_constraint_splitter在DqnAgent试剂中的应用

澄清observation_and_action_constraint_splitter在DqnAgent试剂中的应用
EN

Stack Overflow用户
提问于 2022-03-23 11:29:34
回答 1查看 65关注 0票数 0

我试图创建一个具有有效/无效操作掩码的DqnAgent代理,根据这个职位,我应该为observation_and_action_constraint_splitter arg指定一个splitter_fn。根据tf_agents 文档

splitter_fn就像:

代码语言:javascript
复制
def observation_and_action_constraint_splitter(observation):
  return observation['network_input'], observation['constraint'] 

在我看来,变量observation应该是由env.step(action).observation返回的数组,在我的例子中,它是一个形状为(56 )的数组(它是一个具有原始形状的扁平数组(14,4),每一行是每个选项的4个特征值,有5-14个选项,如果选择无效,相应的特性将全部为0),所以我这样写了我的splitter_fn:

代码语言:javascript
复制
def observation_and_action_constrain_splitter(observation):
     print(observation)
     temp = observation.reshape(14,-1)
     action_mask = (~(temp==0).all(axis=1)).astype(np.int32).ravel()
     return observation, tf.convert_to_tensor(action_mask, dtype=tf.int32)

agent = DqnAgent(
    tf_time_step_spec,
    tf_action_spec,
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=tf_common.element_wise_squared_loss,
    train_step_counter=train_step_counter,
    observation_and_action_constraint_splitter=observation_and_action_constrain_splitter
)

但是,它在运行上述代码单元时返回了以下错误:

代码语言:javascript
复制
BoundedTensorSpec(shape=(56,), dtype=tf.float32, name='observation', minimum=array(-3.4028235e+38, dtype=float32), maximum=array(3.4028235e+38, dtype=float32))
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-213-07450ea5ba21> in <module>()
     13     td_errors_loss_fn=tf_common.element_wise_squared_loss,
     14     train_step_counter=train_step_counter,
---> 15     observation_and_action_constraint_splitter=observation_and_action_constrain_splitter
     16     )
     17 

4 frames
<ipython-input-212-dbfee6076511> in observation_and_action_constrain_splitter(observation)
      1 def observation_and_action_constrain_splitter(observation):
      2      print(observation)
----> 3      temp = observation.reshape(14,-1)
      4      action_mask = (~(temp==0).all(axis=1)).astype(np.int32).ravel()
      5      return observation, tf.convert_to_tensor(action_mask, dtype=tf.int32)

AttributeError: 'BoundedTensorSpec' object has no attribute 'reshape'
  In call to configurable 'DqnAgent' (<class 'tf_agents.agents.dqn.dqn_agent.DqnAgent'>)

结果是,print(observation)返回一个BoundedTensorSpec对象,而不是数组或tf.Tensor对象。如何从BoundedTensorSpec创建我的动作掩码,它甚至不包含用于观察的数组。

提前感谢!

PS: tf_agents版本为0.12.0

EN

回答 1

Stack Overflow用户

发布于 2022-10-19 09:27:26

我也面临着同样的问题。我通过将函数observation_and_action_constrain_splitter传递给策略而不是DqnAgent来解决它。

代码语言:javascript
复制
agent = categorical_dqn_agent.CategoricalDqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    categorical_q_network=categorical_q_net,
    optimizer=optimizer,
    min_q_value=min_q_value,
    max_q_value=max_q_value,
    n_step_update=n_step_update,
    td_errors_loss_fn=common.element_wise_squared_loss,
    gamma=gamma,
    train_step_counter=train_step_counter)
agent.initialize()

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec(),
                                                observation_and_action_constraint_splitter=observation_and_action_constraint_splitter)

希望这能帮到你。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71586439

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档