首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Sklearn.mixture.dpgmm不能正常工作

Sklearn.mixture.dpgmm不能正常工作
EN

Stack Overflow用户
提问于 2015-11-20 08:38:14
回答 2查看 385关注 0票数 0

我和sklearn.mixture.dpgmm有点问题。主要问题是它没有为合成数据返回正确的协方差(2个分离的二维高斯),而它实际上应该没有问题。特别是,当我执行dpgmm._get_covars()时,协方差矩阵的对角元素总是恰好大于1.0,而不管输入数据的分布如何。这似乎是一个bug,因为gmm工作得很好(当限制到已知的确切组数时)

另一个问题是dpgmm.weights_没有任何意义,它们加起来都是1,但是值看起来毫无意义。

有没有人对此有解决方案,或者清楚地看到我的例子有什么问题?

下面是我正在运行的确切脚本:

代码语言:javascript
复制
import itertools
import numpy as np
from scipy import linalg
import matplotlib.pyplot as plt
import matplotlib as mpl
import pdb

from sklearn import mixture

# Generate 2D random sample, two gaussians each with 10000 points
rsamp1 =     np.random.multivariate_normal(np.array([5.0,5.0]),np.array([[1.0,-0.2],[-0.2,1.0]]),10000)
rsamp2 = np.random.multivariate_normal(np.array([0.0,0.0]),np.array([[0.2,-0.0],[-0.0,3.0]]),10000)
X = np.concatenate((rsamp1,rsamp2),axis=0)

# Fit a mixture of Gaussians with EM using 2
gmm = mixture.GMM(n_components=2, covariance_type='full',n_iter=10000)
gmm.fit(X)

# Fit a Dirichlet process mixture of Gaussians using 10 components
dpgmm = mixture.DPGMM(n_components=10, covariance_type='full',min_covar=0.5,tol=0.00001,n_iter = 1000000)
dpgmm.fit(X)

print("Groups With data in them")
print(np.unique(dpgmm.predict(X)))

##print the input and output covars as example, should be very similar
correct_c0 = np.array([[1.0,-0.2],[-0.2,1.0]])
print "Input covar"
print correct_c0

covars = dpgmm._get_covars()
c0 = np.round(covars[0],decimals=1)
print "Output Covar"
print c0

print("Output Variances Too Big by 1.0")
EN

回答 2

Stack Overflow用户

发布于 2016-09-30 02:51:00

根据dpgmm docs的说法,这个类在版本0.18中是不推荐使用的,将在版本0.20中删除

您应该改用BayesianGaussianMixture类,并使用选项"dirichlet_process"设置参数weight_concentration_prior_type

希望能有所帮助

票数 1
EN

Stack Overflow用户

发布于 2017-03-23 17:50:46

而不是写作

代码语言:javascript
复制
from sklearn.mixture import GMM
gmm = GMM(2, covariance_type='full', random_state=0)

你应该这样写:

代码语言:javascript
复制
from sklearn.mixture import BayesianGaussianMixture
gmm = BayesianGaussianMixture(2, covariance_type='full', random_state=0)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/33817028

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档