ozgeneral
ozgeneral

Reputation: 6819

Check if a function raises NotImplementedError before calling it in Python

I have the following simplified scheme:

class NetworkAnalyzer(object):
    def __init__(self):
       print('is _score_funct implemented?')

    @staticmethod
    def _score_funct(network):
        raise NotImplementedError

class LS(NetworkAnalyzer):
    @staticmethod
    def _score_funct(network):
        return network

and I am looking for what I should use instead of print('is _score_funct implemented?') in order to figure out if a subclass has already implemented _score_funct(network) or not.

Note: If there is a more pythonic/conventional way of structuring the code, I would also appreciate its mention. The reason I defined it this way is, some NetworkAnalyzer subclasses have _score_funct in their definition, and the ones that dont have it will have different initialization of variables although they will have the same structure

Upvotes: 6

Views: 3120

Answers (2)

Duncan
Duncan

Reputation: 95762

Use an abstract base class and you won't be able to instantiate the class unless it implements all of the abstract methods:

import abc

class NetworkAnalyzerInterface(abc.ABC):
    @staticmethod
    @abc.abstractmethod
    def _score_funct(network):
        pass

class NetworkAnalyzer(NetworkAnalyzerInterface):
    def __init__(self):
        pass

class LS(NetworkAnalyzer):
    @staticmethod
    def _score_funct(network):
        return network

class Bad(NetworkAnalyzer):
    pass

ls = LS()   # Ok
b = Bad()   # raises TypeError: Can't instantiate abstract class Bad with abstract methods _score_funct

Upvotes: 4

Jean-François Fabre
Jean-François Fabre

Reputation: 140307

I'm not a metaclass/class specialist but here's a method that works in your simple case (not sure it works as-is in a complex/nested class namespace):

To check if the method was overridden, you could try a getattr on the function name, then check the qualified name (class part is enough using string partitionning):

class NetworkAnalyzer(object):
    def __init__(self):
        funcname = "_score_funct"
        d = getattr(self,funcname)
        print(d.__qualname__.partition(".")[0] == self.__class__.__name__)

if _score_funct is defined in LS, d.__qualname__ is LS._score_funct, else it's NetworkAnalyzer._score_funct.

That works if the method is implemented at LS class level. Else you could replace by:

d.__qualname__.partition(".")[0] != "NetworkAnalyzer"

Of course if the method is overridden with some code which raises an NotImplementedError, that won't work... This method doesn't inspect methods code (which is hazardous anyway)

Upvotes: 2

Related Questions