拟合分类器中的奇怪错误

时间:2019-01-16 19:17:49

标签: scikit-learn

我正在使用Scikit-Learn和Tensorflow进行O'Reilly的动手机器学习

我正在训练MNIST数据集上的分类器,但出现错误

ValueError: The number of classes has to be greater than one; got 1 class

这是我的代码

mnist = fetch_openml('mnist_784', version=1, cache=True)

X, y = mnist["data"], mnist["target"]

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

y_train_5 = (y_train == 9)
y_test_5 = (y_test == 9)

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)

我已经三遍检查了我的代码,但仍然不确定发生了什么。

2 个答案:

答案 0 :(得分:3)

sklearn中来自MNIST数据集的标签包含字符串,而不是整数。因此,设置

y_train_5 = (y_train == '9')
y_test_5 = (y_test == '9')

当您使用整数进行检查时,所有内容都为False,Python警告您只有一个类。

答案 1 :(得分:0)

此过程完全正确,只需将数字放入字符串中,因为scikit中的标签需要字符串。

var debug: some View {
        MyViewWithError(property: self.property)
}