首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tf_agents自定义time_step_spec

tf_agents自定义time_step_spec
EN

Stack Overflow用户
提问于 2019-10-12 04:41:23
回答 1查看 428关注 0票数 2

我正在修补tf-agents,但我在制作自定义time_step_spec时遇到了问题。

我正在尝试在健身房'Breakout-v0‘中训练一个tf-agent,我已经做了一个函数来预处理观察结果(游戏像素),现在我想修改time_step和time_step_spec来反映新的数据。

原始time_step_spec.observation()为:

代码语言:javascript
复制
BoundedTensorSpec(shape=(210, 160, 3), dtype=tf.uint8, name='observation', minimum=array(0, dtype=uint8), maximum=array(255, dtype=uint8))  

我的建议是:

代码语言:javascript
复制
BoundedTensorSpec(shape=(1, 165, 150), dtype=tf.float32, name='observation', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32))        

我已经成功地创建了一个自定义BoundedTensorSpec,并使用以下函数修改了time_step

代码语言:javascript
复制
processed_timestep = timestep._replace(observation=processed_obs)

现在我在理解如何修改time_step_spec时遇到了麻烦,我既不完全理解它是什么,也不知道如何修改它的组件。

原始time_step_spec为:

代码语言:javascript
复制
TimeStep(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)), observation=BoundedTensorSpec(shape=(210, 160, 3), dtype=tf.uint8, name='observation', minimum=array(0, dtype=uint8), maximum=array(255, dtype=uint8)))

它到底是什么结构?张量数组?

如何访问它的组件?

是否可以制作包含多个组件的自定义time_step_spec?(奖励、观察等)

我可以只修改一个组件吗?

EN

回答 1

Stack Overflow用户

发布于 2020-12-12 19:56:40

这可以通过使用环境很容易地解决。在TF-Agents中,环境需要遵循PyEnvironment类(然后用TFPyEnvironment包装它,以便并行执行多个env)。如果您已经定义了与这个类的规范相匹配的环境,那么您的环境应该已经为您提供了两个方法env.time_step_spec()env.action_spec()。只需将这两个文件提供给您的代理,就可以完成任务了。

要创建包含多个组件的time_step,您可以查看以下示例:

代码语言:javascript
复制
self._observation_spec = {'observations': array_spec.ArraySpec(shape=(100,), dtype=np.float64),
                          'legal_moves': array_spec.ArraySpec(shape=(self.num_moves(),), dtype=np.bool_),
                          'other_output1': array_spec.ArraySpec(shape=(10,), dtype=np.int64),
                          'other_output2': array_spec.ArraySpec(shape=(2, 5, 2), dtype=np.int64)}

这一行放在环境的__init__函数中。您不需要知道这一切实际上在做什么,但您可以看到它基本上是一个带有字符串键和ArraySpec值的字典。请注意,如果这样做,您将不得不定义一个observation_and_action_constraint_splitter来传递给代理,该代理丢弃所有不应该提供给代理输入的组件。下面是一个示例,说明如何在env._step方法中使用此观察值字典来构造适当的TimeStep:

代码语言:javascript
复制
observations_and_legal_moves = {'observations': current_agent_obs,
                                'legal_moves': np.ones(shape=(self.num_moves(), dtype=np.bool_),
                                'other_output1': np.ones(shape=(10, dtype=np.int64,
                                'other_output2': np.ones(shape=(2, 5, 2), dtype=np.int64}

ts.transition(observations_and_legal_moves, reward, self.gamma)

如果您的环境已经构建了张量输出,而您只是不能确定适当的TensorSpec应该是什么(获取它可能非常麻烦),那么您可以简单地调用tf.TensorSpec.from_tensor(tensor)来确定您必须定义什么。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58348203

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档