关于scikit-learn Class Instance的Monkey-Patching Magic方法

时间:2015-08-17 21:06:09

标签: python scikit-learn monkeypatching magic-methods

我正在尝试构建一个名为SafeModel的工厂类,其generate方法接受scikit-learn类的实例,更改其某些属性,以及返回相同的实例。具体来说,对于此示例,我想访问返回模型的coef_属性,在案例1中,如果基础scikit-learn类包含coef_,则返回该类{{1如果基础scikit-learn类包含coef_,则返回该类feature_importances_

我已经成功完成了Python类实例的猴子修补属性。我对Python类实例的修补魔术方法的成功较少。我的案例的警告是:属性feature_importancescoef_永远不会在scikit-learn类实例化时定义;相反,它们仅在对各自的类调用feature_importances方法后定义。出于这个原因,我无法覆盖属性定义本身。

fit

1 个答案:

答案 0 :(得分:0)

我无法确定您的代码有什么问题。我发现了一个可能适用于您的用例的工作流程 我使用了不同的策略,因为我只是使用SafeModel.__getattr__作为模型getattr方法的包装而不是猴子修补。

from sklearn.utils.validation import NotFittedError
from sklearn.ensemble import RandomForestClassifier

class SafeModel(object):

    def __init__(self, model):
        self.FALLBACK_ATTRIBUTES = {
        'coef_': ['feature_importances_'],
    }
        self.model = model

    def __getattr__(self, name):
        try:
            return getattr(self.model, name)
        except AttributeError:
            pass
        for fallback_attribute in self.FALLBACK_ATTRIBUTES[name]:
            try:
                return getattr(self.model, fallback_attribute)
            except NotFittedError as e:
                # NotFittedError inherits AttributeError.
                raise e
            except AttributeError:
                continue
        else:
            raise AttributeError(
                "{} object has no attribute {}.".format(
                    self.__class__.__name__, name) + 
                " Could not retrieve any fallback attribute.")                    


model = SafeModel(RandomForestClassifier())
model.coef_   

输出:

NotFittedError: Estimator not fitted, call `fit` before `feature_importances_`.

请注意,这是正常行为,正如您所提到的,在您适应随机森林之前,您无法访问feature_importances_

不可否认,异常捕获在这里相当脆弱(您需要添加一堆可能会被引发的异常),但如果您在尝试访问时不关心提出正确的异常属性应该没问题。

如果这对你有用,请告诉我。如果你发现你发布的代码发生了什么,我也会对解释感兴趣!