Reputation: 1033
I have a class encompassing the calculation for some complicated vector formula. The detailed structure is probably unimportant; it basically calculates the supplied vector formula on a mesh, and stores/manages the result.
I want to be able to do basic arithmetics with these classes if they have the same spatio-temporal coordinates. My problem is that every arithmetics would need a considrable amount of type check. This results in 5 copies (one for +,-,*,/,**
each) having exactly the same code bar an arithmetic sign in 3 places.
So much repeating code in an object oriented language looks suspicious to me. At the same time, I couldn't come up with an elegant solution to simplify it, and I get the feeling there's some method I'm unaware of.
How could I extract the repeating code in a best-practice way?
The code is below, I marked the 3 differences in __sub__
compared to __add__
:
class FieldVector(object):
def __init__(self, formula, fieldparams, meshparams, zero_comps=[]):
[...]
def is_comparable_to(self, other):
"has the same spatio-temporal dimensions as the other one"
if not other.isinstance(FieldVector):
return False
return (
self.meshparams == other.meshparams and
self.framenum == other.framenum and
self.fieldparams.tnull == other.fieldparams.tnull and
self.fieldparams.tmax == other.fieldparams.tmax
)
def _check_comparable(self, other):
if not self.is_comparable_to:
raise ValueError("The two fields have different spatio-temporal coordinates")
def __add__(self, other):
new_compvals = {}
if isinstance(other, Number):
for comp in self.nonzero_comps:
new_compvals[comp] = self.get_component(comp) + other
elif isinstance(other,FieldVector):
self._check_comparable(other)
nonzeros = list(set(self.nonzero_comps).union(other.nonzero_comps))
for comp in nonzeros:
new_compvals[comp] = self.get_component(comp) + other.get_component(comp)
else:
raise TypeError(f'unsupported operand type(s) for +: {self.__class__} and {other.__class__}')
return ModifiedFieldVector(self, new_compvals)
def __sub__(self, other):
new_compvals = {}
if isinstance(other, Number):
for comp in self.nonzero_comps:
# --- difference 1: - instead of +
new_compvals[comp] = self.get_component(comp) - other
elif isinstance(other,FieldVector):
self._check_comparable(other)
nonzeros = list(set(self.nonzero_comps).union(other.nonzero_comps))
for comp in nonzeros:
# --- difference 2: - instead of +
new_compvals[comp] = self.get_component(comp) - other.get_component(comp)
else:
# --- difference 3: - instead of +
raise TypeError(f'unsupported operand type(s) for -: {self.__class__} and {other.__class__}')
return ModifiedFieldVector(self, new_compvals)
[... __mul__, __truediv__, __pow__ defined the same way]
Upvotes: 0
Views: 35
Reputation: 36722
You could extract the calculation into a private method, and pass it the operator needed.
Maybe something like this:
import operator
...
def _xeq(self, other, op):
new_compvals = {}
if isinstance(other, Number):
for comp in self.nonzero_comps:
new_compvals[comp] = op(self.get_component(comp), other)
elif isinstance(other,FieldVector):
self._check_comparable(other)
nonzeros = list(set(self.nonzero_comps).union(other.nonzero_comps))
for comp in nonzeros:
new_compvals[comp] = op(self.get_component(comp), other.get_component(comp))
else:
raise TypeError(f'unsupported operand type(s) for {op}: {self.__class__} and {other.__class__}')
return ModifiedFieldVector(self, new_compvals)
def __add__(self, other):
return self._xeq(other, operator.add)
def __sub__(self, other):
return self._xeq(other, operator.sub)
Upvotes: 1