使用mock在python unittest中模拟实例方法

时间:2015-08-17 08:16:16

标签: python unit-testing mocking

我正在尝试模拟另一个对象中包含的对象的实例方法。

我有两个对象

class CreditCard(object):
    def charge(self, amount):
        if amount < 0:
            raise Exception('Invalid amount')
        print "charging %d" % amount

class Transaction(object):
    def __init__(self, credit_card, amount):
        self.amount = amount
        self.credit_card = credit_card
        self.status = 'PRISTINE'

    def pay(self):
        self.credit_card.charge(self.amount)
        self.status = 'COMPLETE'

这是我正在尝试实现的简化示例

class TestTransaction(TestCase):

    def setUp(self):
        cc = CreditCard()
        t = Transaction(cc, 10)

    def test_should_not_mark_transaction_as_complete_if_charge_failed(self):
        with mock.patch('t.credit_card.charge') as mock_charge:
            mock_charge.side_effect = Exception
            with self.assertRaises(Exception):
                t.pay()
            self.assertEquaL(t.status, 'PRISTINE')

charge内部有很多逻辑,所以我试图通过模拟收费来隔离交易测试。

谢谢,python 2.7

2 个答案:

答案 0 :(得分:1)

由于您正在模拟类的方法,因此可以直接对其进行修补:

import mock
import unittest

from your_code import CreditCard, Transaction

class TestTransaction(unittest.TestCase):
    def setUp(self):
        self.cc = CreditCard()
        self.t = Transaction(self.cc, 10)

    @mock.patch('your_code.CreditCard.charge', side_effect=Exception)
    def test_should_not_mark_transaction_as_complete_if_charge_failed(self, mock_charge):
        with self.assertRaises(Exception):
            self.t.pay()
        self.assertEqual(self.t.status, 'PRISTINE')

请务必修改补丁中的模块路径,在哪里定义您的类。

此外,由于多个错误(在测试用例中使用全局变量而不是实例属性),我更正了您的初始示例。

答案 1 :(得分:0)

由于您要将CreditCard个实例传递到Transaction,只需创建一个Mock object并传递它。

class TestTransaction(TestCase):

    def setUp(self):
         self.fake_cc = Mock()
         # set up any side effects, return values etc. of the fake cc
         # in your tests as needed...
         self.t = Transaction(self.fake_cc, 10)

    def test_should_not_mark_transaction_as_complete_if_charge_failed(self):
        self.fake_cc.side_effect = Exception
        with self.assertRaises(Exception):
            self.t.pay()
相关问题