我正在尝试构建一个名为SafeModel
的工厂类,其generate
方法接受scikit-learn类的实例,更改其某些属性,以及返回相同的实例。具体来说,对于此示例,我想访问返回模型的coef_
属性,在案例1中,如果基础scikit-learn类包含coef_
,则返回该类{{1如果基础scikit-learn类包含coef_
,则返回该类feature_importances_
。
我已经成功完成了Python类实例的猴子修补属性。我对Python类实例的修补魔术方法的成功较少。我的案例的警告是:属性feature_importances
和coef_
永远不会在scikit-learn类实例化时定义;相反,它们仅在对各自的类调用feature_importances
方法后定义。出于这个原因,我无法覆盖属性定义本身。
fit
答案 0 :(得分:0)
我无法确定您的代码有什么问题。我发现了一个可能适用于您的用例的工作流程
我使用了不同的策略,因为我只是使用SafeModel.__getattr__
作为模型getattr
方法的包装而不是猴子修补。
from sklearn.utils.validation import NotFittedError
from sklearn.ensemble import RandomForestClassifier
class SafeModel(object):
def __init__(self, model):
self.FALLBACK_ATTRIBUTES = {
'coef_': ['feature_importances_'],
}
self.model = model
def __getattr__(self, name):
try:
return getattr(self.model, name)
except AttributeError:
pass
for fallback_attribute in self.FALLBACK_ATTRIBUTES[name]:
try:
return getattr(self.model, fallback_attribute)
except NotFittedError as e:
# NotFittedError inherits AttributeError.
raise e
except AttributeError:
continue
else:
raise AttributeError(
"{} object has no attribute {}.".format(
self.__class__.__name__, name) +
" Could not retrieve any fallback attribute.")
model = SafeModel(RandomForestClassifier())
model.coef_
输出:
NotFittedError: Estimator not fitted, call `fit` before `feature_importances_`.
请注意,这是正常行为,正如您所提到的,在您适应随机森林之前,您无法访问feature_importances_
。
不可否认,异常捕获在这里相当脆弱(您需要添加一堆可能会被引发的异常),但如果您在尝试访问时不关心提出正确的异常属性应该没问题。
如果这对你有用,请告诉我。如果你发现你发布的代码发生了什么,我也会对解释感兴趣!