Reputation: 11234
Let's say, I have a bunch of functions a
, b
, c
, d
and e
and I want to find out if they call any method from the random
module:
def a():
pass
def b():
import random
def c():
import random
random.randint(0, 1)
def d():
import random as ra
ra.randint(0, 1)
def e():
from random import randint as ra
ra(0, 1)
I want to write a function uses_module
so I can expect these assertions to pass:
assert uses_module(a) == False
assert uses_module(b) == False
assert uses_module(c) == True
assert uses_module(d) == True
assert uses_module(e) == True
(uses_module(b)
is False
because random
is only imported but never one of its methods called.)
I can't modify a
, b
, c
, d
and e
. So I thought it might be possible to use ast
for this and walk along the function's code which I get from inspect.getsource
. But I'm open to any other proposals, this was only an idea how it could work.
This is as far as I've come with ast
:
def uses_module(function):
import ast
import inspect
nodes = ast.walk(ast.parse(inspect.getsource(function)))
for node in nodes:
print(node.__dict__)
Upvotes: 6
Views: 624
Reputation: 36249
You can replace the random
module with a mock object, providing custom attribute access and hence intercepting function calls. Whenever one of the functions tries to import (from) random
it will actually access the mock object. The mock object can also be designed as a context manager, handing back the original random
module after the test.
import sys
class Mock:
import random
random = random
def __enter__(self):
sys.modules['random'] = self
self.method_called = False
return self
def __exit__(self, *args):
sys.modules['random'] = self.random
def __getattr__(self, name):
def mock(*args, **kwargs):
self.method_called = True
return getattr(self.random, name)
return mock
def uses_module(func):
with Mock() as m:
func()
return m.method_called
A more flexible way, specifying the module's name, is achieved by:
import importlib
import sys
class Mock:
def __init__(self, name):
self.name = name
self.module = importlib.import_module(name)
def __enter__(self):
sys.modules[self.name] = self
self.method_called = False
return self
def __exit__(self, *args):
sys.modules[self.name] = self.module
def __getattr__(self, name):
def mock(*args, **kwargs):
self.method_called = True
return getattr(self.module, name)
return mock
def uses_module(func):
with Mock('random') as m:
func()
return m.method_called
Upvotes: 1
Reputation: 36249
You can simply place a mock random.py
in your local (test) directory containing the following code:
# >= Python 3.7.
def __getattr__(name):
def mock(*args, **kwargs):
raise RuntimeError(f'{name}: {args}, {kwargs}') # For example.
return mock
# <= Python 3.6.
class Wrapper:
def __getattr__(self, name):
def mock(*args, **kwargs):
raise RuntimeError('{}: {}, {}'.format(name, args, kwargs)) # For example.
return mock
import sys
sys.modules[__name__] = Wrapper()
Then you simply test your functions as follows:
def uses_module(func):
try:
func()
except RuntimeError as err:
print(err)
return True
return False
This works because instead of importing the builtin random
module it will go for the mock module which emulates custom attribute access and hence can intercept the function calls.
If you don't want to interrupt the functions by raising an exception you can still use the same approach, by importing the original random
module in the mock module (modifying sys.path
appropriately) and then falling back on the original functions.
Upvotes: 1
Reputation: 2287
This is a work in progress, but perhaps it will spark a better idea. I am using the types of nodes in the AST to attempt to assert that a module is imported and some function it provides is used.
I have added what may be the necessary pieces to determine that this is the case to a checker
defaultdict which can be evaluated for some set of conditions, but I am not using all key value pairs to establish an assertion for your use cases.
def uses_module(function):
"""
(WIP) assert that a function uses a module
"""
import ast
import inspect
nodes = ast.walk(ast.parse(inspect.getsource(function)))
checker = defaultdict(set)
for node in nodes:
if type(node) in [ast.alias, ast.Import, ast.Name, ast.Attribute]:
nd = node.__dict__
if type(node) == ast.alias:
checker['alias'].add(nd.get('name'))
if nd.get('name') and nd.get('asname'):
checker['name'].add(nd.get('name'))
checker['asname'].add(nd.get('asname'))
if nd.get('ctx') and nd.get('attr'):
checker['attr'].add(nd.get('attr'))
if nd.get('id'):
checker['id'].add(hex(id(nd.get('ctx'))))
if nd.get('value') and nd.get('ctx'):
checker['value'].add(hex(id(nd.get('ctx'))))
# print(dict(checker)) for debug
# This check passes your use cases, but probably needs to be expanded
if checker.get('alias') and checker.get('id'):
return True
return False
Upvotes: 2