Handmade Estimator修改__init__中的参数?

时间:2017-12-04 15:45:12

标签: python scikit-learn

我正在准备一个量身定制的预处理阶段,该阶段假设成为sklearn.pipeline.Pipeline的一部分。这是预处理器的代码:

import string
from nltk import wordpunct_tokenize
from nltk.stem.snowball import SnowballStemmer
from nltk import sent_tokenize
from sklearn.base import BaseEstimator, TransformerMixin
from . import stopwords

class NLTKPreprocessor(BaseEstimator, TransformerMixin):
    def __init__(self, stopwords=stopwords.STOPWORDS_DE,
                 punct=string.punctuation,
                 lower=True, strip=True, lang='german'):
        """
        Based on:
        https://bbengfort.github.io/tutorials/2016/05/19/text-classification-nltk-sckit-learn.html
        """

        self.lower = lower
        self.strip = strip
        self.stopwords = set(stopwords)
        self.punct = set(punct)
        self.stemmer = SnowballStemmer(lang)
        self.lang = lang

    def fit(self, X, y=None):
        return self

    def inverse_transform(self, X):
        return [" ".join(doc) for doc in X]

    def transform(self, X):
        return [
            list(self.tokenize(doc)) for doc in X
        ]

    def tokenize(self, document):
        # Break the document into sentences
        for sent in sent_tokenize(document, self.lang):
            for token in wordpunct_tokenize(sent):
                # Apply preprocessing to the token
                token = token.lower() if self.lower else token
                token = token.strip() if self.strip else token
                token = token.strip('_') if self.strip else token
                token = token.strip('*') if self.strip else token

                # If stopword, ignore token and continue
                if token in self.stopwords:
                    continue

                # If punctuation, ignore token and continue
                if all(char in self.punct for char in token):
                    continue

                # Lemmatize the token and yield
                # lemma = self.lemmatize(token, tag)
                stem = self.stemmer.stem(token)
                yield stem

接下来,这是我构建的管道:

pipeline = Pipeline(
    [
        ('preprocess', nltkPreprocessor),
        ('vectorize', TfidfVectorizer(tokenizer=identity, preprocessor=None, lowercase=False)),
        ('clf', SGDClassifier(max_iter=1000, tol=1e-3))       
    ]
)

这一切都很适合一次通过;例如pipeline.fit(X,y)效果很好。但是,将此管道放入网格搜索

parameters = {
    'vectorize__use_idf': (True, False),
    'vectorize__max_df': np.arange(0.8, 1.01 ,0.05),
    'vectorize__smooth_idf': (True, False),
    'vectorize__sublinear_tf': (True, False),
    'vectorize__norm': ('l1', 'l2'),
    'clf__loss':  ('hinge', 'log', 'modified_huber', 'squared_hinge', 'perceptron'),
    'clf__alpha': (0.00001, 0.000001),
    'clf__penalty': ('l1', 'l2', 'elasticnet')
}
grid_search = GridSearchCV(pipeline, parameters, n_jobs=-1, verbose=1)
grid_search.fit(X_train, y_train)

我收到以下警告:

/Users/user/anaconda3/envs/myenv/lib/python3.6/site-packages/sklearn/base.py:115: DeprecationWarning: Estimator NLTKPreprocessor modifies parameters in __init__. This behavior is deprecated as of 0.18 and support for this behavior will be removed in 0.20.
  % type(estimator).__name__, DeprecationWarning)

我不明白在实现中应该更改/修复的内容。如何维护功能并删除警告?

2 个答案:

答案 0 :(得分:1)

查看sklearn的开发者指南,herefollowing paragraph。我会尝试尽可能多地凝聚它,以确保避免这些消息(即使你从未打算贡献它)。

他们规定估算员在__init__函数中应该没有逻辑!这很可能会导致您的错误。

我在fit()方法的开头处对init参数进行了验证或转换(也在说明中也有说明),在任何情况下都必须调用它。

另外,请注意this实用程序,如果它确认到scikit learn API,您可以使用它来测试您的估算工具。

编辑(作为对评论的回复,但代码格式化):

嗯,不是逻辑。引用链接: "总而言之,__ init__应该如下所示:

def __init__(self, param1=1, param2=2):
    self.param1 = param1
    self.param2 = param2

应该没有逻辑,甚至没有输入验证,并且不应该更改参数。" 1

所以我猜@uberwach详细说明了SnowballStemmer实例的集合构造和创建可能违反了#34;不应该更改"部分。

编辑2:

除了下面的评论。这将是一种通用的方法(另一个特别是后来在你的tokenize方法中由@uberwach提到的):

class NLTKPreprocessor(BaseEstimator, TransformerMixin):
    def __init__(self, stopwords=stopwords.STOPWORDS_DE,
                 punct=string.punctuation,
                 lower=True, strip=True, lang='german'):
        self.lower = lower
        self.strip = strip
        self.stopwords = stopwords
        self.punct = punct
        self.lang = lang

    def fit(self, X, y=None):
        self.stopword_set = set(self.stopwords)
        self.punct_set = set(self.punct)
        self.stemmer = SnowballStemmer(self.lang)
        return self

答案 1 :(得分:1)

我阅读了https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/base.py

下的代码

可以重现警告消息。两次改变后他们离开了:

  1. frozenset而不是set。由于set被认为是可变的,因此在复制后会变得不同。

  2. self.stemmer方法而不是tokenize中初始化__init__

相关问题