首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在R包'party‘中创建ctree

如何在R包'party‘中创建ctree
EN

Stack Overflow用户
提问于 2014-03-20 23:27:01
回答 1查看 1K关注 0票数 0

我想得到创建条件推理树的包'party‘中的R代码的一部分。也就是说,代码的一部分指定了如何通过bootstrap示例一步一步地生长树。我检查了函数'ctree',但它的结构似乎很复杂。甚至我也找不到这样的迹象,比如sample(所有特征中的mtry特征)。有人知道在哪里可以找到它吗?或者方法如何找到它?

EN

回答 1

Stack Overflow用户

发布于 2014-03-21 03:43:32

这是一个C函数。从http://cran.r-project.org/src/contrib/party_1.0-13.tar.gz下载包源并在src/TreeGrow.c中查找

代码语言:javascript
复制
/**
    The tree growing recursion
    *\file TreeGrow.c
    *\author $Author: hothorn $
    *\date $Date: 2009-06-16 09:17:31 +0200 (Tue, 16 Jun 2009) $
*/

#include "party.h"


/**
    The main tree growing function, handles the recursion. \n
    *\param node  a list representing the current node
    *\param learnsample an object of class `LearningSample'
    *\param fitmem an object of class `TreeFitMemory'
    *\param controls an object of class `TreeControl'
    *\param where a pointer to an integer vector of n-elements
    *\param nodenum a pointer to a integer vector of length 1
    *\param depth an integer giving the depth of the current node
*/

void C_TreeGrow(SEXP node, SEXP learnsample, SEXP fitmem,
                SEXP controls, int *where, int *nodenum, int depth) {

    SEXP weights;
    int nobs, i, stop;
    double *dweights;

    weights = S3get_nodeweights(node);

    /* stop if either stumps have been requested or
       the maximum depth is exceeded */
    stop = (nodenum[0] == 2 || nodenum[0] == 3) &&
           get_stump(get_tgctrl(controls));
    stop = stop || !check_depth(get_tgctrl(controls), depth);

    if (stop)
        C_Node(node, learnsample, weights, fitmem, controls, 1, depth);
    else
        C_Node(node, learnsample, weights, fitmem, controls, 0, depth);

    S3set_nodeID(node, nodenum[0]);

    if (!S3get_nodeterminal(node)) {

        C_splitnode(node, learnsample, controls);

        /* determine surrogate splits and split missing values */
        if (get_maxsurrogate(get_splitctrl(controls)) > 0) {
            C_surrogates(node, learnsample, weights, controls, fitmem);
            C_splitsurrogate(node, learnsample);
        }

        nodenum[0] += 1;
        C_TreeGrow(S3get_leftnode(node), learnsample, fitmem,
                   controls, where, nodenum, depth + 1);

        nodenum[0] += 1;
        C_TreeGrow(S3get_rightnode(node), learnsample, fitmem,
                   controls, where, nodenum, depth + 1);

    } else {
        dweights = REAL(weights);
        nobs = get_nobs(learnsample);
        for (i = 0; i < nobs; i++)
            if (dweights[i] > 0) where[i] = nodenum[0];
    }
}


/**
    R-interface to C_TreeGrow\n
    *\param learnsample an object of class `LearningSample'
    *\param weights a vector of case weights
    *\param fitmem an object of class `TreeFitMemory'
    *\param controls an object of class `TreeControl'
    *\param where a vector of node indices for each observation
*/

SEXP R_TreeGrow(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls, SEXP where) {

     SEXP ans, nweights;
     double *dnweights, *dweights;
     int nobs, i, nodenum = 1;

     GetRNGstate();

     nobs = get_nobs(learnsample);
     PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
     C_init_node(ans, nobs, get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(controls)),
                 ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));

     nweights = S3get_nodeweights(ans);
     dnweights = REAL(nweights);
     dweights = REAL(weights);
     for (i = 0; i < nobs; i++) dnweights[i] = dweights[i];

     C_TreeGrow(ans, learnsample, fitmem, controls, INTEGER(where), &nodenum, 1);

     PutRNGstate();

     UNPROTECT(1);
     return(ans);
}
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/22537635

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档