使用高斯混合模型拟合高斯

时间:2017-07-13 11:25:40

标签: python scikit-learn histogram gaussian mixture-model

我有data

我使用matplotlib绘制直方图:

n, bins, _= plt.hist(data, bins = 1000)
plt.show()

结果是:Histogram of the dataset

人们可以注意到三个甚至四个高斯分布。为了使高斯分布适合直方图,我遵循了example

import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture

# Define simple gaussian
def gauss_function(x, amp, x0, sigma):
    return amp * np.exp(-(x - x0) ** 2. / (2. * sigma ** 2.))


n, bins, _= plt.hist(data, bins = 1000)
samples = n
# Fit GMM
gmm = GaussianMixture(n_components = 3, covariance_type="full",  tol=0.0001)
gmm = gmm.fit(X=np.expand_dims(samples, 1))


minimum = np.min(bins)
maximum = np.max(bins)
# Evaluate GMM
gmm_x = np.linspace(minimum, maximum, 5000)
gmm_y = np.exp(gmm.score_samples(gmm_x.reshape(-1, 1)))

# Construct function manually as sum of gaussians
gmm_y_sum = np.full_like(gmm_x, fill_value=0, dtype=np.float32)
for m, c, w in zip(gmm.means_.ravel(), gmm.covariances_.ravel(), gmm.weights_.ravel()):
    gauss = gauss_function(x=gmm_x, amp=1, x0=m, sigma=np.sqrt(c))
    gmm_y_sum += gauss / np.trapz(gauss, gmm_x) * w

# Make regular histogram
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=[8, 5])
ax.plot(gmm_x, gmm_y, color="crimson", lw=4, label="GMM")
ax.plot(gmm_x, gmm_y_sum, color="black", lw=4, label="Gauss_sum", linestyle="dashed")

# Annotate diagram
ax.set_ylabel("Probability density")
ax.set_xlabel("Arbitrary units")

# Make legend
plt.legend()

plt.show()

然而,结果是: Result of the Gaussian mixture model

有人能帮助我理解为什么这么糟糕吗?

0 个答案:

没有答案