首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >制作一个ML模型科学工具包-学习兼容

制作一个ML模型科学工具包-学习兼容
EN

Stack Overflow用户
提问于 2021-05-09 00:03:35
回答 1查看 271关注 0票数 5

我想要使这个ML模型科学学习兼容:https://github.com/manifoldai/merf

为此,我遵循这里的说明:https://danielhnyk.cz/creating-your-own-estimator-scikit-learn/和导入from sklearn.base import BaseEstimator, RegressorMixin,并从它们继承如下:class MERF(BaseEstimator, RegressorMixin):

但是,当我检查scikit-学习兼容性时:

代码语言:javascript
复制
from sklearn.utils.estimator_checks import check_estimator

import merf
check_estimator(merf)

我知道这个错误:

代码语言:javascript
复制
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 500, in check_estimator
    for estimator, check in checks_generator:
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 340, in _generate_instance_checks
    yield from ((estimator, partial(check, name))
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 340, in <genexpr>
    yield from ((estimator, partial(check, name))
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 232, in _yield_all_checks
    tags = estimator._get_tags()
AttributeError: module 'merf' has no attribute '_get_tags'

我如何使这个模型科学学习兼容?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-05-09 01:11:54

文档中,check_estimator被用来“检查估值器是否遵守科学知识学习惯例”。

此评估器将运行一个广泛的测试套件,用于输入验证、形状等,确保该评估器符合scikit学习约定,详细介绍在滚动您自己的评估器中。如果Estimator类从sklearn.base中继承相应的混音,则将运行对分类器、回归器、群集或转换器的附加测试。

因此,check_estimator不仅仅是一个兼容性检查,它还检查您是否遵守了所有约定等等。

您可以在滚动你自己的估测器上阅读,以确保您遵循惯例。

然后,您需要传递估计器类的一个实例来检查esimator,比如check_estimator(MERF())。实际上,要使它遵循所有的约定,您必须解决它抛出的每一个错误,并逐个修复它们。

例如,这样的检查之一是__init__方法只设置它作为参数接受的那些属性。

MERF类违反了以下规定:

代码语言:javascript
复制
    def __init__(
        self,
        fixed_effects_model=RandomForestRegressor(n_estimators=300, n_jobs=-1),
        gll_early_stop_threshold=None,
        max_iterations=20,
    ):
        self.gll_early_stop_threshold = gll_early_stop_threshold
        self.max_iterations = max_iterations

        self.cluster_counts = None
        # Note fixed_effects_model must already be instantiated when passed in.
        self.fe_model = fixed_effects_model
        self.trained_fe_model = None
        self.trained_b = None

        self.b_hat_history = []
        self.sigma2_hat_history = []
        self.D_hat_history = []
        self.gll_history = []
        self.val_loss_history = []

它正在设置属性(如self.b_hat_history ),即使它们不是参数。

还有很多像这样的支票。

我个人的建议是,除非有必要,否则不要检查所有这些条件,只需继承Mixins和基类,实现所需的方法并使用模型。

票数 6
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67453220

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档