首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在Tf-agent中传递自定义环境的批处理大小

如何在Tf-agent中传递自定义环境的批处理大小
EN

Stack Overflow用户
提问于 2022-03-30 04:56:06
回答 1查看 221关注 0票数 2

我正在使用tf-代理库来构建上下文强盗。为此,我正在构建一个自定义环境。

我正在创建一个banditpyenvironment,并将其封装在TFpyenvironment中。

tfpyenvironment自动添加批处理大小维度(在观察规范中)。我需要在_observe和_apply_Action方法中考虑这个批处理大小维度。由于根据批次大小,我应该提供所需的观察(批次大小)数量(用于观察),并且根据批次大小,我应该采取批次大小的操作,并提供奖励(用于应用操作)。

我无法找到一个关于如何告诉tfenvironment批处理大小的示例,而不允许自动向第一个维度添加一个1。有谁能澄清一下

代码语言:javascript
复制
 def __init__(self, batch_size):

    self.batchsize=batch_size
    observation_spec = BoundedTensorSpec(
    (2,), np.int32, minimum=[1,1], maximum=[5,2], name= 'observation')
    action_spec = BoundedTensorSpec(
        shape=(), dtype=np.int32, minimum=0, maximum=6, name='action')


    super(SampleEnvironment, self).__init__(observation_spec, action_spec)

  def _observe(self):
    batch=[]
    for i in range(self.batchsize):
        each=tf.cast(np.array([np.random.choice([1,2,3,4,5]),np.random.choice([1,2])]), 'int32')
        batch.append(each)
    self.observation=np.array(batch)
    print("in observe",self.observation)
    return np.array(self.observation)

当我试图在上面的观察方法中解释批处理大小时(使用for循环作为批处理大小),tfenvironment将再次将1作为批处理大小添加到第一个维度。是否有一种方法可以自动告诉环境该批处理是3,而不是自动添加1。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-04-21 23:00:29

这可以使用BatchedPyEnvironment类来完成,如下面的示例所示。从上面看,强盗的环境是一个非批次的环境。

下面的SampleEnvironment是问题中所示的环带环境。

代码语言:javascript
复制
batch_size = 4
env= SampleEnvironment()
py_envs = [env for _ in range(0, batch_size)]
batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs)
tfenv = tf_py_environment.TFPyEnvironment(batched_env)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71671716

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档