首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >MDP的值迭代实现

MDP的值迭代实现
EN

Code Review用户
提问于 2014-08-23 13:21:37
回答 2查看 3.1K关注 0票数 5

我已经在决策理论库工作了一段时间了,因为我从未真正接受过任何关于代码最佳实践的正式培训,我很想听听您的反馈。这个特定的类是我以前的类之一,它在一个尊重特定接口的提供的类上执行值迭代算法

代码中有包含,但我不确定是否应该包含所有相关代码,或者提供更多有关未显示部分的信息。请告诉我是否应该。

我主要关心的是是否应该像这样使用getter和setter,或者是否应该将每一种方法内联起来以提高易用性/速度。当然,如果你觉得我应该改变什么,我会很高兴知道。

头文件:

代码语言:javascript
复制
#ifndef AI_TOOLBOX_MDP_VALUE_ITERATION_HEADER_FILE
#define AI_TOOLBOX_MDP_VALUE_ITERATION_HEADER_FILE

#include <tuple>
#include <iostream>
#include <iterator>

#include <AIToolbox/MDP/Types.hpp>
#include <AIToolbox/MDP/Utils.hpp>
#include <AIToolbox/ProbabilityUtils.hpp>

namespace AIToolbox {
    namespace MDP {
        /**
         * @brief This class applies the value iteration algorithm on a Model.
         *
         * This algorithm solves an MDP model for the specified horizon, or less
         * if convergence is encountered.
         *
         * The idea of this algorithm is to iteratively compute the
         * ValueFunction for the MDP optimal policy. On the first iteration,
         * the ValueFunction for horizon 1 is obtained. On the second
         * iteration, the one for horizon 2. This process is repeated until the
         * ValueFunction has converged to a specific value within a certain
         * accuracy, or the horizon requested is reached.
         *
         * This implementation in particular is ported from the MATLAB
         * MDPToolbox (although it is simplified).
         */
        class ValueIteration {
            public:
                /**
                 * @brief Basic constructor.
                 *
                 * The epsilon parameter must be >= 0.0, otherwise the
                 * constructor will throw an std::runtime_error. The epsilon
                 * parameter sets the convergence criterion. An epsilon of 0.0
                 * forces ValueIteration to perform a number of iterations
                 * equal to the horizon specified. Otherwise, ValueIteration
                 * will stop as soon as the difference between two iterations
                 * is less than the epsilon specified.
                 *
                 * Note that the default value function size needs to match
                 * the number of states of the Model. Otherwise it will
                 * be ignored. An empty value function will be defaulted
                 * to all zeroes.
                 *
                 * @param horizon The maximum number of iterations to perform.
                 * @param epsilon The epsilon factor to stop the value iteration loop.
                 * @param v The initial value function from which to start the loop.
                 */
                ValueIteration(unsigned horizon, double epsilon = 0.001, ValueFunction v = ValueFunction(Values(0), Actions(0)));

                /**
                 * @brief This function applies value iteration on an MDP to solve it.
                 *
                 * The algorithm is constrained by the currently set parameters.
                 *
                 * @tparam M The type of the solvable MDP.
                 * @param m The MDP that needs to be solved.
                 * @return A tuple containing a boolean value specifying whether
                 *         the specified epsilon bound was reached and the
                 *         ValueFunction and the QFunction for the Model.
                 */
                template <typename M, typename std::enable_if<is_model<M>::value, int>::type = 0>
                std::tuple<bool, ValueFunction, QFunction> operator()(const M & m);

                /**
                 * @brief This function sets the epsilon parameter.
                 *
                 * The epsilon parameter must be >= 0.0, otherwise the
                 * constructor will throw an std::runtime_error. The epsilon
                 * parameter sets the convergence criterion. An epsilon of 0.0
                 * forces ValueIteration to perform a number of iterations
                 * equal to the horizon specified. Otherwise, ValueIteration
                 * will stop as soon as the difference between two iterations
                 * is less than the epsilon specified.
                 *
                 * @param e The new epsilon parameter.
                 */
                void setEpsilon(double e);

                /**
                 * @brief This function sets the horizon parameter.
                 *
                 * @param h The new horizon parameter.
                 */
                void setHorizon(unsigned h);

