从sklearn库运行示例时引发AssertionError

时间:2016-01-27 13:49:39

标签: python scikit-learn

import pandas as pd
import numpy as np
from sklearn.learning_curve import learning_curve
import matplotlib.pyplot as plt


def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
                        n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)):
    """
    Generate a simple plot of the test and traning learning curve.

    Parameters
    ----------
    estimator : object type that implements the "fit" and "predict" methods
        An object of that type which is cloned for each validation.

    title : string
        Title for the chart.

    X : array-like, shape (n_samples, n_features)
        Training vector, where n_samples is the number of samples and
        n_features is the number of features.

    y : array-like, shape (n_samples) or (n_samples, n_features), optional
        Target relative to X for classification or regression;
        None for unsupervised learning.

    ylim : tuple, shape (ymin, ymax), optional
        Defines minimum and maximum yvalues plotted.

    cv : integer, cross-validation generator, optional
        If an integer is passed, it is the number of folds (defaults to 3).
        Specific cross-validation objects can be passed, see
        sklearn.cross_validation module for the list of possible objects

    n_jobs : integer, optional
        Number of jobs to run in parallel (default 1).
    """
    plt.figure()
    plt.title(title)
    if ylim is not None:
        plt.ylim(*ylim)
    plt.xlabel("Training examples")
    plt.ylabel("Score")
    train_sizes, train_scores, test_scores = learning_curve(
        estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)
    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std = np.std(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)
    test_scores_std = np.std(test_scores, axis=1)
    plt.grid()

    plt.fill_between(train_sizes, train_scores_mean - train_scores_std,
                     train_scores_mean + train_scores_std, alpha=0.1,
                     color="r")
    plt.fill_between(train_sizes, test_scores_mean - test_scores_std,
                     test_scores_mean + test_scores_std, alpha=0.1, color="g")
    plt.plot(train_sizes, train_scores_mean, 'o-', color="r",
             label="Training score")
    plt.plot(train_sizes, test_scores_mean, 'o-', color="g",
             label="Cross-validation score")

    plt.legend(loc="best")
    return plt



forest = ensemble.RandomForestClassifier(bootstrap=True, class_weight=None, max_depth=None, max_features='auto', max_leaf_nodes=None,min_samples_leaf=1, min_samples_split=6,min_weight_fraction_leaf=0.0, n_estimators=300, n_jobs=-1,oob_score=False, random_state=111, verbose=0, warm_start=False)

cv = cross_validation.ShuffleSplit(alldata.shape[0], n_iter=10,
                                   test_size=0.2, random_state=0)

title = "Learning Curve (Random Forest)"
plot_learning_curve(forest, title, alldata, y, ylim=None, cv=cv, n_jobs=-1)

plt.show()

当我在IPython Notebook(Python 2.7)中运行此代码时,可以从cmd看到以下错误。我从the following website获取了plot_learning_curve函数。

enter image description here

2 个答案:

答案 0 :(得分:0)

错误是由多处理引起的。在Windows上使用多处理与在Unix上使用不同。您需要将主代码放在if __name__ == '__main__':子句下:

if __name__ == '__main__':

    forest = ensemble.RandomForestClassifier(bootstrap=True, class_weight=None,
                max_depth=None, max_features='auto', max_leaf_nodes=None,min_samples_leaf=1, min_samples_split=6,min_weight_fraction_leaf=0.0, n_estimators=300, n_jobs=-1,oob_score=False, random_state=111, verbose=0, warm_start=False)

    cv = cross_validation.ShuffleSplit(alldata.shape[0], n_iter=10,
                                       test_size=0.2, random_state=0)

    title = "Learning Curve (Random Forest)"
    plot_learning_curve(forest, title, alldata, y, ylim=None, cv=cv, n_jobs=-1)

    plt.show()

答案 1 :(得分:0)

用你的代码我得到了这个

milenko@milenko-X58-USB3:~$ python k1.py 
Traceback (most recent call last):
  File "k1.py", line 68, in <module>
    forest = ensemble.RandomForestClassifier(bootstrap=True, class_weight=None, max_depth=None, max_features='auto', max_leaf_nodes=None,min_samples_leaf=1, min_samples_split=6,min_weight_fraction_leaf=0.0, n_estimators=300, n_jobs=-1,oob_score=False, random_state=111, verbose=0, warm_start=False)
NameError: name 'ensemble' is not defined

我的python版

Python 2.7.11 :: Anaconda 2.4.1 (64-bit)

我认为你应该创建类集合。