Johnny Metz
Johnny Metz

Reputation: 5965

Mock same Python function across different files in a single mock variable

Let's say I have the following python files:

# source.py
def get_one():
    return 1

# file1.py
from source import get_one
def func1():
    return get_one()

# file2.py
from source import get_one
def func2():
    return get_one()

# script.py
from file1 import func1
from file2 import func2
def main(a, b):
    count = 0
    for _ in range(a):
        count += func1()
    for _ in range(b):
        count += func2()
    return count

I know I can mock out get_one() in main.py using the following setup:

def test_mock():
    with (
        patch("file1.get_one") as mock1,
        patch("file2.get_one") as mock2,
    ):
        main(2, 3)
    assert mock1.call_count + mock2.call_count == 5

However, this gets increasingly verbose and hard to read if get_one() needs to be mocked out in many files. I would love to be able to mock out all of its locations in a single mock variable. Something like:

# this test fails, I'm just showing what this ideally would look like
def test_mock():
    with patch("file1.get_one", "file2.get_one") as mock:
        main(2, 3)
    assert mock.call_count == 5

Is there anyway to do this or do I need to use multiple mocks?

Note, I know I can't mock where the function is defined, e.g. patch("source.get_one").

Upvotes: 2

Views: 1525

Answers (2)

aaron
aaron

Reputation: 43083

patch accepts new as the object to patch the target with:

def test_mock():
    with (
        patch("file1.get_one") as mock,
        patch("file2.get_one", new=mock),
    ):
        main(2, 3)
    assert mock.call_count == 5, mock.call_count

You can write a helper context manager:

import contextlib
from unittest.mock import DEFAULT, patch


@contextlib.contextmanager
def patch_same(target, *targets, new=DEFAULT):
    with patch(target, new=new) as mock:
        if targets:
            with patch_same(*targets, new=mock):
                yield mock
        else:
            yield mock

Usage:

def test_mock():
    with patch_same("file1.get_one", "file2.get_one") as mock:
        main(2, 3)
    assert mock.call_count == 5

Upvotes: 2

Samwise
Samwise

Reputation: 71454

A possible workaround: add a level of indirection in source.py so that get_one has a dependency that you can patch:

def get_one():
    return _get_one()

def _get_one():
    return 1

and now you can do this in your test:

from script import main
from unittest.mock import patch

def test_mock():
    with patch("source._get_one") as mock1:
        main(2, 3)
    assert mock1.call_count == 5

Upvotes: 1

Related Questions