将图例添加到matplotlib中的scatterplot

时间:2017-06-05 11:39:02

标签: python-3.x matplotlib classification

我有以下代码可视化二进制分类问题的决策边界:

from matplotlib.colors import ListedColormap
from sklearn import neighbors, datasets
X_train, X_test, y_train, y_test = train_test_split(X_C2, y_C2,
                                                   random_state=0)
n_neighbors = [1, 3, 11]
weights = 'uniform'
h = .02  # step size in the mesh

# Create color maps (http://htmlcolorcodes.com/)
cmap_light = ListedColormap(['#F9F999', '#F3F3F3'])
cmap_bold = ListedColormap(['#FFFF00', '#000000'])
# For example, FF AA AA = RGB(255, 170, 170) 

for n in n_neighbors:
    # we create an instance of Neighbours Classifier and fit the data.
    clf = neighbors.KNeighborsClassifier(n, weights=weights)
    clf.fit(X_train, y_train)

    # ----------------------------- Mesh color i.e. background color begins ------------------------
    # Plot the decision boundary. For that, we will assign a color to each
    # point in the mesh [x_min, x_max]x[y_min, y_max].
    x_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1
    y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

    # Put the result into a color plot
    Z = Z.reshape(xx.shape)
    plt.figure(figsize=(7,5), dpi=100)
    plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
    # ------------------------------ Mesh color i.e. background color ends -------------------------

    # Plot also the training points
    # plt.figure(figsize=(10,4), dpi=80)
    plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cmap_bold, s=40, label = "class 0") # scatter plot of height vs width
    plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cmap_bold, s=40, label = "class 1") # scatter plot of height vs width
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.title("2-Class classification (k = %i, weights = '%s')" % (n, weights))
    plt.legend()
    plt.show()

我能做的最好的事情是: enter image description here 但是传说并不符合实际的分类。 X_train每个功能有2列。 y_train包含此数据的标签,即0和1。如何获得与点颜色相同的图例?

0 个答案:

没有答案
相关问题