Reputation: 1636
I am using Python's unittest
module to test a script I am writing.
The script contains a loop like this:
// my_script.py
def my_loopy_function():
aggregate_value = 0
for x in range(10):
aggregate_value = aggregate_value + complicated_function(x)
return aggregate_value
def complicated_function(x):
a = do()
b = something()
c = complicated()
return a + b + c
I have no problems using unittest
to test complicated_function
. But I would like to test my_loopy_function
by overriding complicated_function
.
I tried modifying my script so that my_loopy_function
takes complicated_function
as an optional parameter so that I can pass in a simple version from the test:
// my_modified_script.py
def my_loopy_function(action_function=None):
if action_function is not None:
complicated_function = action_function
aggregate_value = 0
for x in range(10):
aggregate_value = aggregate_value + complicated_function(x)
return aggregate_value
def complicated_function(x):
a = do()
b = something()
c = complicated()
return a + b + c
// test_my_script.py
from myscript import my_loopy_function
class TestMyScript(unittest.TestCase):
test_loopy_function(self):
def simple_function():
return 1
self.assertEqual(10, my_loopy_function(action_function=simple_function))
It has not worked as I had hoped, are there any suggestions on how I should be doing this?
Upvotes: 6
Views: 14990
Reputation: 5619
I couldn't make @mock.patch
from the accepted answer work in any combination (worth noting I'm using pytest, but then the answer seems to be testing library-agnostic, so Idk), but I found a more generic way to do that, so sharing.
There's a technique called "monkey-patching". Its purpose is exactly to override a method from a 3rd party module. It works both with top-level functions as well as with class methods (in the latter case you need to add a class name in the assignment operation).
It works like this: you first do import some_module
and then to override its function foo
to bar
you execute some_module.foo = bar
.
Important: the module should be imported without from
. I.e. if you do a from some_module import foo
and then foo = bar
, it will not work.
Example:
λ cat some_module.py
def hello():
print("hello")
def some_func():
hello()
λ cat test.py
import some_module
def my_override():
print("hi")
some_module.hello = my_override
some_module.some_func()
λ python3 test.py
hi
Upvotes: 0
Reputation: 1636
In the end I used Python's mock
, which allows me to override complicated_function
without having to adjust the original code in any way.
Here is the original script, and note that complicated_function
is not passed in to my_loopy_function
as an 'action_function
' parameter (which was what I tried in my earlier solutions):
// my_script.py
def my_loopy_function():
aggregate_value = 0
for x in range(10):
aggregate_value = aggregate_value + complicated_function(x)
return aggregate_value
def complicated_function(x):
a = do()
b = something()
c = complicated()
return a + b + c
and here is the script I am using to test it:
// test_my_script.py
import unittest
import mock
from my_script import my_loopy_function
class TestMyModule(unittest.TestCase):
@mock.patch('my_script.complicated_function')
def test_1(self, mocked):
mocked.return_value = 1
self.assertEqual(10, my_loopy_function())
This works just as I had wanted:
mock
module gives me post-coding access to the innards.Thanks to austin for his suggestion to use mock
.
BTW I am using Python 2.7 and therefore used the pip
-installable mock
from PyPI.
Upvotes: 8
Reputation: 60957
In your code, you shouldn't be able to overwrite complicated_function
like that. If I try it, I get UnboundLocalError: local variable 'complicated_function' referenced before assignment
.
But perhaps is the problem that in your actual code, you're referring to complicated_function
in some other way (e.g. as a member of a module)? Then by overwriting it in your test, you're overwriting the actual complicated_function
, so you won't be able to use it from other tests.
The correct way to do this is to overwrite the local variable with the global one, like so:
def my_loopy_function(action_function=None):
if action_function is None:
action_function = complicated_function
aggregate_value = 0
for x in range(10):
# Use action_function here instead of complicated_function
aggregate_value = aggregate_value + action_function(x)
return aggregate_value
Upvotes: 0
Reputation: 55197
Don't try to override complicated_function
with action_function
, just use complicated_function
as the default action_function
:
def my_loopy_function(action_function=complicated_function):
aggregate_value = 0
for x in range(10):
aggregate_value = aggregate_value + action_function(x)
return aggregate_value
Upvotes: 3