Neinstein
Neinstein

Reputation: 1033

How can I simplify my repetitive arithmetic functions?

I have a class encompassing the calculation for some complicated vector formula. The detailed structure is probably unimportant; it basically calculates the supplied vector formula on a mesh, and stores/manages the result.

I want to be able to do basic arithmetics with these classes if they have the same spatio-temporal coordinates. My problem is that every arithmetics would need a considrable amount of type check. This results in 5 copies (one for +,-,*,/,** each) having exactly the same code bar an arithmetic sign in 3 places.

So much repeating code in an object oriented language looks suspicious to me. At the same time, I couldn't come up with an elegant solution to simplify it, and I get the feeling there's some method I'm unaware of.

How could I extract the repeating code in a best-practice way?

The code is below, I marked the 3 differences in __sub__ compared to __add__:

class FieldVector(object):
    def __init__(self, formula, fieldparams, meshparams, zero_comps=[]):
        [...]


    def is_comparable_to(self, other):
        "has the same spatio-temporal dimensions as the other one"
        if not other.isinstance(FieldVector):
            return False
        return (
            self.meshparams == other.meshparams and
            self.framenum == other.framenum and
            self.fieldparams.tnull == other.fieldparams.tnull and
            self.fieldparams.tmax == other.fieldparams.tmax
        )


    def _check_comparable(self, other):
        if not self.is_comparable_to:
            raise ValueError("The two fields have different spatio-temporal coordinates")



    def __add__(self, other):
        new_compvals = {}
        if isinstance(other, Number):
            for comp in self.nonzero_comps:
                new_compvals[comp] = self.get_component(comp) + other
        elif isinstance(other,FieldVector):
            self._check_comparable(other)
            nonzeros = list(set(self.nonzero_comps).union(other.nonzero_comps))
            for comp in nonzeros:
                new_compvals[comp] = self.get_component(comp) + other.get_component(comp)
        else:
            raise TypeError(f'unsupported operand type(s) for +: {self.__class__} and {other.__class__}')

        return ModifiedFieldVector(self, new_compvals)


    def __sub__(self, other):
        new_compvals = {}
        if isinstance(other, Number):
            for comp in self.nonzero_comps:
                # --- difference 1: - instead of +
                new_compvals[comp] = self.get_component(comp) - other
        elif isinstance(other,FieldVector):
            self._check_comparable(other)
            nonzeros = list(set(self.nonzero_comps).union(other.nonzero_comps))
            for comp in nonzeros:
                # --- difference 2: - instead of +
                new_compvals[comp] = self.get_component(comp) - other.get_component(comp) 
        else:
            # --- difference 3: - instead of +
            raise TypeError(f'unsupported operand type(s) for -: {self.__class__} and {other.__class__}')

        return ModifiedFieldVector(self, new_compvals)

    [... __mul__, __truediv__, __pow__ defined the same way]

Upvotes: 0

Views: 35

Answers (1)

Reblochon Masque
Reblochon Masque

Reputation: 36722

You could extract the calculation into a private method, and pass it the operator needed.

Maybe something like this:

import operator

...
        
        
    def _xeq(self, other, op):
        new_compvals = {}
        if isinstance(other, Number):
            for comp in self.nonzero_comps:
                new_compvals[comp] = op(self.get_component(comp), other)
        elif isinstance(other,FieldVector):
            self._check_comparable(other)
            nonzeros = list(set(self.nonzero_comps).union(other.nonzero_comps))
            for comp in nonzeros:
                new_compvals[comp] = op(self.get_component(comp), other.get_component(comp))
        else:
            raise TypeError(f'unsupported operand type(s) for {op}: {self.__class__} and {other.__class__}')
        return ModifiedFieldVector(self, new_compvals)        

    def __add__(self, other):
        return self._xeq(other, operator.add)

    def __sub__(self, other):
        return self._xeq(other, operator.sub)

Upvotes: 1

Related Questions