首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我想在稳定基线3中获得csv格式的片段奖励。

我想在稳定基线3中获得csv格式的片段奖励。
EN

Stack Overflow用户
提问于 2022-02-07 13:53:14
回答 1查看 435关注 0票数 0

我想在每一集之后检索数据,我已经阅读了您可以使用的文档,stable_baselines3.common.monitor.ResultsWriter,但是我不知道如何在代码中实现它。

代码语言:javascript
复制
import gym
import numpy as np
import Neural_Traffic_Env

import stable_baselines3
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, CallbackList, StopTrainingOnMaxEpisodes, EveryNTimesteps
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.monitor import Monitor, ResultsWriter

env = gym.make('NeuralTraffic-v1')
env = Monitor(env, filename="Monitor")

eval_callback = EvalCallback(env, best_model_save_path='./logs/best_model', log_path='./logs/', eval_freq=500)
checkpoint_callback = CheckpointCallback(save_freq=100, save_path='./saves/')
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=1000, verbose=1)
callback = CallbackList([callback_max_episodes, checkpoint_callback, eval_callback])

n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=1e6, log_interval=1, callback=callback)
model.save("ddpg")
env = model.get_env()

还有一个稳定的基线论坛,我也可以直接问我的问题吗?

EN

回答 1

Stack Overflow用户

发布于 2022-04-28 20:57:08

代码语言:javascript
复制
from stable_baselines3.common.logger import configure
from stable_baselines3.common.monitor import Monitor

tmp_path = "./tmp/sb3_log/"
# set up logger
new_logger = configure(tmp_path, ["stdout", "csv", "tensorboard"])

model = PPO('MlpPolicy', env, verbose=1)
model.set_logger(new_logger)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71019713

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档