Reputation: 1466
Consider that I have 2 files:
some_module.py
def get_value_1():
return 1
def print_value_1():
print(get_value_1())
and main file main.py
from unittest.mock import patch
import some_module
from some_module import print_value_1, get_value_1
def mocked_value():
return 2
if __name__ == '__main__':
print_value_1() # prints 1
with patch.object(some_module, 'get_value_1', mocked_value):
print_value_1() # prints 2
print(some_module.get_value_1()) # prints 2
print(get_value_1()) # prints 1 - DESIRABLE RESULT IS TO PRINT ALSO 2
As you can see because I explicitly imported the get_value_1
function, the patch is not working on it. I understand basically why, that's because it uses a reference and the reference is imported before the main ran (checked it with calling id()
on each invoked function and saw the addresses). Can I somehow hijack also the imported reference?
(It won't be enough to patch it only in main.py, I want it to be patched all over the project, so for example in some other some_other_module.py
there will be: from some_module import get_value_1
and when I call get_value_1()
it will call the patched function and return the value 2)
Upvotes: 3
Views: 1197
Reputation: 16805
If using patch
or patch.object
, there is no way around patching every reference of the module. In your case this would be for example:
if __name__ == '__main__':
print_value_1()
with patch.object(some_module, 'get_value_1', mocked_value):
with patch.object(sys.modules[__name__], 'get_value_1', mocked_value):
print_value_1()
print(some_module.get_value_1())
print(get_value_1())
Depending on how your app structure looks, you could iterate over all modules that reference the function to be patched, e.g.:
def get_modules():
return (sys.modules[__name__], some_module, some_other_module)
if __name__ == '__main__':
patches = []
for module in get_modules():
p = patch.object(module, 'get_value_1', mocked_value)
p.start()
patches.append(p)
print(some_module.get_value_1())
print(some_other_module.get_value_1())
print(get_value_1())
[p.stop() for p in patches]
If you don't have all modules to patch beforehand, you have to collect all modules you need to patch at runtime (e.g. get_modules
gets more complicated), for example by iterating over all loaded modules, find your function to be patched by name, and mock that (assuming that at that point all modules are loaded).
Upvotes: 2