Reputation:
In Python, given a module X and a class Y, how can I iterate or generate a list of all subclasses of Y that exist in module X?
Upvotes: 19
Views: 13430
Reputation: 3887
Although Quamrana's suggestion works fine, there are a couple of possible improvements I'd like to suggest to make it more pythonic. They rely on using the inspect module from the standard library.
inspect.getmembers()
inspect.isclass()
With those, you can reduce the whole thing to a single list comprehension if you like:
def find_subclasses(module, clazz):
return [
cls
for name, cls in inspect.getmembers(module)
if inspect.isclass(cls) and issubclass(cls, clazz)
]
Upvotes: 25
Reputation: 39414
Can I suggest that neither of the answers from Chris AtLee and zacherates fulfill the requirements? I think this modification to zacerates answer is better:
def find_subclasses(module, clazz):
for name in dir(module):
o = getattr(module, name)
try:
if (o != clazz) and issubclass(o, clazz):
yield name, o
except TypeError: pass
The reason I disagree with the given answers is that the first does not produce classes that are a distant subclass of the given class, and the second includes the given class.
Upvotes: 5
Reputation: 123000
Given the module foo.py
class foo(object): pass
class bar(foo): pass
class baz(foo): pass
class grar(Exception): pass
def find_subclasses(module, clazz):
for name in dir(module):
o = getattr(module, name)
try:
if issubclass(o, clazz):
yield name, o
except TypeError: pass
>>> import foo
>>> list(foo.find_subclasses(foo, foo.foo))
[('bar', <class 'foo.bar'>), ('baz', <class 'foo.baz'>), ('foo', <class 'foo.foo'>)]
>>> list(foo.find_subclasses(foo, object))
[('bar', <class 'foo.bar'>), ('baz', <class 'foo.baz'>), ('foo', <class 'foo.foo'>), ('grar', <class 'foo.grar'>)]
>>> list(foo.find_subclasses(foo, Exception))
[('grar', <class 'foo.grar'>)]
Upvotes: 1
Reputation: 8076
Here's one way to do it:
import inspect
def get_subclasses(mod, cls):
"""Yield the classes in module ``mod`` that inherit from ``cls``"""
for name, obj in inspect.getmembers(mod):
if hasattr(obj, "__bases__") and cls in obj.__bases__:
yield obj
Upvotes: 15