Robert
Robert

Reputation: 1636

Python: how can I override a complicated function during unittest?

I am using Python's unittest module to test a script I am writing.

The script contains a loop like this:

// my_script.py

def my_loopy_function():
    aggregate_value = 0
    for x in range(10):
        aggregate_value = aggregate_value + complicated_function(x)
    return aggregate_value

def complicated_function(x):
    a = do()
    b = something()
    c = complicated()
    return a + b + c

I have no problems using unittest to test complicated_function. But I would like to test my_loopy_function by overriding complicated_function.

I tried modifying my script so that my_loopy_function takes complicated_function as an optional parameter so that I can pass in a simple version from the test:

// my_modified_script.py

def my_loopy_function(action_function=None):
    if action_function is not None:
        complicated_function = action_function
    aggregate_value = 0
    for x in range(10):
        aggregate_value = aggregate_value + complicated_function(x)
    return aggregate_value

def complicated_function(x):
    a = do()
    b = something()
    c = complicated()
    return a + b + c

// test_my_script.py

from myscript import my_loopy_function

class TestMyScript(unittest.TestCase):
    test_loopy_function(self):
        def simple_function():
            return 1
    self.assertEqual(10, my_loopy_function(action_function=simple_function))

It has not worked as I had hoped, are there any suggestions on how I should be doing this?

Upvotes: 6

Views: 14990

Answers (4)

Hi-Angel
Hi-Angel

Reputation: 5619

I couldn't make @mock.patch from the accepted answer work in any combination (worth noting I'm using pytest, but then the answer seems to be testing library-agnostic, so Idk), but I found a more generic way to do that, so sharing.

There's a technique called "monkey-patching". Its purpose is exactly to override a method from a 3rd party module. It works both with top-level functions as well as with class methods (in the latter case you need to add a class name in the assignment operation).

It works like this: you first do import some_module and then to override its function foo to bar you execute some_module.foo = bar.

Important: the module should be imported without from. I.e. if you do a from some_module import foo and then foo = bar, it will not work.

Example:

λ cat some_module.py
def hello():
    print("hello")

def some_func():
    hello()
λ cat test.py
import some_module

def my_override():
    print("hi")

some_module.hello = my_override

some_module.some_func()
λ python3 test.py
hi

Upvotes: 0

Robert
Robert

Reputation: 1636

In the end I used Python's mock, which allows me to override complicated_function without having to adjust the original code in any way.

Here is the original script, and note that complicated_function is not passed in to my_loopy_function as an 'action_function' parameter (which was what I tried in my earlier solutions):

// my_script.py

def my_loopy_function():
    aggregate_value = 0
    for x in range(10):
        aggregate_value = aggregate_value + complicated_function(x)
    return aggregate_value

def complicated_function(x):
    a = do()
    b = something()
    c = complicated()
    return a + b + c

and here is the script I am using to test it:

// test_my_script.py

import unittest
import mock
from my_script import my_loopy_function

class TestMyModule(unittest.TestCase):
    @mock.patch('my_script.complicated_function')
    def test_1(self, mocked):
        mocked.return_value = 1
        self.assertEqual(10, my_loopy_function())

This works just as I had wanted:

  1. I am able to substitute functions with a simpler version of themselves that I can more easily test,
  2. I do not need to alter my original code in any way (such as I was trying -- which was effectively by passing in function pointers), the mock module gives me post-coding access to the innards.

Thanks to austin for his suggestion to use mock. BTW I am using Python 2.7 and therefore used the pip-installable mock from PyPI.

Upvotes: 8

Daniel Pryden
Daniel Pryden

Reputation: 60957

In your code, you shouldn't be able to overwrite complicated_function like that. If I try it, I get UnboundLocalError: local variable 'complicated_function' referenced before assignment.

But perhaps is the problem that in your actual code, you're referring to complicated_function in some other way (e.g. as a member of a module)? Then by overwriting it in your test, you're overwriting the actual complicated_function, so you won't be able to use it from other tests.

The correct way to do this is to overwrite the local variable with the global one, like so:

def my_loopy_function(action_function=None):
  if action_function is None:
    action_function = complicated_function
  aggregate_value = 0
    for x in range(10):
      # Use action_function here instead of complicated_function
      aggregate_value = aggregate_value + action_function(x)
    return aggregate_value

Upvotes: 0

Thomas Orozco
Thomas Orozco

Reputation: 55197

Don't try to override complicated_function with action_function, just use complicated_function as the default action_function:

def my_loopy_function(action_function=complicated_function):
    aggregate_value = 0
    for x in range(10):
        aggregate_value = aggregate_value + action_function(x)
    return aggregate_value

Upvotes: 3

Related Questions