我正在使用tf-代理库来构建上下文强盗。为此,我正在构建一个自定义环境。
我正在创建一个banditpyenvironment,并将其封装在TFpyenvironment中。
tfpyenvironment自动添加批处理大小维度(在观察规范中)。我需要在_observe和_apply_Action方法中考虑这个批处理大小维度。由于根据批次大小,我应该提供所需的观察(批次大小)数量(用于观察),并且根据批次大小,我应该采取批次大小的操作,并提供奖励(用于应用操作)。
我无法找到一个关于如何告诉tfenvironment批处理大小的示例,而不允许自动向第一个维度添加一个1。有谁能澄清一下
def __init__(self, batch_size):
self.batchsize=batch_size
observation_spec = BoundedTensorSpec(
(2,), np.int32, minimum=[1,1], maximum=[5,2], name= 'observation')
action_spec = BoundedTensorSpec(
shape=(), dtype=np.int32, minimum=0, maximum=6, name='action')
super(SampleEnvironment, self).__init__(observation_spec, action_spec)
def _observe(self):
batch=[]
for i in range(self.batchsize):
each=tf.cast(np.array([np.random.choice([1,2,3,4,5]),np.random.choice([1,2])]), 'int32')
batch.append(each)
self.observation=np.array(batch)
print("in observe",self.observation)
return np.array(self.observation)当我试图在上面的观察方法中解释批处理大小时(使用for循环作为批处理大小),tfenvironment将再次将1作为批处理大小添加到第一个维度。是否有一种方法可以自动告诉环境该批处理是3,而不是自动添加1。
发布于 2022-04-21 23:00:29
这可以使用BatchedPyEnvironment类来完成,如下面的示例所示。从上面看,强盗的环境是一个非批次的环境。
下面的SampleEnvironment是问题中所示的环带环境。
batch_size = 4
env= SampleEnvironment()
py_envs = [env for _ in range(0, batch_size)]
batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs)
tfenv = tf_py_environment.TFPyEnvironment(batched_env)https://stackoverflow.com/questions/71671716
复制相似问题