Reputation: 73
Is there a way to mock a function return that is called within another function call? For example:
def bar():
return "baz"
def foo():
return bar()
class Tests(unittest.TestCase):
def test_code(self):
# hijack bar() here to return "bat" instead
assert(foo() == "bat")
I've tried using @mock.patch
but have found that it only allows me to mock the function I'm calling, not a function that will be called as a result of calling a a different function.
Upvotes: 1
Views: 1642
Reputation: 32073
Also, you could use pytest seattr
method of monkeypatch
fixture. For the first arguments, it accepts an object to be patched or a string will be interpreted as a dotted import path, with the last part being the attribute name:
# foo_module.py
def bar():
return "baz"
def foo():
return bar()
# test_foo.py
from foo_module import foo
def test_foo(monkeypatch):
monkeypatch.setattr('foo_module.bar', lambda: 'bat')
assert foo() == "bat"
Upvotes: 1
Reputation: 114260
You can write a context manager to temporarily swap out objects in the global namespace:
class Hijack:
def __init__(self, name, replacement, namespace):
self.name = name
self.replacement = replacement
self.namespace = namespace
def __enter__(self):
self.original = self.namespace[self.name]
self.namespace[self.name] = self.replacement
def __exit__(self, *args):
self.namespace[self.name] = self.original
You can call your mock function with the hijacked method:
def bar():
return "baz"
def bar_mock():
return "bat"
def foo():
return bar()
class Tests(unittest.TestCase):
def test_code(self):
with Hijack('bar', bar_mock, globals()):
assert(foo() == "bat")
This is a pretty general approach that can be used outside of unit testing. In fact, it is pretty simple to generalize this to work on any mutable object that can be represented as some sort of mapping:
class Hijack:
def __init__(self, name, replacement, namespace, getter=None, setter=None):
self.name = name
self.replacement = replacement
self.namespace = namespace
self.getter = type(namespace).__getitem__ if getter is None else getter
self.setter = type(namespace).__setitem__ if setter is None else setter
def __enter__(self):
self.original = self.getter(self.namespace, self.name)
self.setter(self.namespace, self.name, self.replacement)
def __exit__(self, *args):
self.setter(self.namespace, self.name, self.original)
For classes and other objects, you would use getter=getattr
and setter=setattr
. For situations where None
is preferable to KeyError
, you can use getter=dict.get
, etc.
Upvotes: 2
Reputation: 114260
Generic Patch
unittest.mock.patch
does exactly what my other answer suggested out of the box. You can add as many @patch
annotations as you need, and the objects you select will be patched:
from unittest.mock import patch
def bar():
return "baz"
def foo():
return bar()
class Tests(unittest.TestCase):
@patch(__name__ + '.bar', lambda: 'bat')
def test_code(self):
assert(foo() == "bat")
In this configuration, the function bar
will be reinstated once test_code
completes. If you want the same patch to apply to all test cases in your class, annotate the whole class:
@patch(__name__ + '.bar', lambda: 'bat')
class Tests(unittest.TestCase):
def test_code(self):
assert(foo() == "bat")
Patch Globals
You can also unittest.mock.patch.dict
on your global namespace for the same result:
class Tests(unittest.TestCase):
@patch.dict(globals(), {'bar': lambda: 'bat'})
def test_code(self):
assert(foo() == "bat")
Upvotes: 3