首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在本例中,是否可以使用我创建的所有类通用的超类?

在本例中,是否可以使用我创建的所有类通用的超类?
EN

Stack Overflow用户
提问于 2020-07-15 00:13:34
回答 1查看 26关注 0票数 0

我有所有这些类都非常相似,除了它们的covar_module。我想知道我是否可以创建一个通用的超类,它拥有所有的属性和函数,而子类只有covar_module。我是python的新手,所以我不知道synthax会是什么样子。

代码语言:javascript
复制
class RbfGP(ExactGP, GPyTorchModel):
    _num_outputs = 1  # to inform GPyTorchModel API

    def __init__(self, train_X, train_Y):
        # squeeze output dim before passing train_Y to ExactGP
        super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            base_kernel= gpytorch.kernels.RBFKernel(ard_num_dims=train_X.shape[-1]),
        )
        self.to(train_X)  # make sure we're on the right device/dtype

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


class Matern12GP(ExactGP, GPyTorchModel):
    _num_outputs = 1  # to inform GPyTorchModel API

    def __init__(self, train_X, train_Y):
        # squeeze output dim before passing train_Y to ExactGP
        super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            base_kernel=gpytorch.kernels.MaternKernel(nu=0.5, ard_num_dims=train_X.shape[-1]),
        )
        self.to(train_X)  # make sure we're on the right device/dtype

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


class Matern32GP(ExactGP, GPyTorchModel):
    _num_outputs = 1  # to inform GPyTorchModel API

    def __init__(self, train_X, train_Y):
        # squeeze output dim before passing train_Y to ExactGP
        super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            base_kernel=gpytorch.kernels.MaternKernel(nu=1.5, ard_num_dims=train_X.shape[-1]),
        )
        self.to(train_X)  # make sure we're on the right device/dtype

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)
EN

回答 1

Stack Overflow用户

发布于 2020-07-15 05:51:48

为什么不这样做呢:

代码语言:javascript
复制
class GP(ExactGP, GPyTorchModel):
    _num_outputs = 1  # to inform GPyTorchModel API

    def __init__(self, train_X, train_Y, base_kernel, **kwargs):
        # squeeze output dim before passing train_Y to ExactGP
        super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            base_kernel=base_kernel(ard_num_dims=train_X.shape[-1], **kwargs),
        )
        self.to(train_X)  # make sure we're on the right device/dtype

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

然后

代码语言:javascript
复制
rbf_gp = GP(train_x, train_y, base_kernel=gpytorch.kernels.RBFKernel)

matern_12_gp = GP(train_x, train_y, base_kernel=gpytorch.kernels.MaternKernel, nu=0.5)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62899452

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档