adam.ra
adam.ra

Reputation: 1078

Python: mock patch a module wherever it is imported from

I need to make sure that running unit tests won't trigger calling a heavy outer world function, say, this one:

# bigbad.py
def request(param):
    return 'I searched the whole Internet for "{}"'.format(param)

Multiple modules use this function (bigbad.request) and they import it differently (in real-life it may be imported from an external library as well). Say, there are two modules, a and b, where b depends on a and both use the function:

# a.py, from...import
from bigbad import request

def routine_a():
    return request('a')

# b.py, imports directly
import a
import bigbad

def routine_b():
    resp_a = a.routine_a()
    return 'resp_a: {}, resp_b=request(resp_a): {}'.format(resp_a, bigbad.request(resp_a))

Is there a way to make sure that bigbad.request is not ever called? This code mocks only one of the imports:

# test_b.py
import unittest
from unittest import mock
import b

with mock.patch('bigbad.request') as mock_request:
    mock_request.return_value = 'mocked'
    print(b.routine_b())

Obviously I could refactor b and change the imports but this way I cannot guarantee that someone during the future development is not going to break this provision. I believe tests should test behaviour than implementation details.

Upvotes: 23

Views: 22711

Answers (3)

Phanabani
Phanabani

Reputation: 23

For any people coming to this question from the future, I wrote a function to patch all imports of a given symbol.

This function returns a list of patchers for each import of the given symbol (a whole module, a specific function, or any other object). These patchers can then be started/stopped in your test fixture's setup/teardown areas (see the docstring for an example).

How it works:

  • Iterate through every currently visible module in sys.modules
  • If the module's name starts with match_prefix (optional) and does not contain skip_substring (optional), iterate through every local in the module
  • If the local is target_symbol, create a patcher for it, local to the module it's imported in

I recommend using an argument like skip_substring='test' so that you don't patch things imported by your test suite.

from typing import Any, Optional
import unittest.mock as mock
import sys

def patch_all_symbol_imports(
        target_symbol: Any, match_prefix: Optional[str] = None,
        skip_substring: Optional[str] = None
):
    """
    Iterate through every visible module (in sys.modules) that starts with
    `match_prefix` to find imports of `target_symbol` and return a list
    of patchers for each import.

    This is helpful when you want to patch a module, function, or object
    everywhere in your project's code, even when it is imported with an alias.

    Example:

    ::

        import datetime

        # Setup
        patchers = patch_all_symbol_imports(datetime, 'my_project.', 'test')
        for patcher in patchers:
            mock_dt = patcher.start()
            # Do stuff with the mock

        # Teardown
        for patcher in patchers:
            patcher.stop()

    :param target_symbol: the symbol to search for imports of (may be a module,
        a function, or some other object)
    :param match_prefix: if not None, only search for imports in
        modules that begin with this string
    :param skip_substring: if not None, skip any module that contains this
        substring (e.g. 'test' to skip unit test modules)
    :return: a list of patchers for each import of the target symbol
    """

    patchers = []

    # Iterate through all currently imported modules
    # Make a copy in case it changes
    for module in list(sys.modules.values()):
        name_matches = (
                match_prefix is None
                or module.__name__.startswith(match_prefix)
        )
        should_skip = (
            skip_substring is not None and skip_substring in module.__name__
        )
        if not name_matches or should_skip:
            continue

        # Iterate through this module's locals
        # Again, make a copy
        for local_name, local in list(module.__dict__.items()):
            if local is target_symbol:
                # Patch this symbol local to the module
                patchers.append(mock.patch(
                    f'{module.__name__}.{local_name}', autospec=True
                ))

    return patchers

For this question specifically, the following code could be used:

from bigbad import request

patchers = patch_all_symbol_imports(request, skip_substring='test')
for patcher in patchers:
    mock_request = patcher.start()
    mock_request.return_value = 'mocked'

print(b.routine_b())

for patcher in patchers:
    patcher.stop()

Upvotes: 2

warvariuc
warvariuc

Reputation: 59604

# a.py, from...import
from bigbad import request

To ensure that the original request is never called, you'll have to patch all the places where the reference is imported:

import mock
with mock.patch('a.request', return_value='mocked') as mock_request:
    ...

This is tedious, so if possible don't do from bigbad import request in your code, but use import bigbad; bigbad.request.

Another solution: if possible, change bigbad.py:

# bigbad.py
def _request(param):
    return 'I searched the whole Internet for "{}"'.format(param)


def request(param):
    return _request(param)

Then, even if some code does from bigbad import request, you'd be able to do with mock.patch('bigbad._request', return_value='mocked') as mock_request:.

Upvotes: 6

Alex Hall
Alex Hall

Reputation: 36033

import bigbad
bigbad.request = # some dummy function

This will work as long as it runs before any module that does from bigbad import request is run/imported. That is, as long as they run after, they will receive the dummy function.

Upvotes: 7

Related Questions