simple liquids
simple liquids

Reputation: 165

Python - create mock test for class method that has context manager

I'm trying to write a unit test for a method of a class that has a context manager and many calls. I'm having a difficulty time understanding how to properly mock the function so that I can test the return value.
The class I am trying to mock is db. As you can see below I'm using a patch, but I'm not able to figure out how to get it to return the correct method call. I'm getting a generic mock function instead of the return value I expect.

db_class.py

import db

class Foo():
    def __init__(self):
        pass
    def method(self):
        with db.a() as a:
            b = a.b
            return b.fetch()

unit_db.py

 from mock import Mock, patch, MagicMock
 from db_class import Foo

 @patch('db_class.db')
 def test(db_mock):
     expected_result = [5,10]
     db_mock.return_value = Mock(__enter__ = db_mock,
                                 __exit___ = Mock(),
                                 b = Mock(fetch=expected_result))
 
     foo = Foo()
     result = foo.method()
     assert result == expected_result
     
 

Upvotes: 7

Views: 8936

Answers (2)

Peter K
Peter K

Reputation: 2484

Here is the same test, using pytest and mocker fixture:

def test(mocker):
    mock_db = mocker.MagicMock(name='db')
    mocker.patch('db_class.db', new=mock_db)
    expected_result = [5, 10]
    mock_db.a.return_value.__enter__.return_value.b.fetch.return_value = expected_result

    foo = db_class.Foo()
    result = foo.method()
    assert result == expected_result

You may find the way I wrote the test more interesting than the test itself - I have created a python library to help me with the syntax.

Here is how I approached your problem in a systematic way:

We start with the test you want and my helper library:

import db_class

from mock_autogen.pytest_mocker import PytestMocker

def test(mocker):
    # this would output the mocks we need
    print(PytestMocker(db_class).mock_modules().prepare_asserts_calls().generate())

    # your original test, without the mocks
    expected_result = [5,10]
    foo = db_class.Foo()
    result = foo.method()
    assert result == expected_result

Now the test obviously fails (AttributeError: module 'db' has no attribute 'a'), but the print output is useful:

# mocked modules
mock_db = mocker.MagicMock(name='db')
mocker.patch('db_class.db', new=mock_db)
# calls to generate_asserts, put this after the 'act'
import mock_autogen
print(mock_autogen.generator.generate_asserts(mock_db, name='mock_db'))

Now, I'm placing the mocks before the call to Foo() and the generate_asserts after, just before your assert, like so (no need for the previous print, so I removed it):

def test(mocker):
    # mocked modules
    mock_db = mocker.MagicMock(name='db')
    mocker.patch('db_class.db', new=mock_db)

    # your original test, without the mocks
    expected_result = [5,10]
    foo = db_class.Foo()
    result = foo.method()

    # calls to generate_asserts, put this after the 'act'
    import mock_autogen
    print(mock_autogen.generator.generate_asserts(mock_db, name='mock_db'))

    assert result == expected_result

Now the assert fails (AssertionError: assert <MagicMock name='db.a().__enter__().b.fetch()' id='139996983259768'> == [5, 10]), but we have once more gained some valuable input:

mock_db.a.return_value.__enter__.assert_called_once_with()
mock_db.a.return_value.__enter__.return_value.b.fetch.assert_called_once_with()
mock_db.a.return_value.__exit__.assert_called_once_with(None, None, None)

Notice the second line, it's almost what you need to mock. With a slight alteration it would look like mock_db.a.return_value.__enter__.return_value.b.fetch.return_value = expected_result, and with that, we can have the final version of the test:

def test(mocker):
    mock_db = mocker.MagicMock(name='db')
    mocker.patch('db_class.db', new=mock_db)
    expected_result = [5, 10]
    mock_db.a.return_value.__enter__.return_value.b.fetch.return_value = expected_result

    foo = db_class.Foo()
    result = foo.method()
    assert result == expected_result

You may add the additional automatically generated asserts, or alter them to include additional asserts if you find that useful.

Upvotes: 3

simple liquids
simple liquids

Reputation: 165

Thanks to the commenters I have found a solution that works for me. The trick was to patch the correct class, in this case I wanted to patch db_class.db.a instead of db_class.db. After that, it is important to make sure that the fetch() call is a method (I think I'm getting that correct). The tricky part about this problem for me was patching the correct thing as well as dealing with the context manager which requires a bit of extra tinkering.

@patch('db_class.db.a')
def test(db_a):
    expected_result = [5,10]
    b_fetch = MagicMock()
    b_fetch.fetch.return_value = expected_result 
    db_a.return_value = Mock(b = b_fetch,
                         __enter__= db_a,
                         __exit__ =Mock())
    foo = Foo()
    result = foo.method()
    assert result == expected_result

if __name__ == "__main__":
    test()

Upvotes: 6

Related Questions