我正在尝试在gpflow中实现我自己的MultioutputKernel (MOK),但是我被(Kernel, Inducing Variable)组合的Multiple Dispatch所困。
根据docs的说法,组合MultioutputKernel和InducingPoints的回退方法应该是调用fully_correlated_conditional (通过gpf.conditionals.multioutput.inducing_point_conditional)。
然而,我不能让任何MOK与非独立的诱导变量一起工作,即使是预先实现的。下面是SharedIndependent的一个最小的不工作的示例
######################## toy data
d = 1
X = np.random.normal(0, 10, (100, d))
xx = np.linspace(-10, 10, 200).reshape(-1, 1)
f = lambda x: x ** 2
f_ = lambda x: 2 * x
Y = f(X)
Y_ = f_(X)
Y_combined = np.hstack((Y, Y_))
data = (X, Y_combined)
######################### gpflow stuff
kernel = gpf.kernels.RBF(lengthscales=[1.0] * d)
Z = X.copy()
# create multi-output inducing variables from Z
iv = gpf.inducing_variables.InducingPoints(Z)
MOK_K = gpf.kernels.SharedIndependent(kernel, output_dim=2)
m = gpf.models.SVGP(likelihood=gpf.likelihoods.Gaussian(), kernel=MOK_K, num_latent_gps=2,
inducing_variable=iv)
optimizer = gpf.optimizers.Scipy()
optimizer.minimize(
m.training_loss_closure(data),
variables=m.trainable_variables,
method="l-bfgs-b",
options={"disp": True, "maxiter": 1000},
)这是行不通的,除非你把诱导点换成
iv = gpf.inducing_variables.SharedIndependentInducingVariables(
gpf.inducing_variables.InducingPoints(Z)
)但是对于我的自定义非独立内核,我需要完全相关的条件。我得到的错误是
ValueError: base_conditional() arguments [Note that this check verifies the shape of an alternative representation of Kmn. See the docs for the actual expected shape.]在传递给base_conditional之前,似乎在inducing_point_conditional方法中,它试图将内核矩阵“展平”(到2d)为经典的多输出表示。然而,我不明白哪里出了问题,因为形状应该是好的。它们就像文档中定义的那样。
我需要做些什么才能让它与完全相关的条件语句一起运行?
发布于 2020-11-24 18:45:23
问题是变分布参数的维数,q_mu和q_sqrt。
必须事先手动定义它们,以便它们分别具有形状[N*P, 1]和[1, N*P, N*P],其中P是输出尺寸(在我的示例中为2)。
请注意,这与文档中的略有不同,他们说q_mu必须是[1, N*P]。
我的问题中的代码可以正常运行:
# [200, 1]
q_mu = tf.zeros([200, 1], dtype=gpf.config.default_float())
# [1, 200, 200]
q_sqrt = tf.eye(200, dtype=gpf.config.default_float())[tf.newaxis, ...]
m = gpf.models.SVGP(likelihood=gpf.likelihoods.Gaussian(), kernel=MOK_K, num_latent_gps=2,
inducing_variable=iv, q_mu=q_mu, q_sqrt=q_sqrt)https://stackoverflow.com/questions/64967543
复制相似问题