我正在使用 unittest 包编写一些测试函数。但出于某种原因,我无法修补 GridSearchCV。如果我运行这样的代码:
import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from unittest import TestCase
from unittest.mock import patch
def grid():
np.random.seed(1)
x_train = pd.DataFrame(np.random.choice(['0', '1'], size=100, replace=True).reshape(100,))
np.random.seed(3)
y_train = np.random.choice(['0', '1'], size=100, replace=True).reshape(100,)
stratified_cv_grid_search = GridSearchCV(
estimator=RandomForestClassifier(random_state=1),
param_grid={
'n_estimators': [100],
'max_depth': [None, 5],
'min_samples_split': [2, 3],
'min_samples_leaf': [2, 3],
},
cv=StratifiedKFold(2, shuffle=True, random_state=1),
scoring='f1_macro'
)
print(GridSearchCV)
print(pd.merge)
print(stratified_cv_grid_search)
# run cv
stratified_cv_grid_search.fit(x_train, y_train)
print(stratified_cv_grid_search)
print('--- Best params and best score ---')
print(stratified_cv_grid_search.best_params_)
class clfTest(TestCase):
@patch("pandas.merge")
@patch("sklearn.model_selection.GridSearchCV")
@patch("sklearn.model_selection.StratifiedKFold")
def test_clf(self, mock_kfold, mock_gridsearch, mock_merge):
grid()
输出为:
<class 'sklearn.model_selection._search.GridSearchCV'>
<MagicMock name='merge' id='1948904870352'>
GridSearchCV(cv=StratifiedKFold(n_splits=2, random_state=1, shuffle=True),
estimator=RandomForestClassifier(random_state=1),
param_grid={'max_depth': [None, 5], 'min_samples_leaf': [2, 3],
'min_samples_split': [2, 3], 'n_estimators': [100]},
scoring='f1_macro').
GridSearchCV(cv=StratifiedKFold(n_splits=2, random_state=1, shuffle=True),
estimator=RandomForestClassifier(random_state=1),
param_grid={'max_depth': [None, 5], 'min_samples_leaf': [2, 3],
'min_samples_split': [2, 3], 'n_estimators': [100]},
scoring='f1_macro')
--- Best params and best score ---
Error
Traceback (most recent call last):
File "C:\Users\larwinkl\.conda\envs\useenv\lib\unittest\mock.py", line 1183, in patched
return func(*args, **keywargs)
File "C:\Users\larwinkl\XCross Testing\X-Cross\tests\utils\for stackexchange.py", line 44, in test_clf
grid()
File "C:\Users\larwinkl\XCross Testing\X-Cross\tests\utils\for stackexchange.py", line 33, in grid
print(stratified_cv_grid_search.best_params_)
AttributeError: 'GridSearchCV' object has no attribute 'best_params_'
我们可以看到 pd.merge 被正确修补并作为 MagicMock 返回,而 GridSearchCV 没有任何反应。如何正确模拟 GridSearchCV 并相应地模拟 best_params 之类的属性?