首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在Unity中实现SARSA

在Unity中实现SARSA
EN

Stack Overflow用户
提问于 2018-06-08 20:43:35
回答 1查看 181关注 0票数 0

因此,我使用以下代码在Unity中实现Q-learning:

代码语言:javascript
复制
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;

namespace QLearner
{
    public class QLearnerScript
    {
        List<float[]> QStates; // Q states over time
        List<float[]> QActions; // Q actions over time

        float[] initialState;
        int initialActionIndex;
        float[] outcomeState;
        float outcomeActionValue;
        bool firstIteration;

        int possibleActions;

        float learningRate; // denoted by alpha
        float discountFactor; // denoted by gamma

        float simInterval;


        System.Random r = new System.Random();

        public int main(float[] currentState, float reward)
        {
            QLearning(currentState, reward);

            // Applies a sim interval and rounds
            initialState = new float[2] {(float)Math.Round((double)currentState[0] / simInterval) * simInterval , (float)Math.Round((double)currentState[1] / simInterval) * simInterval};

            firstIteration = false;

            int actionIndex = r.Next(0, possibleActions);

            bool exists = false;
            if(QStates.Count > 0)
            {
                for(int i = 0; i < QStates.Count; i++)
                {
                    float[] state = QStates.ElementAt(i);
                    float[] actions = QActions.ElementAt(i);

                    if(state[0] ==  initialState[0] && state[1] ==  initialState[1])
                    {
                        exists = true;
                        initialActionIndex = Array.IndexOf(actions, MaxFloat(actions));

                        return initialActionIndex;
                    }
                }
            }

            if(!exists)
            {
                float[] actionVals = new float[possibleActions];
                for (int i = 0; i < possibleActions; i++)
                {
                    actionVals[i] = 0f;
                }
                QStates.Add( initialState);
                QActions.Add(actionVals);
            }

            initialActionIndex = actionIndex;
            return initialActionIndex;
        }

        public QLearnerScript(int possActs)
        {
            QStates = new List<float[]>();
            QActions = new List<float[]>();
            possibleActions = possActs;

            learningRate = .5f; // Between 0 and 1
            discountFactor = 1f;

            simInterval = 1f;

            firstIteration = true;
        }

        public void QLearning(float[] outcomeStateFeed, float reward)
        {
            if(!firstIteration)
            {
                outcomeState = new float[2] {(float)Math.Round((double)outcomeStateFeed[0] / simInterval) * simInterval , (float)Math.Round((double)outcomeStateFeed[1] / simInterval) * simInterval};

                bool exists = false;
                for(int i = 0; i < QStates.Count; i++)
                {
                    float[] state = QStates.ElementAt(i);
                    float[] actions = QActions.ElementAt(i);

                    if(state[0] == outcomeState[0] && state[1] == outcomeState[1])
                    {
                        exists = true;
                        outcomeActionValue = MaxFloat(actions);
                    }
                }

                for(int i = 0; i < QStates.Count; i++)
                {
                    float[] state = QStates.ElementAt(i);
                    float[] actions = QActions.ElementAt(i);

                    if(state[0] ==  initialState[0] && state[1] ==  initialState[1])
                    {

                        if(exists)
                        {
                            actions[initialActionIndex] += learningRate * (reward + discountFactor * outcomeActionValue - actions[initialActionIndex]);
                        }
                        if(!exists)
                        {
                            actions[initialActionIndex] += learningRate * (reward + discountFactor * 0f - actions[initialActionIndex]);
                        }
                    }
                }
            }
        }

        public int getQtableCount()
        {
            return QStates.Count;
        }

        float MaxFloat(float[] numbers)
        {
            float max = numbers[0];

            for (int i = 0; i < numbers.Length; i++)
                if (max < numbers[i])
                {
                    max = numbers[i];
                }

            return max;
        }
    }
}

它在我的环境中工作得很好。但是,我也在尝试实现SARSA as来测试这两种算法。我知道Q-learning是非策略的,而SARSA是on-policy的,这意味着我必须实现一个策略来获取下一个操作,而不是简单地调用

代码语言:javascript
复制
MaxFloat(actions)

然而,这一点的实际实现让我感到困惑,我如何修改我的脚本以包含此策略?

EN

回答 1

Stack Overflow用户

发布于 2018-06-08 21:20:52

在SARSA中,算法的名称也是算法:保存一个状态、一个操作、奖励以及下一个状态和操作,然后使用这些信息执行更新。

当您不仅拥有当前状态和奖励时,还需要在拥有前一个状态、前一个状态的奖励和当前状态的点计算更新。SARSA将使用当前状态,而Q-Learning将其替换为贪婪策略的预测。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50760934

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档