我有一个形状为(100,49280)
的测试数据集我试图根据http://scikit-learn.org/0.18/auto_examples/classification/plot_classifier_comparison.html中的示例来绘制svm分类中的数据,如下所示:
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
所以我打印x_min,x_max,h,y_min,y_max:
-0.6541590243577957
7.5925517082214355
0.02
-0.9953588843345642
4.58396577835083
到目前为止一切顺利,根据我在文档中发现的内容,我希望得到一个包含这些值的网格数组。
检查我是否打印xx和yy并获取:
(279, 413)
(279, 413)
好像很可疑。当我到达这条线时:
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
我收到错误:ValueError:X.shape [1] = 2应该等于49280,即训练时的要素数
我从来没有使用过numpy的meshgrid函数,但是我在文档中读到的所有东西看起来都很正常,并且记录得很好,因为大多数numpy都是。我确信我错过了一些愚蠢的东西,但对于我的生活,我找不到它。有谁知道我哪里出错了?一旦我弄清楚问题是什么,我会尝试为这个问题提供一个更具信息性的标题。
答案 0 :(得分:0)
np.arange
函数不适合此方案。您可能希望使用np.linspace
,如下所示
xx, yy = np.meshgrid(np.linspace(x_min, x_max), np.linspace(y_min, y_max))
np.linspace
https://docs.scipy.org/doc/numpy/reference/generated/numpy.linspace.html