                /**
                 * @brief This function sets the starting value function.
                 *
                 * An empty value function defaults to all zeroes. Note
                 * that the default value function size needs to match
                 * the number of states of the Model that needs to be
                 * solved. Otherwise it will be ignored.
                 *
                 * @param v The new starting value function.
                 */
                void setValueFunction(ValueFunction v);

                /**
                 * @brief This function will return the currently set epsilon parameter.
                 *
                 * @return The currently set epsilon parameter.
                 */
                double getEpsilon() const;

                /**
                 * @brief This function will return the current horizon parameter.
                 *
                 * @return The currently set horizon parameter.
                 */
                unsigned getHorizon() const;

                /**
                 * @brief This function will return the current set default value function.
                 *
                 * @return The currently set default value function.
                 */
                const ValueFunction & getValueFunction() const;

            private:
                // Parameters
                double discount_, epsilon_;
                unsigned horizon_;
                ValueFunction vParameter_;

                // Internals
                ValueFunction v1_;
                size_t S, A;

                // Internal methods
                /**
                 * @brief This function computes the single PRType of the MDP once for improved speed.
                 *
                 * @tparam M The type of the solvable MDP.
                 * @param m The MDP that needs to be solved.
                 *
                 * @return The Models's PRType.
                 */
                template <typename M, typename std::enable_if<is_model<M>::value, int>::type = 0>
                Table2D computeImmediateRewards(const M & model) const;

                /**
                 * @brief This function creates the Model's most up-to-date QFunction.
                 *
                 * @tparam M The type of the solvable MDP.
                 *
                 * @param m The MDP that needs to be solved.
                 * @param ir The immediate rewards of the model.
                 *
                 * @return A new QFunction.
                 */
                template <typename M, typename std::enable_if<is_model<M>::value, int>::type = 0>
                QFunction computeQFunction(const M & model, const Table2D & ir) const;

                /**
                 * @brief This function applies a single pass Bellman operator, improving the current ValueFunction estimate.
                 *
                 * This function computes the optimal value and action for
                 * each state, given the precomputed QFunction.
                 *
                 * @param q The precomputed QFunction.
                 * @param vOut The newly estimated ValueFunction.
                 */
                inline void bellmanOperator(const QFunction & q, ValueFunction * vOut) const;
        };

        template <typename M, typename std::enable_if<is_model<M>::value, int>::type>
        std::tuple<bool, ValueFunction, QFunction> ValueIteration::operator()(const M & model) {
            // Extract necessary knowledge from model so we don't have to pass it around
            S = model.getS();
            A = model.getA();
            discount_ = model.getDiscount();

            {
                // Verify that parameter value function is compatible.
                size_t size = std::get<VALUES>(vParameter_).size();
                if ( size != S ) {
                    if ( size != 0 )
                        std::cerr << "AIToolbox: Size of starting value function in ValueIteration::solve() is incorrect, ignoring...\n";
                    // Defaulting
                    v1_ = makeValueFunction(S);
                }
                else
                    v1_ = vParameter_;
            }

            auto ir = computeImmediateRewards(model);

            unsigned timestep = 0;
            double variation = epsilon_ * 2; // Make it bigger

            Values val0;
            QFunction q = makeQFunction(S, A);

            bool useEpsilon = checkDifferent(epsilon_, 0.0);
            while ( timestep < horizon_ && (!useEpsilon || variation > epsilon_) ) {
                ++timestep;

                auto & val1 = std::get<VALUES>(v1_);
                val0 = val1;

                q = computeQFunction(model, ir);
                bellmanOperator(q, &v1_);

                // We do this only if the epsilon specified is positive, otherwise we
                // continue for all the timesteps.
                if ( useEpsilon ) {
                    auto computeVariation = [](double lhs, double rhs) { return std::fabs(lhs - rhs); };
                    // We compute the difference and store it into v0 for comparison.
                    std::transform(std::begin(val1), std::end(val1), std::begin(val0), std::begin(val0), computeVariation);

                    variation = *std::max_element(std::begin(val0), std::end(val0));
                }
            }

            // We do not guarantee that the Value/QFunctions are the perfect ones, as we stop as within epsilon.
            return std::make_tuple(variation <= epsilon_, v1_, q);
        }

