首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何创建一个自定义内核的高斯过程回归器在科学工具包-学习?

如何创建一个自定义内核的高斯过程回归器在科学工具包-学习?
EN

Stack Overflow用户
提问于 2018-03-08 17:01:01
回答 1查看 5.1K关注 0票数 10

我正在为一个相当特殊的上下文使用探地雷达,在这里我需要写我自己的内核。然而,我发现没有关于如何做到这一点的文档。尝试简单地继承Kernel并实现__call__get_paramsdiagis_stationary方法就足以使拟合过程正常工作,但当我试图预测y值和标准差时,就会崩溃。在使用自己的函数时,构建从Kernel继承的最小但功能类的必要步骤是什么?谢谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-02-13 15:03:41

根据内核的异国情调,您的问题的答案可能是不同的。

我发现RBF核的实现非常自文档化,所以我使用它作为参考。这是要点:

代码语言:javascript
复制
class RBF(StationaryKernelMixin, NormalizedKernelMixin, Kernel):
    def __init__(self, length_scale=1.0, length_scale_bounds=(1e-5, 1e5)):
        self.length_scale = length_scale
        self.length_scale_bounds = length_scale_bounds

    @property
    def hyperparameter_length_scale(self):
        if self.anisotropic:
            return Hyperparameter("length_scale", "numeric",
                                  self.length_scale_bounds,
                                  len(self.length_scale))
        return Hyperparameter(
            "length_scale", "numeric", self.length_scale_bounds)

    def __call__(self, X, Y=None, eval_gradient=False):
        # ...

正如您提到的,您的内核应该继承内核,这需要您实现__call__diagis_stationary。注意,sklearn.gaussian_process.kernels提供了StationaryKernelMixinNormalizedKernelMixin,它们为您实现diagis_stationary (cf )。在代码中定义RBF类)。

您不应该覆盖get_params!这是由Kernel类为您完成的,它期望scikit学习内核遵循一个约定,您的内核也应该这样做:将构造函数签名中的参数指定为关键字参数(参见前面的length_scale内核示例)。这确保您的内核可以被复制,这是由GaussianProcessRegressor.fit(...)完成的(这可能是您无法预测标准偏差的原因)。

此时,您可能会注意到另一个参数length_scale_bounds。这只是对实际超参数length_scale (cf )的一个约束。约束优化)。这就引出了这样一个事实:您还需要声明您的超参数,您想要优化,并且需要在您的__call__实现中计算梯度。您可以通过定义类的一个以hyperparameter_ (cf )为前缀的属性来做到这一点。(代码中的hyperparameter_length_scale )。每个未固定的超参数(fixed = hyperparameter.fixed == True)由Kernel.theta返回,GP在fit()上使用该参数并计算边际日志似然。因此,如果您想要将参数与数据相匹配,这是非常重要的。

关于Kernel.theta的最后一个细节,实现声明:

返回(扁平的、日志转换的)非固定的超参数。

因此,您应该小心使用超调参数中的0值,因为它们最终可能会成为np.nan并破坏一些内容。

我希望这会有所帮助,尽管这个问题已经有点老了。实际上,我自己从来没有实现过内核,但我渴望浏览sklearn代码库。不幸的是,没有关于这方面的官方教程,然而,代码库相当干净和注释。

票数 10
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49188159

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档