首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >fit后检索训练数据

fit后检索训练数据
EN

Stack Overflow用户
提问于 2017-02-27 09:22:33
回答 1查看 147关注 0票数 1

我希望能够检索到scikit-learn估计器训练过的数据(即,在拟合之后)。

例如,如果我像这样拟合一个RandomForestClassifier:

代码语言:javascript
复制
rf = RandomForestClassifier()
train_X = np.asarray([[0, 1, 0], [1, 1, 1], [0, 1, 1]])
train_y = np.asarray([1, 0, 1])
rf.fit(train_X, train_y)

有没有办法从估计器返回我的训练数据和类标签?

就像..。

代码语言:javascript
复制
rf.X_
>>>array([[0, 1, 0],
          [1, 1, 1],
          [0, 1, 1]])
EN

回答 1

Stack Overflow用户

发布于 2017-02-27 09:53:11

docs上看,我没有看到任何真正允许这样做的东西。但是,您可以尝试像这样定义一个类:

代码语言:javascript
复制
class RFClassifierWithData:
    def __init__(self):
        self.clf = RandomForestClassifier()
    def fit(self, train_X, train_y):
        self.train_X = train_X
        self.train_y = train_y
        self.clf.fit(self.train_X, self.train_y)

尝试一下:

代码语言:javascript
复制
>>> model = RFClassifierWithData()
>>> model.fit(train_X, train_y)
>>> model.train_X
array([[0, 1, 0],
       [1, 1, 1],
       [0, 1, 1]])
>>> model.train_y
array([1, 0, 1])

并且您仍然可以访问已安装的分类器:

代码语言:javascript
复制
>>> model.clf
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False)

请注意,这可能不是执行此操作的最安全或最健壮的方式,但它应该为您提供一个很好的起点。您可能希望使该类中传递给__init__的参数与RandomForestClassifier基类中的参数相等。

编辑:

我仍然认为这是一个有效的选择,即使你试图从一个酸洗过的分类器中获取数据:

代码语言:javascript
复制
from sklearn.externals import joblib

joblib.dump(model, 'model.pkl')
same_model = joblib.load('model.pkl')

一切都还在:

代码语言:javascript
复制
In [19]: same_model.train_X
Out[19]: 
array([[0, 1, 0],
       [1, 1, 1],
       [0, 1, 1]])

same_model.train_y
Out[20]: array([1, 0, 1])
In [21]: same_model.clf
Out[21]: 
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/42476383

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档