        template <typename M, typename std::enable_if<is_model<M>::value, int>::type>
        Table2D ValueIteration::computeImmediateRewards(const M & model) const {
            Table2D pr(boost::extents[S][A]);

            for ( size_t s = 0; s < S; ++s )
                for ( size_t a = 0; a < A; ++a )
                    for ( size_t s1 = 0; s1 < S; ++s1 )
                        pr[s][a] += model.getTransitionProbability(s,a,s1) * model.getExpectedReward(s,a,s1);

            return pr;
        }

        template <typename M, typename std::enable_if<is_model<M>::value, int>::type>
        QFunction ValueIteration::computeQFunction(const M & model, const Table2D & ir) const {
            QFunction q = ir;

            for ( size_t s = 0; s < S; ++s )
                for ( size_t a = 0; a < A; ++a )
                    for ( size_t s1 = 0; s1 < S; ++s1 )
                        q[s][a] += model.getTransitionProbability(s,a,s1) * discount_ * std::get<VALUES>(v1_)[s1];
            return q;
        }

        void ValueIteration::bellmanOperator(const QFunction & q, ValueFunction * v) const {
            auto & values  = std::get<VALUES> (*v);
            auto & actions = std::get<ACTIONS>(*v);

            for ( size_t s = 0; s < S; ++s ) {
                // Accessing an element like this creates a temporary. Thus we need to bind it.
                QFunction::const_reference ref = q[s];
                auto begin = std::begin(ref);
                auto it = std::max_element(begin, std::end(ref));

                values[s] = *it;
                actions[s] = std::distance(begin, it);
            }
        }
    }
}

#endif

源文件:

代码语言:javascript
复制
#include <AIToolbox/MDP/Algorithms/ValueIteration.hpp>

namespace AIToolbox {
    namespace MDP {
        ValueIteration::ValueIteration(unsigned horizon, double epsilon, ValueFunction v) : horizon_(horizon), vParameter_(v),
                                                                                            S(0), A(0)
        {
            setEpsilon(epsilon);
        }

        void ValueIteration::setEpsilon(double e) {
            if ( e < 0.0 ) throw std::invalid_argument("Epsilon must be >= 0");
            epsilon_ = e;
        }
        void ValueIteration::setHorizon(unsigned h) {
            horizon_ = h;
        }
        void ValueIteration::setValueFunction(ValueFunction v) {
            vParameter_ = v;
        }

        double                  ValueIteration::getEpsilon() const {
            return epsilon_;
        }
        unsigned                ValueIteration::getHorizon() const {
            return horizon_;
        }
        const ValueFunction &   ValueIteration::getValueFunction() const {
            return vParameter_;
        }
    }
}
EN

回答 2

Code Review用户

发布于 2014-08-23 15:59:40

我个人认为,在这里将事情分解为hpp/cpp文件没有意义。对于这样少量的功能(基本上只是getter和setter),您最好将它们与其他所有内容放在头文件中。

我不喜欢使用std::enable_if在非绝对必要的情况下切换函数。它确实有一个位置,但是这通常是当您需要根据模板参数在多个函数之间执行重载解析时。如果您只是想确保满足给定模板参数的某些条件,那么最好使用static_assert。没有确切地了解is_model的实现是什么,很难确切说明应该是什么,但我设想如下:

代码语言:javascript
复制
template <typename M>
QFunction ValueIteration::computeQFunction(const M & model, const Table2D & ir) const {
    static_assert(is_model<M>::value, "M must be a model!");
    //Implementation
    .....
}

另外,您在这里缺少了一个包含:应该有一个#include <type_traits> for std::enable_if。您还缺少了一个#include <algorithm> (对于std::max_elementstd::transform之类的东西)和一个用于std::fabs#include <cmath>。即使您的其他内部头正在拉这些,这是不好的做法,依赖-每个文件应该把它自己需要的一切。

也许变量名称SA在您的特定上下文中是有意义的,但是如果可能的话,我会尝试给它们更多的描述性名称。

