在python单元测试中模拟嵌套函数

时间:2021-06-28 05:52:41

标签: python unit-testing mocking pytest python-unittest

我试图使用模拟和补丁对 Some.py 中的 some_fun() 进行单元测试,但它没有按预期工作我无法正确修补 db_helper.fetchallmetrics 函数中的 cursor.fetchall()。下面的函数 test_some() 抛出这样的错误

“断言错误:((35235, 4), (342, 3)) !=

db_helper.py

import settings
import pyodbc


class Connection:
    def __init__(self):
        self.SERVER= settings.SERVER
        self.DATABASE = settings.DATABASE
        self.USER = settings.USER
        self.PASSWORD = settings.PASSWORD


    def connect_db(self):
        try:
            conn = pyodbc.connect('DRIVER={ODBC Driver 17 for SQL Server};SERVER='+self.SERVER+';DATABASE='+self.DATABASE+';UID='+self.USER+';PWD='+self.PASSWORD)
            self.LOGGER.info("Connected to PSP Database")
            return conn
        except Exception as e:
            self.LOGGER.error("Unable to connect Database. {0}".format(e))
def fetchallmetrics(self,connection, query):
    cursor = connection.cursor()
    cursor.execute(query)
    metrics=cursor.fetchall()
    return metrics

一些.py

import settings
import db_helper  as dh
import constant

class SomeMetrics:
    def __init__(self):
        # Set the date of which metric is collected
        self.metric_collected_date = settings.METRICS_COLLECTED_DATE
        self.LOGGER = settings.LOGGER

    def some_fun(self, connection):
        try:
            count = dh.Connection().fetchallmetrics(connection,"SELECT * FROM TABLE")
            return count

        except Exception as e:
            self.LOGGER.error("{0} at some_fun()".format(e))

test_some.py

import unittest
from unittest import mock
from mock import patch, Mock
from Some import SomeMetrics

class TestSomeMetrics(unittest.TestCase):
    
    def setUp(self):
        self.km = SomeMetrics()
    
    @mock.patch("db_helper.Connection.connect_db", autospec=True)
    @mock.patch("db_helper.Connection.fetchallmetrics", autospec=True)
    def test_some_fun(self, mock_some_connection, mock_fetchallmetrics):
        posting_count = ((35235, 4), (342, 3))
        mock_data_interface = Mock()
        mock_fetchallmetrics_interface = Mock()
        mock_fetchallmetrics.return_value = mock_fetchallmetrics_interface
        mock_fetchallmetrics_interface.cursor.return_value.fetchall.return_value = posting_count
        mock_some_connection.return_value = mock_data_interface
        self.assertEqual(posting_count,self.km.some_fun(mock_some_connection))

1 个答案:

答案 0 :(得分:0)

在我看来,如果您创建一个单独的类 MockConnection

class MockConnection:
    def __init__(self,metrics):
        self.metrics = metrics
    def fetchallmetrics(self,connection, query):
        return self.metrics

# ...

def test_some_fun(self):
    posting_count = ((35235, 4), (342, 3))
    mock_connection = MockConnection(posting_count)
    self.assertEqual(posting_count, self.km.some_fun(mock_connection))