首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Q-learning模型没有改进

Q-learning模型没有改进
EN

Stack Overflow用户
提问于 2019-02-15 19:51:17
回答 1查看 190关注 0票数 1

我正在尝试解决openAI健身房的cartpole问题。通过Q学习。我想我误解了Q-learning的工作原理,因为我的模型没有改进。

我使用字典作为我的Q表。因此,我对每个观察结果进行“散列”(变成字符串)。并将其用作我的表中的关键字。

我的表中的每个键(观察值)都映射到另一个字典。其中我存储了在此状态下进行的每个移动及其相关的Q值。

如上所述,我的表中的条目可能如下所示:

代码语言:javascript
复制
'[''0.102'', ''1.021'', ''-0.133'', ''-1.574'']':
  0: 0.1

因此,在状态(Observation):'[''0.102'', ''1.021'', ''-0.133'', ''-1.574'']'中,已经使用Q值:0.01记录了一个动作:0

我的逻辑是不是错了?我真的不知道我的实现出了什么问题。

代码语言:javascript
复制
import gym
import random
import numpy as np

ENV = 'CartPole-v0'

env = gym.make(ENV)

class Qtable:
  def __init__(self):
    self.table = {}

  def update_table(self, obs, action, value):
    obs_hash = self.hash_obs(obs)

    # Update table with new observation
    if not obs_hash in self.table:
      self.table[obs_hash] = {}
      self.table[obs_hash][action] = value
    else:
      # Check if action has been recorded
      # If such, check if this value was better
      # If not, record new action for this obs
      if action in self.table[obs_hash]:
        if value > self.table[obs_hash][action]:
          self.table[obs_hash][action] = value
      else:
        self.table[obs_hash][action] = value

  def get_prev_value(self, obs, action):
    obs_hash = self.hash_obs(obs)
    if obs_hash in self.table:
      if action in self.table[obs_hash]:
        return self.table[obs_hash][action]
    return 0

  def get_max_value(self, obs):
    obs_hash = self.hash_obs(obs)
    if obs_hash in self.table:
      key = max(self.table[obs_hash])
      return self.table[obs_hash][key]
    return 0

  def has_action(self, obs):
    obs_hash = self.hash_obs(obs)
    if obs_hash in self.table:
      if len(self.table[obs_hash]) > 0:
        return True
    return False

  def get_best_action(self, obs):
    obs_hash = self.hash_obs(obs)
    if obs_hash in self.table:
      return max(self.table[obs_hash])

  # Makes a hashable entry of the observation
  def hash_obs(self, obs):
    return str(['{:.3f}'.format(i) for i in obs])

def play():

  q_table = Qtable()

  # Hyperparameters
  alpha   = 0.1
  gamma   = 0.6
  epsilon = 0.1
  episodes = 1000

  total = 0

  for i in range(episodes):

    done     = False
    prev_obs = env.reset()
    episode_reward = 0

    while not done:

      if random.uniform(0, 1) > epsilon and q_table.has_action(prev_obs):
        # Exploit learned values
        action = q_table.get_best_action(prev_obs)
      else:
        # Explore action space
        action = env.action_space.sample()

      # Render the environment
      #env.render()

      # Take a step
      obs, reward, done, info = env.step(action)

      if done:
        reward = -200

      episode_reward += reward

      old_value = q_table.get_prev_value(prev_obs, action)
      next_max  = q_table.get_max_value(obs)

      # Get the current sate value
      new_value = (1-alpha)*old_value + alpha*(reward + gamma*next_max)

      q_table.update_table(obs, action, new_value)

      prev_obs = obs

    total += episode_reward

  print("average", total/episodes)
  env.close()


play()
EN

回答 1

Stack Overflow用户

发布于 2019-02-16 02:51:00

我想我想通了。我误解了这部分new_value = (1-alpha)*old_value + alpha*(reward + gamma*next_max)

在这里,next_max是下一个状态的最好的移动。而不是(应该是)这个子树的最大值。

因此,将Q表实现为hashmap可能不是一个好主意。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54708749

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档