我试图创建一个具有有效/无效操作掩码的DqnAgent代理,根据这个职位,我应该为observation_and_action_constraint_splitter arg指定一个splitter_fn。根据tf_agents 文档
,splitter_fn就像:
def observation_and_action_constraint_splitter(observation):
return observation['network_input'], observation['constraint'] 在我看来,变量observation应该是由env.step(action).observation返回的数组,在我的例子中,它是一个形状为(56 )的数组(它是一个具有原始形状的扁平数组(14,4),每一行是每个选项的4个特征值,有5-14个选项,如果选择无效,相应的特性将全部为0),所以我这样写了我的splitter_fn:
def observation_and_action_constrain_splitter(observation):
print(observation)
temp = observation.reshape(14,-1)
action_mask = (~(temp==0).all(axis=1)).astype(np.int32).ravel()
return observation, tf.convert_to_tensor(action_mask, dtype=tf.int32)
agent = DqnAgent(
tf_time_step_spec,
tf_action_spec,
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=tf_common.element_wise_squared_loss,
train_step_counter=train_step_counter,
observation_and_action_constraint_splitter=observation_and_action_constrain_splitter
)但是,它在运行上述代码单元时返回了以下错误:
BoundedTensorSpec(shape=(56,), dtype=tf.float32, name='observation', minimum=array(-3.4028235e+38, dtype=float32), maximum=array(3.4028235e+38, dtype=float32))
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-213-07450ea5ba21> in <module>()
13 td_errors_loss_fn=tf_common.element_wise_squared_loss,
14 train_step_counter=train_step_counter,
---> 15 observation_and_action_constraint_splitter=observation_and_action_constrain_splitter
16 )
17
4 frames
<ipython-input-212-dbfee6076511> in observation_and_action_constrain_splitter(observation)
1 def observation_and_action_constrain_splitter(observation):
2 print(observation)
----> 3 temp = observation.reshape(14,-1)
4 action_mask = (~(temp==0).all(axis=1)).astype(np.int32).ravel()
5 return observation, tf.convert_to_tensor(action_mask, dtype=tf.int32)
AttributeError: 'BoundedTensorSpec' object has no attribute 'reshape'
In call to configurable 'DqnAgent' (<class 'tf_agents.agents.dqn.dqn_agent.DqnAgent'>)结果是,print(observation)返回一个BoundedTensorSpec对象,而不是数组或tf.Tensor对象。如何从BoundedTensorSpec创建我的动作掩码,它甚至不包含用于观察的数组。
提前感谢!
PS: tf_agents版本为0.12.0
发布于 2022-10-19 09:27:26
我也面临着同样的问题。我通过将函数observation_and_action_constrain_splitter传递给策略而不是DqnAgent来解决它。
agent = categorical_dqn_agent.CategoricalDqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
categorical_q_network=categorical_q_net,
optimizer=optimizer,
min_q_value=min_q_value,
max_q_value=max_q_value,
n_step_update=n_step_update,
td_errors_loss_fn=common.element_wise_squared_loss,
gamma=gamma,
train_step_counter=train_step_counter)
agent.initialize()
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
train_env.action_spec(),
observation_and_action_constraint_splitter=observation_and_action_constraint_splitter)希望这能帮到你。
https://stackoverflow.com/questions/71586439
复制相似问题