Reputation: 93
On databricks, I have a notebook of code and a notebook of unit tests.
The code is "imported" into the unit test notebook using the "%run" command.
How can I make a mock object of one of the functions in the code notebook from the unit test notebook? I'd typically use the patch context manager for this.
Here is the code notebook with the function to be patched (get_name):
# Databricks notebook source
def get_name_func():
return 'name1'
Here is the unit test code:
# Databricks notebook source:
from unittest.mock import patch
import inspect
# COMMAND ----------
# MAGIC %run ./get_name
# COMMAND ----------
def local_get_name():
return 'name_local'
# COMMAND ----------
get_name_func()
# COMMAND ----------
print(inspect.getmodule(get_name_func))
print(inspect.getsourcefile(get_name_func))
# COMMAND ----------
inspect.unwrap(get_name_func)
# COMMAND ----------
with patch('get_name_func') as mock_func:
print(mock_func)
# COMMAND ----------
with patch('local_get_name') as mock_func:
print(mock_func)
Both patch attempts, for the local function and the function in the code notebook, give the same error:
TypeError: Need a valid target to patch. You supplied: 'get_name_func'
The inspect commands return:
<module '__main__' from '/local_disk0/tmp/1625490167313-0/PythonShell.py'>
<command-6807918>
and
Out[38]: <function __main__.get_name_func()>
I've tried various combinations for the module path with no luck.
Strangely, __name__
returns '__main__'
. But using the path '__main__.get_name_func'
in the patch call does not work.
My belief is that if the object exists in the notebook (which it definitely does), then it must be patchable.
Any suggestions?
Upvotes: 2
Views: 1339
Reputation: 93
I had to make my own patching function:
class FunctionPatch():
'''
This class is a context manager that allows patching of functions "imported" from another notebook using %run.
The patch function must be at global scope (i.e. top level)
'''
def __init__(self, real_func_name: str, patch_func: Callable):
self._real_func_name = real_func_name
self._patch_func = patch_func
self._backup_real_func = None
def __enter__(self):
self._backup_real_func = globals()[self._real_func_name]
globals()[self._real_func_name] = self._patch_func
def __exit__(self, exc_type, exc_value, tb):
if exc_type is not None:
traceback.print_exception(exc_type, exc_value, tb)
globals()[self._real_func_name] = self._backup_real_func
Usage:
def test_function_patch_real_func():
return 'real1'
def test_function_patch():
assert test_function_patch_real_func() == 'real1'
def mock_func():
return 'mock1'
with FunctionPatch('test_function_patch_real_func', mock_func):
assert test_function_patch_real_func() == 'mock1'
assert test_function_patch_real_func() == 'real1'
Upvotes: 2