numpy meshgrid调整测试数据的大小

时间:2018-06-10 19:18:22

标签: python numpy mesh

我有一个形状为(100,49280)

的测试数据集

我试图根据http://scikit-learn.org/0.18/auto_examples/classification/plot_classifier_comparison.html中的示例来绘制svm分类中的数据,如下所示:

x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

所以我打印x_min,x_max,h,y_min,y_max:

-0.6541590243577957

7.5925517082214355

0.02

-0.9953588843345642

4.58396577835083

到目前为止一切顺利,根据我在文档中发现的内容,我希望得到一个包含这些值的网格数组。

检查我是否打印xx和yy并获取:

(279, 413)
(279, 413)
好像很可疑。当我到达这条线时:

Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])

我收到错误:ValueError:X.shape [1] = 2应该等于49280,即训练时的要素数

我从来没有使用过numpy的meshgrid函数,但是我在文档中读到的所有东西看起来都很正常,并且记录得很好,因为大多数numpy都是。我确信我错过了一些愚蠢的东西,但对于我的生活,我找不到它。有谁知道我哪里出错了?一旦我弄清楚问题是什么,我会尝试为这个问题提供一个更具信息性的标题。

1 个答案:

答案 0 :(得分:0)

np.arange函数不适合此方案。您可能希望使用np.linspace,如下所示

xx, yy = np.meshgrid(np.linspace(x_min, x_max), np.linspace(y_min, y_max))

np.linspace

的文档

https://docs.scipy.org/doc/numpy/reference/generated/numpy.linspace.html