首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用pyplot.Spectral创建的调色板在循环中绘制散点图?

如何使用pyplot.Spectral创建的调色板在循环中绘制散点图?
EN

Stack Overflow用户
提问于 2019-07-08 21:43:56
回答 1查看 471关注 0票数 1

我有一个二维点的数据集,我想使用K-means技术对其进行分类。

数据:

代码语言:javascript
复制
import numpy as np

x1 = np.array([3,1,1,2,1,6,6,6,5,6,7,8,9,8,9,9,8])
x2 = np.array([5,4,5,6,5,8,6,7,6,7,1,2,1,2,3,2,3])
X = np.array(list(zip(x1,x2))).reshape(len(x1), 2)

我想对从1到9的集群数量进行迭代,以测试散点图上的最终分布。所以我计算了数据集的质心。

代码语言:javascript
复制
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt

max_k = 10
K = range(1,max_k)
centroid = [sum(X)/len(X) for k in K]
sst = sum(np.min(cdist(X, centroid, "euclidean"), axis = 1))

然后使用cm.Spectral为每个迭代创建一个具有一种rgb颜色的调色板。

代码语言:javascript
复制
color_palette = [plt.cm.Spectral(float(k)/max_k) for k in K]

并在循环中使用它,我在循环中迭代了k

代码语言:javascript
复制
from sklearn.cluster import KMeans
import pandas as pd

ssw = []
for k in K:
    kmeanModel = KMeans(n_clusters=k).fit(X)

    centers = pd.DataFrame(kmeanModel.cluster_centers_)
    labels = kmeanModel.labels_

    ssw_k = sum(np.min(cdist(X, kmeanModel.cluster_centers_), axis = 1))
    ssw.append(ssw_k)

    label_color = [color_palette[i] for i in labels]

    plt.plot()
    plt.xlim([0,10])
    plt.ylim([0,10])
    plt.title("Clustering for k = %s"%str(k))
    plt.scatter(x1,x2, c=label_color)
    plt.scatter(centers[0], centers[1], c=color_palette, marker = "x")
    plt.show()

我在我的Python 3.7.3版本中重现了这段代码,从这段代码的源代码中我知道它在旧版本中工作得很好。当matplotlib.pyplot.cm中的函数Spectral以小写形式(spectral)编写时。

结果就是下一个。

代码语言:javascript
复制
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py in scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, **kwargs)
   4237                     valid_shape = False
-> 4238                     raise ValueError
   4239             except ValueError:

ValueError: 

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-26-2f513f9c616c> in <module>
     24     plt.title("Clustering for k = %s"%str(k))
     25     plt.scatter(x1,x2, c=label_color)
---> 26     plt.scatter(centers[0], centers[1], c=[i for i in color_palette], marker = "x")
     27     plt.show()

~/anaconda3/lib/python3.7/site-packages/matplotlib/pyplot.py in scatter(x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, data, **kwargs)
   2860         vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths,
   2861         verts=verts, edgecolors=edgecolors, **({"data": data} if data
-> 2862         is not None else {}), **kwargs)
   2863     sci(__ret)
   2864     return __ret

~/anaconda3/lib/python3.7/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
   1808                         "the Matplotlib list!)" % (label_namer, func.__name__),
   1809                         RuntimeWarning, stacklevel=2)
-> 1810             return func(ax, *args, **kwargs)
   1811 
   1812         inner.__doc__ = _add_data_doc(inner.__doc__,

~/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py in scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, **kwargs)
   4243                         "acceptable for use with 'x' with size {xs}, "
   4244                         "'y' with size {ys}."
-> 4245                         .format(nc=n_elem, xs=x.size, ys=y.size)
   4246                     )
   4247                 # Both the mapping *and* the RGBA conversion failed: pretty

ValueError: 'c' argument has 9 elements, which is not acceptable for use with 'x' with size 1, 'y' with size 1.

我希望每个组的中心颜色就像组本身一样。

提前谢谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-07-08 21:49:00

尝试通过索引使用相应大小的调色板,该索引对应于x和y值的长度,如下所示。

附言:您的代码在matplotlib 2.2.2中运行良好

代码语言:javascript
复制
for i, k in enumerate(K):
    # rest of your code

    plt.scatter(centers[0], centers[1], c=color_palette[0:i+1], marker = "x")
    print (centers[0].values)
    plt.show()
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56936255

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档