首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何解决与GPyTorch一起使用SpectralMixture内核时遇到的错误?

如何解决与GPyTorch一起使用SpectralMixture内核时遇到的错误?
EN

Stack Overflow用户
提问于 2019-05-10 21:36:43
回答 1查看 151关注 0票数 1

我使用GPyTorch来拟合高斯过程回归模型(主要用于学习过程)。在遵循他们的教程时,我正在尝试使用SpectralMixtureKernel。但是,我得到了以下错误。但是,这里首先是代码(这与他们的教程基本相同,但为了方便起见,在这里复制):

代码语言:javascript
复制
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self,train_x,train_y,likelihood):
        super(ExactGPModel, self).__init__(train_x,train_y,likelihood)
        self.mean_module = gpytorch.means.ConstantMean()

        self.covar_module = gpytorch.kernels.SpectralMixtureKernel(num_mixtures=4)

        self.covar_module.initialize_from_data(train_x, train_y)



    def forward(self,x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x,covar_x)

熊猫数据转换为下面的torch.tensor

代码语言:javascript
复制
train_x = torch.tensor(train_x.values.astype(np.float32))
train_y = torch.tensor(train_y.values.astype(np.float32))

test_x = torch.tensor(test_x.values.astype(np.float32))
test_y = torch.tensor(test_y.values.astype(np.float32))

然后

代码语言:javascript
复制
likelihood = gpytorch.likelihoods.GaussianLikelihood()

model = ExactGPModel(train_x,train_y, likelihood)

一旦运行最后一行,我将得到以下错误:

代码语言:javascript
复制
Traceback (most recent call last):

  File "<ipython-input-195-e3bc37af324c>", line 1, in <module>
    model = ExactGPModel(train_x,train_y, likelihood)

  File "<ipython-input-186-323eff9c5819>", line 7, in __init__
    self.covar_module.initialize_from_data(train_x, train_y)

  File "/anaconda3/envs/py36/lib/python3.6/site-packages/gpytorch/kernels/spectral_mixture_kernel.py", line 163, in initialize_from_data
    self.raw_mixture_scales.data.normal_().mul_(max_dist).abs_().pow_(-1)

RuntimeError: output with shape [4, 1, 1] doesn't match the broadcast shape [4, 1, 33]

如能为解决这一问题提供任何帮助,将不胜感激。

谢谢。

EN

回答 1

Stack Overflow用户

发布于 2022-04-26 07:54:30

我也有同样的问题。在我的例子中,我使用的是维数大于1的train_x向量。

代码语言:javascript
复制
self.covar_module = gpytorch.kernels.SpectralMixtureKernel(num_mixtures=4, ard_num_dims=33)

关于https://docs.gpytorch.ai/en/latest/kernels.html#spectralmixturekernel的更多信息

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56085127

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档