Jonathan Mines
Jonathan Mines

Reputation: 73

Can I mock a function return that is called within another function call in a Python Test?

Is there a way to mock a function return that is called within another function call? For example:

def bar():
    return "baz"

def foo():
    return bar()

class Tests(unittest.TestCase):
    def test_code(self):
        # hijack bar() here to return "bat" instead
        assert(foo() == "bat")

I've tried using @mock.patch but have found that it only allows me to mock the function I'm calling, not a function that will be called as a result of calling a a different function.

Upvotes: 1

Views: 1642

Answers (3)

alex_noname
alex_noname

Reputation: 32073

Also, you could use pytest seattr method of monkeypatch fixture. For the first arguments, it accepts an object to be patched or a string will be interpreted as a dotted import path, with the last part being the attribute name:

# foo_module.py
def bar():
    return "baz"


def foo():
    return bar()

# test_foo.py
from foo_module import foo


def test_foo(monkeypatch):
    monkeypatch.setattr('foo_module.bar', lambda: 'bat')
    assert foo() == "bat"

Upvotes: 1

Mad Physicist
Mad Physicist

Reputation: 114260

You can write a context manager to temporarily swap out objects in the global namespace:

class Hijack:
    def __init__(self, name, replacement, namespace):
        self.name = name
        self.replacement = replacement
        self.namespace = namespace

    def __enter__(self):
        self.original = self.namespace[self.name]
        self.namespace[self.name] = self.replacement

    def __exit__(self, *args):
        self.namespace[self.name] = self.original

You can call your mock function with the hijacked method:

def bar():
    return "baz"

def bar_mock():
    return "bat"

def foo():
    return bar()

class Tests(unittest.TestCase):
    def test_code(self):
        with Hijack('bar', bar_mock, globals()):
            assert(foo() == "bat")

This is a pretty general approach that can be used outside of unit testing. In fact, it is pretty simple to generalize this to work on any mutable object that can be represented as some sort of mapping:

class Hijack:
    def __init__(self, name, replacement, namespace, getter=None, setter=None):
        self.name = name
        self.replacement = replacement
        self.namespace = namespace
        self.getter = type(namespace).__getitem__ if getter is None else getter
        self.setter = type(namespace).__setitem__ if setter is None else setter

    def __enter__(self):
        self.original = self.getter(self.namespace, self.name)
        self.setter(self.namespace, self.name, self.replacement)

    def __exit__(self, *args):
        self.setter(self.namespace, self.name, self.original)

For classes and other objects, you would use getter=getattr and setter=setattr. For situations where None is preferable to KeyError, you can use getter=dict.get, etc.

Upvotes: 2

Mad Physicist
Mad Physicist

Reputation: 114260

Generic Patch

unittest.mock.patch does exactly what my other answer suggested out of the box. You can add as many @patch annotations as you need, and the objects you select will be patched:

from unittest.mock import patch

def bar():
    return "baz"

def foo():
    return bar()

class Tests(unittest.TestCase):
    @patch(__name__ + '.bar', lambda: 'bat')
    def test_code(self):
        assert(foo() == "bat")

In this configuration, the function bar will be reinstated once test_code completes. If you want the same patch to apply to all test cases in your class, annotate the whole class:

@patch(__name__ + '.bar', lambda: 'bat')
class Tests(unittest.TestCase):
    def test_code(self):
        assert(foo() == "bat")

Patch Globals

You can also unittest.mock.patch.dict on your global namespace for the same result:

class Tests(unittest.TestCase):
    @patch.dict(globals(), {'bar': lambda: 'bat'})
    def test_code(self):
        assert(foo() == "bat")

Upvotes: 3

Related Questions