我有所有这些类都非常相似,除了它们的covar_module。我想知道我是否可以创建一个通用的超类,它拥有所有的属性和函数,而子类只有covar_module。我是python的新手,所以我不知道synthax会是什么样子。
class RbfGP(ExactGP, GPyTorchModel):
_num_outputs = 1 # to inform GPyTorchModel API
def __init__(self, train_X, train_Y):
# squeeze output dim before passing train_Y to ExactGP
super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(
base_kernel= gpytorch.kernels.RBFKernel(ard_num_dims=train_X.shape[-1]),
)
self.to(train_X) # make sure we're on the right device/dtype
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
class Matern12GP(ExactGP, GPyTorchModel):
_num_outputs = 1 # to inform GPyTorchModel API
def __init__(self, train_X, train_Y):
# squeeze output dim before passing train_Y to ExactGP
super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(
base_kernel=gpytorch.kernels.MaternKernel(nu=0.5, ard_num_dims=train_X.shape[-1]),
)
self.to(train_X) # make sure we're on the right device/dtype
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
class Matern32GP(ExactGP, GPyTorchModel):
_num_outputs = 1 # to inform GPyTorchModel API
def __init__(self, train_X, train_Y):
# squeeze output dim before passing train_Y to ExactGP
super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(
base_kernel=gpytorch.kernels.MaternKernel(nu=1.5, ard_num_dims=train_X.shape[-1]),
)
self.to(train_X) # make sure we're on the right device/dtype
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)发布于 2020-07-15 05:51:48
为什么不这样做呢:
class GP(ExactGP, GPyTorchModel):
_num_outputs = 1 # to inform GPyTorchModel API
def __init__(self, train_X, train_Y, base_kernel, **kwargs):
# squeeze output dim before passing train_Y to ExactGP
super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(
base_kernel=base_kernel(ard_num_dims=train_X.shape[-1], **kwargs),
)
self.to(train_X) # make sure we're on the right device/dtype
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)然后
rbf_gp = GP(train_x, train_y, base_kernel=gpytorch.kernels.RBFKernel)
matern_12_gp = GP(train_x, train_y, base_kernel=gpytorch.kernels.MaternKernel, nu=0.5)https://stackoverflow.com/questions/62899452
复制相似问题