首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何为多类分类提取随机森林树规则?

如何为多类分类提取随机森林树规则?
EN

Stack Overflow用户
提问于 2019-06-27 20:46:01
回答 1查看 141关注 0票数 1

嗨,我想在多类分类的情况下从一棵树中提取规则

代码语言:javascript
复制
from sklearn.tree import _tree 
from sklearn.tree import DecisionTreeClassifier

#creat a gaussian classifier
clf=RandomForestClassifier(n_estimators=100)

#train the model using the training sets y_pred=clf.predict(X_test)

clf.fit(X_train,y_train)

#extract one tree from the forest
model = clf.estimators_[0]


def find_rules(tree,features): 
    dt=tree.tree_
    def visitor(node,depth):
        indent= ' ' * depth
        if dt.feature[node] != _tree.TREE_UNDEFINED:
            print('{} if <{}> <= {}:'.format(indent,features[node],round(dt.threshold[node],100)))
            visitor(dt.children_left[node],depth+1)
            print('{}else:'.format(indent))
            visitor(dt.children_right[node],depth+1)
        else:
            print('{} return {}'.format(indent,dt.value[node]))
    visitor(0,1)


find_rules(model, iris.feature_names)

EN

回答 1

Stack Overflow用户

发布于 2019-06-27 22:47:19

请检查以下代码。它似乎起作用了。只有一个小小的变化

代码语言:javascript
复制
def find_rules(tree,features): 
    dt=tree.tree_
    def visitor(node,depth):
        indent= ' ' * depth
        if dt.feature[node] != _tree.TREE_UNDEFINED:
            print('{} if <{}> <= {}:'.format(indent,features[dt.feature[node]],round(dt.threshold[node],100)))
            # in the previous line i added a backward-mapping
            # for the feature id
            visitor(dt.children_left[node],depth+1)
            print('{} else:'.format(indent))
            visitor(dt.children_right[node],depth+1)
        else:
            print('{} return {}'.format(indent,dt.value[node]))
    visitor(0,1)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56791341

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档