您是否有理由在ValueIterator::operator()中创建一个内部作用域?

代码语言:javascript
复制
template <typename M, typename std::enable_if<is_model<M>::value, int>::type>
std::tuple<bool, ValueFunction, QFunction> ValueIteration::operator()(const M & model) {
// Extract necessary knowledge from model so we don't have to pass it around
S = model.getS();
A = model.getA();
discount_ = model.getDiscount();

// Inner scope here?
{
    // Verify that parameter value function is compatible.
    size_t size = std::get<VALUES>(vParameter_).size();
    if ( size != S ) {
        if ( size != 0 )
            std::cerr << "AIToolbox: Size of starting value function in ValueIteration::solve() is incorrect, ignoring...\n";
        // Defaulting
        v1_ = makeValueFunction(S);
    }
    else
        v1_ = vParameter_;
}

通常,当我看到这个时,我在寻找使用RAII清理东西(文件、锁、线程等)的东西。我在这里看不出有什么特别的理由,读起来有点困惑。

为什么在这里通过指针传递ValueFunction有什么特殊的原因?

代码语言:javascript
复制
void ValueIteration::bellmanOperator(const QFunction & q, ValueFunction * v) const;

没有对nullptr的检查,所以我希望它通过引用通过。

大多数情况下,你的代码看起来相当合理。已经提到了在{}中使用for/if/etc,我将对此予以支持。

票数 4
EN

Code Review用户

发布于 2014-08-23 15:27:06

AI不是我的知识领域,所以我不能评论算法。代码的总体结构对我来说似乎不错,所以我只想评论几个您可以改进/更改的风格小贴士。

我建议您始终在控件语句中添加大括号(ifforwhile、.)即使它是一个单行语句。示例:

代码语言:javascript
复制
for ( size_t s = 0; s < S; ++s )
{
    for ( size_t a = 0; a < A; ++a )
    {
        for ( size_t s1 = 0; s1 < S; ++s1 )
        {
            q[s][a] += model.getTransitionProbability(s,a,s1) * discount_ * std::get<VALUES>(v1_)[s1];
        }
    }
}

通过这种方式,您将保护代码不被意外插入缩进行,而缩进行看起来可能属于以下语句:

代码语言:javascript
复制
for ( size_t s = 0; s < S; ++s )
    for ( size_t a = 0; a < A; ++a )
        for ( size_t s1 = 0; s1 < S; ++s1 )
            q[s][a] += model.getTransitionProbability(s,a,s1) * discount_ * std::get<VALUES>(v1_)[s1];
            if (somethingSomething(q))
            {
                ....
            }

在快速通过该块时,您可能不会注意到错误,并认为if将在最后一个for循环中执行。

这一行相当长:

代码语言:javascript
复制
ValueIteration::ValueIteration(unsigned horizon, double epsilon, ValueFunction v) : horizon_(horizon), vParameter_(v),
                                                                                        S(0), A(0)
{
    setEpsilon(epsilon);
}

如果将初始化列表分解为其他行,则它将更具可读性:

代码语言:javascript
复制
ValueIteration::ValueIteration(unsigned int horizon, double epsilon, ValueFunction v) 
    : horizon_(horizon)
    , vParameter_(v)
    , S(0)
    , A(0)
{
    setEpsilon(epsilon);
}

缩进这些最后一行的方式(在返回类型之后):

代码语言:javascript
复制
double                  ValueIteration::getEpsilon() const {
    return epsilon_;
}
unsigned                ValueIteration::getHorizon() const {
    return horizon_;
}
const ValueFunction &   ValueIteration::getValueFunction() const {
    return vParameter_;
}

这个额外的空间似乎并没有提高可读性,我只需要在返回类型之后用一个空格声明函数:

代码语言:javascript
复制
double ValueIteration::getEpsilon() const {
    return epsilon_;
}

unsigned int ValueIteration::getHorizon() const {
    return horizon_;
}

const ValueFunction & ValueIteration::getValueFunction() const {
    return vParameter_;
}

和您一样,单独使用unsigned对我来说似乎有点过分。我喜欢避免使用“隐式int”,所以我建议您始终使用unsigned int

票数 3
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/60874

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档