我试图用sklearn的kde来拟合一些正常数据的内核密度估计。下面是一个例子:
import numpy as np
from sklearn.neighbors.kde import KernelDensity as kde
x = np.random.normal(scale = 2, size = [10000, 1])
np.var(x) # 4.0
test_kde = kde()
test_kde.fit(x)
np.var(test_kde.sample(10000)) # 5.0差异增加了1。我在这里做了什么蠢事吗?
发布于 2019-07-10 14:34:50
问题是,您没有指定正确的bandwidth来缩放单个密度函数,这就是为什么您过度平滑估计的密度函数的原因。由于示例数据遵循正态分布,因此带宽为
>>> h = ((4 * np.std(x)**5) / (3 * len(x)))**(1/5)
>>> h
0.33549590926904804会是最理想的。可以找到一个解释,维基百科。
>>> test_kde = kde(bandwidth=h)
>>> test_kde.fit(x)
>>> samples = test_kde.sample(10000)
>>> np.var(samples)
4.068727474888099 # close enough to 4但为什么我需要这样的比例?
内核密度估计通过使用内核函数(通常是正态分布的密度函数)来估计数据分布的密度。一般的想法是,通过把你的样本参数化的许多密度函数之和,最终,给出足够的样本,近似于原来的密度函数:
我们可以对您的数据进行可视化:
from matplotlib.colors import TABLEAU_COLORS
def gauss_kernel(x, m=0, s=1):
return (1/np.sqrt(2 * np.pi * s**2) * np.e**(-((x - m)**2 / (2*s**2))))
from matplotlib.colors import TABLEAU_COLORS
x_plot = np.linspace(-2, 2, 10)
h = 1
for xi, color in zip(x_plot, TABLEAU_COLORS.values()):
plt.plot(xi, gauss_kernel(xi, m=0, s=2) * 0.001, 'x', color=color)
plt.plot(x, 1 / (len(x) * h) * gauss_kernel((xi - x) / h), 'o', color=color)
plt.plot(xi, (1 / (len(x) * h) * gauss_kernel((xi - x) / h)).sum() * 0.001, 'o', color=color)

此图显示了[-2; 2]中某些点的估计密度和真实密度,以及每个点(相同颜色曲线)的核函数。估计的密度只是相应核函数的和。
可以看到,单个核函数的右/左越远,其和值就越低(因此也就越高)。要解释这一点,你必须记住,我们的原始数据点中心在0附近,因为它们是从正态分布中取样的,平均值为0,方差为2。因此,离中心越远,数据点就越少。因此,这意味着将这些点作为输入的高斯核函数最终会将所有数据点放在其平坦的尾部段中,并使其重量接近于零,这就是为什么这个核函数的和在那里非常小的原因。我们也可以说,我们是在用高斯密度函数来加窗。
通过设置h=2,可以清楚地看到带宽参数的影响。
h = 2
for xi, color in zip(x_plot, TABLEAU_COLORS.values()):
plt.plot(xi, gauss_kernel(xi, m=0, s=2) * 0.001, 'x', color=color)
plt.plot(x, 1 / (len(x) * h) * gauss_kernel((xi - x) / h), 'o', color=color)
plt.plot(xi, (1 / (len(x) * h) * gauss_kernel((xi - x) / h)).sum() * 0.001, 'o', color=color)

单个内核函数更平滑,因此,估计的密度也更平滑。其原因在于光滑算子的形成。内核被称为
1/h K((x - xi)/h)在高斯核的情况下,计算正态分布的密度,均值为xi,方差为h。因此:h越高,每个密度估计越平滑!
在学习的情况下,可以通过使用网格搜索来估计最佳带宽,例如通过测量密度估计的质量来进行网格搜索。这个例子向你展示了如何。如果您选择了一个好的带宽,您可以很好地估计密度函数:

https://stackoverflow.com/questions/56967554
复制相似问题