你怎么能 unittest.mock.patch GridSearchCV?

时间:2021-07-26 10:27:44

标签: python scikit-learn mocking python-unittest

我正在使用 unittest 包编写一些测试函数。但出于某种原因,我无法修补 GridSearchCV。如果我运行这样的代码:

import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from unittest import TestCase
from unittest.mock import patch


def grid():
    np.random.seed(1)
    x_train = pd.DataFrame(np.random.choice(['0', '1'], size=100, replace=True).reshape(100,))
    np.random.seed(3)
    y_train = np.random.choice(['0', '1'], size=100, replace=True).reshape(100,)

    stratified_cv_grid_search = GridSearchCV(
        estimator=RandomForestClassifier(random_state=1),
        param_grid={
            'n_estimators': [100],
            'max_depth': [None, 5],
            'min_samples_split': [2, 3],
            'min_samples_leaf': [2, 3],
        },
        cv=StratifiedKFold(2, shuffle=True, random_state=1),
        scoring='f1_macro'
    )
    print(GridSearchCV)
    print(pd.merge)
    print(stratified_cv_grid_search)
    # run cv
    stratified_cv_grid_search.fit(x_train, y_train)
    print(stratified_cv_grid_search)
    print('--- Best params and best score ---')
    print(stratified_cv_grid_search.best_params_)
        

class clfTest(TestCase):

    @patch("pandas.merge")
    @patch("sklearn.model_selection.GridSearchCV")
    @patch("sklearn.model_selection.StratifiedKFold")
    def test_clf(self, mock_kfold, mock_gridsearch, mock_merge):
            grid()

输出为:

<class 'sklearn.model_selection._search.GridSearchCV'>
<MagicMock name='merge' id='1948904870352'>
GridSearchCV(cv=StratifiedKFold(n_splits=2, random_state=1, shuffle=True),
             estimator=RandomForestClassifier(random_state=1),
             param_grid={'max_depth': [None, 5], 'min_samples_leaf': [2, 3],
                         'min_samples_split': [2, 3], 'n_estimators': [100]},
             scoring='f1_macro').
GridSearchCV(cv=StratifiedKFold(n_splits=2, random_state=1, shuffle=True),
             estimator=RandomForestClassifier(random_state=1),
             param_grid={'max_depth': [None, 5], 'min_samples_leaf': [2, 3],
                         'min_samples_split': [2, 3], 'n_estimators': [100]},
             scoring='f1_macro')
--- Best params and best score ---

Error
Traceback (most recent call last):
  File "C:\Users\larwinkl\.conda\envs\useenv\lib\unittest\mock.py", line 1183, in patched
    return func(*args, **keywargs)
  File "C:\Users\larwinkl\XCross Testing\X-Cross\tests\utils\for stackexchange.py", line 44, in test_clf
    grid()
  File "C:\Users\larwinkl\XCross Testing\X-Cross\tests\utils\for stackexchange.py", line 33, in grid
    print(stratified_cv_grid_search.best_params_)
AttributeError: 'GridSearchCV' object has no attribute 'best_params_'

我们可以看到 pd.merge 被正确修补并作为 MagicMock 返回,而 GridSearchCV 没有任何反应。如何正确模拟 GridSearchCV 并相应地模拟 best_params 之类的属性?

0 个答案:

没有答案
相关问题