首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用arviz识别pymc3链中的分歧

如何使用arviz识别pymc3链中的分歧
EN

Stack Overflow用户
提问于 2021-04-01 05:38:30
回答 1查看 263关注 0票数 0

我想要识别pymc3采样的链中的分歧。每个样本与一个组和一个条件(轨迹中的坐标)相关联。

出于本例的目的,以下结果仅考虑1个链和1个条件(轨迹的坐标)。

我使用Arviz.InferenceData绘制特定变量a_kg的样本跟踪图(其中每行表示一组):

代码语言:javascript
复制
import arviz as az

# trace variable coming from pymc3.sample
azdata = az.from_pymc3(
    trace=trace, 
    coords={'group': groups, 'condition': conditions}, 
    dims={'a_kg': ['group', 'condition']}
)
azdata_sel = azdata.sel(chain=[0], condition='Control')
az.plot_trace(azdata_sel, var_names=['a_kg'], divergences='bottom');

每组的轨迹如下所示:

如果我没理解错的话,分歧显示在痕迹的底部,上面有一个小地毯图。

如果这是正确的,则在图30附近存在分歧。因此,我得到了一个至少有一个分歧的样本切片(在本例中是包含样本30的切片),以更详细地探索这部分轨迹。

代码语言:javascript
复制
azdata_sel = azdata.sel(draw=slice(25, 35))
az.plot_trace(azdata_sel, var_names=['a_kg'], divergences='bottom')

为了更好地理解这个模型是如何工作的,我想找出链条为什么会有分歧。然而,当我查看变量a_kg的样本时,对于每组,在图30附近,所有值都被限制在一个狭窄的有限范围内:

代码语言:javascript
复制
array([[7.03689753e+01, 7.08419788e+01, 4.18270946e+01, 5.56815107e+01,
        2.89069656e+01, 3.21847218e+01, 1.72809154e+01, 6.80358410e+00,
        8.27741780e+00, 8.61561309e+00, 9.52030649e+00, 7.42601279e+00,
        4.86924384e+00, 4.65123572e+00, 3.42272331e+00, 3.72094392e+00,
        3.79496877e+00, 3.63692105e+00, 4.53843102e+00, 4.49938710e+00,
        1.16647181e+00, 1.57530039e+00, 1.38785612e+00, 2.93999569e+00,
        3.19698360e-01, 1.09373256e+00, 8.91772857e-01, 1.27258163e+00,
        7.30115016e-01, 6.48975286e-01, 9.53344198e-01, 7.10095320e-01,
        1.94587869e-01, 2.37110242e-01, 1.74995857e-02, 1.09717525e-01,
        2.49860304e-01, 1.73485239e-01, 3.15215749e-01]])

在采样过程中是否过滤掉了绘图中的差异?在这种情况下,您将如何继续诊断哪里出了问题?

EN

回答 1

Stack Overflow用户

发布于 2021-04-08 17:33:43

我认为这个文档有很多你需要知道的东西:https://docs.pymc.io/notebooks/Diagnosing_biased_Inference_with_Divergences.html

但总而言之,您应该知道,理解差异通常是一个困难的问题,并且没有灵丹妙药--您必须尝试许多(有时是很多很多的)事情。仅看轨迹图是不够的。话虽如此,我链接的文档有很多很好的建议。

我能给出的一般建议是,你不应该专注于一个有分歧的特定样本。那是没有意义的。发散的东西不是样本,而是轨迹。使用arviz.pair_plot并聚焦于分歧集中的地方(设置divergences=True)。运行更长的链(超过10k个样本),这样你就可以得到更多的分歧,并且可以很容易地发现病变区域。一旦你发现了病变区域,决定如何处理它将取决于你的特定模型。也许增加适应的步骤,也许改变你的先验,也许重新参数化你的模型。

由于您正在谈论组,我怀疑您使用的是分层模型。在这种情况下,我认为最好的方法是尝试另一种参数化。查找关于分层模型中中心参数化与非中心参数化的讨论。

祝你在寻找分歧时好运!:)

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66895673

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档