Kevin Maguire
Kevin Maguire

Reputation: 93

In databricks, using unittest.mock.patch on function in a different notebook

On databricks, I have a notebook of code and a notebook of unit tests.

The code is "imported" into the unit test notebook using the "%run" command.

How can I make a mock object of one of the functions in the code notebook from the unit test notebook? I'd typically use the patch context manager for this.

Here is the code notebook with the function to be patched (get_name):

# Databricks notebook source
def get_name_func():
  return 'name1'

Here is the unit test code:

# Databricks notebook source:

from unittest.mock import patch
import inspect


# COMMAND ----------


# MAGIC %run ./get_name


# COMMAND ----------


def local_get_name():
  return 'name_local'


# COMMAND ----------


get_name_func()


# COMMAND ----------


print(inspect.getmodule(get_name_func))
print(inspect.getsourcefile(get_name_func))


# COMMAND ----------


inspect.unwrap(get_name_func)


# COMMAND ----------


with patch('get_name_func') as mock_func:
  print(mock_func)


# COMMAND ----------
with patch('local_get_name') as mock_func:
  print(mock_func)

Both patch attempts, for the local function and the function in the code notebook, give the same error:

TypeError: Need a valid target to patch. You supplied: 'get_name_func'

The inspect commands return:

<module '__main__' from '/local_disk0/tmp/1625490167313-0/PythonShell.py'>
<command-6807918>

and

Out[38]: <function __main__.get_name_func()>

I've tried various combinations for the module path with no luck.

Strangely, __name__ returns '__main__'. But using the path '__main__.get_name_func' in the patch call does not work.

My belief is that if the object exists in the notebook (which it definitely does), then it must be patchable.

Any suggestions?

Upvotes: 2

Views: 1339

Answers (1)

Kevin Maguire
Kevin Maguire

Reputation: 93

I had to make my own patching function:

class FunctionPatch():
  '''
  This class is a context manager that allows patching of functions "imported" from another notebook using %run.
  
  The patch function must be at global scope (i.e. top level)
  '''
  def __init__(self, real_func_name: str, patch_func: Callable):
    self._real_func_name = real_func_name
    self._patch_func = patch_func
    self._backup_real_func = None
    
  def __enter__(self):
    self._backup_real_func = globals()[self._real_func_name]
    globals()[self._real_func_name] = self._patch_func
    
  def __exit__(self, exc_type, exc_value, tb):
    if exc_type is not None:
      traceback.print_exception(exc_type, exc_value, tb)
      
    globals()[self._real_func_name] = self._backup_real_func

Usage:

def test_function_patch_real_func():
  return 'real1'

def test_function_patch():
  
  assert test_function_patch_real_func() == 'real1'
  
  def mock_func():
    return 'mock1'
  
  with FunctionPatch('test_function_patch_real_func', mock_func):
    assert test_function_patch_real_func() == 'mock1' 
    
  assert test_function_patch_real_func() == 'real1'

Upvotes: 2

Related Questions