首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何利用pgmpy在贝叶斯网络中求出新事件的概率?

如何利用pgmpy在贝叶斯网络中求出新事件的概率?
EN

Stack Overflow用户
提问于 2022-01-06 08:39:15
回答 1查看 213关注 0票数 1

我用pgmpy库训练了一个贝叶斯网络。我希望找到一个新事件的联合概率(作为每个变量的概率的乘积,如果它有父母的话)。

目前我正在做

代码语言:javascript
复制
infer = VariableElimination(model)
evidence = dict(x_test.iloc[0])
result = infer.query(variables=[], evidence=evidence, joint=True)
print(result)

这里x_test是测试数据。

result是非常大的输出,所有组合的列车数据和他们的概率。

代码语言:javascript
复制
+----------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------+------------------------------------------+-----------------+---------------------------+-----------------------------------------+------------------------------+------------------------+---------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+
| data_devicetype                                                                                                                              | data_username                      | data_applicationtype                     | event_type      | servicename               | data_applicationname                    | tenantname                   | data_origin            | geoip_country_name        |   phi(data_devicetype,data_username,data_applicationtype,event_type,servicename,data_applicationname,tenantname,data_origin,geoip_country_name) |
+==============================================================================================================================================+====================================+==========================================+=================+===========================+=========================================+==============================+========================+===========================+=================================================================================================================================================+
| data_devicetype(Mozilla_5_0_Windows_NT_10_0_Win64_x64_AppleWebKit_537_36_KHTML_like_Gecko_Chrome_94_0_4606_81_Safari_537_36)                 | data_username(christofer) | data_applicationtype(Custom_Application) | event_type(sso) | servicename(saml_runtime) | data_applicationname(GD)            | tenantname(amx-sni-ksll0) | data_origin(1_0_64_66) | geoip_country_name(Japan) |                                                                                                                                          0.0326 |
+----------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------+------------------------------------------+-----------------+---------------------------+-----------------------------------------+------------------------------+------------------------+---------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+
| data_devicetype(Mozilla_5_0_Windows_NT_10_0_Win64_x64_AppleWebKit_537_36_KHTML_like_Gecko_Chrome_94_0_4606_81_Safari_537_36)                 | data_username(marty) | data_applicationtype(Custom_Application) | event_type(sso) | servicename(saml_runtime) | data_applicationname(VAULT)      | tenantname(login_pqr_com) | data_origin(1_0_64_66) | geoip_country_name(Japan) |                                                                                                                                          0.0156 |
+----------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------+------------------------------------------+-----------------+---------------------------+-----------------------------------------+------------------------------+------------------------+---------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+
| data_devicetype(Mozilla_5_0_Windows_NT_10_0_Win64_x64_AppleWebKit_537_36_KHTML_like_Gecko_Chrome_94_0_4606_81_Safari_537_36)                 | data_username(lincon) | data_applicationtype(Custom_Application) | event_type(sso) | servicename(saml_runtime) | data_applicationname(apps_think4ch_com) | tenantname(login_abc_com) | data_origin(1_0_64_66) | geoip_country_name(Japan) |                                                                                                                                          0.0113 |
......contd

请帮助我了解如何才能找到新事件的概率(即测试数据中的一行)。概率表达式为P(data_devicetype, data_username, data_applicationtype, event_type, servicename, data_applicationname, tenantname, data_origin, geoip_country_name)

EN

回答 1

Stack Overflow用户

发布于 2022-01-06 09:55:34

如果我正确地理解了你,你就会试图计算一个新的数据点的概率。不幸的是,在pgmpy中还没有直接的方法来实现它。虽然可以从推理结果中得到概率值。就像这样:

代码语言:javascript
复制
infer = VariableElimination(model)
result = infer.query(variables=list(model.nodes()), joint=True)
evidence = dict(x_test.iloc[0])
p_evidence = result.get_value(**evidence)

本质上,我们在这里计算所有变量的联合分布,然后取evidence数据点的概率值。正如您所预期的那样,在大型网络的情况下,这在计算上是非常低的。在这种情况下,计算概率的一种近似方法是使用模拟。

代码语言:javascript
复制
nsamples = int(1e6)
samples = model.simulate(nsamples)
evidence = dict(x_test.iloc[0])
matching_samples = samples[np.logical_and.reduce([samples[k]==v for k, v in evidence.items()])]
p_evidence = matching_samples.shape[0] / nsamples

用模拟的方法,我们从模型中生成一些模拟数据,并检验这些样本中有多少与我们的数据点相匹配,这就是概率。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70604458

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档