首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >DQN不收敛

DQN不收敛
EN

Stack Overflow用户
提问于 2022-10-10 12:17:20
回答 1查看 33关注 0票数 0

我正在尝试在openai-健身房的“月球着陆器”环境中实现DQN。

经过3000集的训练,它没有收敛的迹象。(作为比较,一个非常简单的政策梯度方法在2000集之后收敛)

我多次检查我的代码,但找不到哪里出了问题。我希望这里的人能指出问题的所在。下面是我的代码:

我使用一个简单的完全连接的网络:

代码语言:javascript
复制
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(8, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 4)
        )
    def forward(self, state):
        return self.main(state)

我在选择行动时使用贪婪的epsilon,而epsilon(从0.5开始)以指数方式减少加班费:

代码语言:javascript
复制
def sample_action(self, state):
        self.epsilon = self.epsilon * 0.99
        action_probs = self.network_train(state)
        random_number = random.random()
        if random_number < (1-self.epsilon):
            action = torch.argmax(action_probs, dim=-1).item()
        else:
            action = random.choice([0, 1, 2, 3])
        return action

在培训时,我使用重放缓冲区,批处理大小为64,以及渐变裁剪:

代码语言:javascript
复制
def learn(self):
        if len(self.buffer) >= BATCH_SIZE:
            self.learn_counter += 1
            transitions = self.buffer.sample(BATCH_SIZE)
            batch = Transition(*zip(*transitions))
            state = torch.from_numpy(np.concatenate(batch.state)).reshape(-1, 8)
            action = torch.tensor(batch.action).reshape(-1, 1)
            reward = torch.tensor(batch.reward).reshape(-1, 1)
            state_value = self.network_train(state).gather(1, action)
            next_state = torch.from_numpy(np.concatenate(batch.next_state)).reshape(-1, 8)
            next_state_value = self.network_target(next_state).max(1)[0].reshape(-1, 1).detach()
            loss = F.mse_loss(state_value.float(), (self.DISCOUNT_FACTOR*next_state_value + reward).float())
            self.optim.zero_grad()
            loss.backward()
            for param in self.network_train.parameters():
                param.grad.data.clamp_(-1, 1)
            self.optim.step()

我还使用目标网络,它的参数每100个时间步骤更新一次:

代码语言:javascript
复制
def update_network_target(self):
        if (self.learn_counter % 100) == 0:
            self.network_target.load_state_dict(self.network_train.state_dict())

顺便说一下,我使用的是亚当优化器和1e-3的LR。

EN

回答 1

Stack Overflow用户

发布于 2022-10-11 02:18:17

解决了。显然,更新目标网络的频率太高了。我把它设置为每10集,并解决了问题。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74014835

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档