首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在批处理学习中为tf-agent定义正确的形状

如何在批处理学习中为tf-agent定义正确的形状
EN

Stack Overflow用户
提问于 2019-11-01 18:44:05
回答 1查看 1.8K关注 0票数 6

我正在尝试使用tf_agents库训练一个具有批处理学习功能的DDPG代理。但是,我需要定义一个observation_spec和action_spec,它们说明代理将接收的张量的形状。我已经成功地创建了轨迹,我可以用它来提供数据,但是这些轨迹和代理本身的形状不匹配

我已经尝试使用agent定义来更改观察和操作规范。这是我的代理定义:

代码语言:javascript
复制
observation_spec = TensorSpec(shape = (1,),dtype =  tf.float32)
time_step_spec = time_step.time_step_spec(observation_spec)
action_spec = BoundedTensorSpec([1],tf.float32,minimum = -100, maximum = 100)
actor_network = ActorNetwork(
        input_tensor_spec=observation_spec,
        output_tensor_spec=action_spec,
        fc_layer_params=(100,200,100),
        name="ddpg_ActorNetwork"
    )
critic_net_input_specs = (observation_spec, action_spec)
critic_network = CriticNetwork(
    input_tensor_spec=critic_net_input_specs,
    observation_fc_layer_params=(200,100),
    joint_fc_layer_params=(100,200),
    action_fc_layer_params=None,
    name="ddpg_CriticNetwork"
)



agent = ddpg_agent.DdpgAgent(
    time_step_spec=time_step_spec,
    action_spec=action_spec,
    actor_network=actor_network,
    critic_network=critic_network,
    actor_optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    critic_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
)

这就是轨迹的样子

代码语言:javascript
复制
Trajectory(step_type=<tf.Variable 'Variable:0' shape=(1, 2) dtype=int32, numpy=array([[0, 1]], dtype=int32)>, observation=<tf.Variable 'Variable:0' shape=(1, 2) dtype=int32, numpy=array([[280, 280]], dtype=int32)>, action=<tf.Variable 'Variable:0' shape=(1, 2) dtype=float64, numpy=array([[nan,  0.]])>, policy_info=(), next_step_type=<tf.Variable 'Variable:0' shape=(1, 2) dtype=int32, numpy=array([[1, 1]], dtype=int32)>, reward=<tf.Variable 'Variable:0' shape=(1, 2) dtype=float64, numpy=array([[ -6.93147181, -12.14113521]])>, discount=<tf.Variable 'Variable:0' shape=(1, 2) dtype=float32, numpy=array([[0.9, 0.9]], dtype=float32)>)

我应该能够调用agent.train(轨迹),它可以工作,但我得到以下错误:

代码语言:javascript
复制
ValueError                                Traceback (most recent call last)
<ipython-input-325-bf162a5dc8d7> in <module>
----> 1 agent.train(trajs[0])

~/.local/lib/python3.7/site-packages/tf_agents/agents/tf_agent.py in train(self, experience, weights)
    213           "experience must be type Trajectory, saw type: %s" % type(experience))
    214 
--> 215     self._check_trajectory_dimensions(experience)
    216 
    217     if self._enable_functions:

~/.local/lib/python3.7/site-packages/tf_agents/agents/tf_agent.py in _check_trajectory_dimensions(self, experience)
    137     if not nest_utils.is_batched_nested_tensors(
    138         experience, self.collect_data_spec,
--> 139         num_outer_dims=self._num_outer_dims):
    140       debug_str_1 = tf.nest.map_structure(lambda tp: tp.shape, experience)
    141       debug_str_2 = tf.nest.map_structure(lambda spec: spec.shape,

~/.local/lib/python3.7/site-packages/tf_agents/utils/nest_utils.py in is_batched_nested_tensors(tensors, specs, num_outer_dims)
    142       'And spec_shapes:\n   %s' %
    143       (num_outer_dims, tf.nest.pack_sequence_as(tensors, tensor_shapes),
--> 144        tf.nest.pack_sequence_as(specs, spec_shapes)))
    145 
    146 

ValueError: Received a mix of batched and unbatched Tensors, or Tensors are not compatible with Specs.  num_outer_dims: 2.
Saw tensor_shapes:
   Trajectory(step_type=TensorShape([1, 2]), observation=TensorShape([1, 2]), action=TensorShape([1, 2]), policy_info=(), next_step_type=TensorShape([1, 2]), reward=TensorShape([1, 2]), discount=TensorShape([1, 2]))
And spec_shapes:
   Trajectory(step_type=TensorShape([]), observation=TensorShape([1]), action=TensorShape([1]), policy_info=(), next_step_type=TensorShape([]), reward=TensorShape([]), discount=TensorShape([]))
EN

回答 1

Stack Overflow用户

发布于 2020-12-11 23:46:39

这可以通过使用环境很容易地解决。在TF-Agents中,环境需要遵循PyEnvironment类(然后用TFPyEnvironment包装它,以便并行执行多个env)。如果您已经定义了与这个类的规范相匹配的环境,那么您的环境应该已经为您提供了两个方法env.time_step_spec()env.action_spec()。只需将这两个文件提供给您的代理,就可以完成任务了。

如果您希望从您的环境中获得多个输出,而这些输出不会全部进入您的代理,则会变得有点复杂。在这种情况下,您需要定义一个observation_and_action_constraint_splitter函数来传递给您的代理。有关如何正确使用TensorSpecs/ArraySpecs的更多详细信息,以及有效的示例,请参阅我的答案here

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58657866

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档