我正在学习本教程:An Introduction to Inference in Pyro
我不明白的是以下几点。为了获得(??????|?????,???????????=9.5),我们可以使用pyro.condition函数
def scale(guess):
weight = pyro.sample("weight", dist.Normal(guess, 1.0))
print(weight)
return pyro.sample("measurement", dist.Normal(weight, 0.75))和conditioned_scale = pyro.condition(scale, data={"measurement": 9.5})
我写了以下脚本:
pyro.set_rng_seed(101)
scale(0.3) # tensor(-1.0905)
pyro.set_rng_seed(101)
conditioned_scale(0.3) # tensor(-1.0905)对于这两个函数,我们得到相同的权重样本。本教程不是在说,使用conditioned_scale,我们可以从以measurement=9.5为条件的权重分布中获得样本吗?如果是这样,权重的样本不应该是不同的,因为在第一个调用中我们没有观察到任何数据,但在第二个调用中我们以数据为条件?
谢谢!
发布于 2021-06-13 04:29:16
运行模型不会从后部产生样本,您需要运行推断(如SVI或MCMC)。
condition会将示例站点值替换为您指定的值。由于您为measurement指定值,因此weight不受影响。您所编写的模型等同于N(measurement;N(weight;guess,1),.75),并且通过条件化,您已经声明了measurement=9.5。conditioned_scale = pyro.condition(scale, data={"weight": 9.5})和相同的密钥将产生不同的测量值。下面我用NumPyro写了同样的程序。你应该去看看https://forum.pyro.ai/。
import numpyro
import numpyro.distributions as dist
def scale(rng_key, guess):
w_key, m_key = random.split(rng_key)
weight = numpyro.sample("weight", dist.Normal(guess, 1.0), rng_key=w_key)
print(weight)
return numpyro.sample("measurement", dist.Normal(weight, 0.75), rng_key=m_key)
if __name__ == '__main__':
rng_key = random.PRNGKey(0)
print(scale(rng_key, 0.3)) # -0.49476373
conditioned_scale = numpyro.handlers.condition(scale, data={"weight": 9.5})
print(conditioned_scale(rng_key, 0.3)) # 8.561346https://stackoverflow.com/questions/67335114
复制相似问题