Reputation: 484
I am trying to write a decorator that can handle various number and type of returned values, and I am having trouble to deal with numbers/iterables of numbers.
Say I want my decorator to operate "+1" on each returned values of the function it decorates. I found this way to do it, although I don't find it elegant (especially the "try/except" block, and the "return tuple(x) if len(x) >1 else x[0]" line):
# Plus-one decorator
def plus_one_decorator(func):
"""Decorator that adds one to each returned values"""
def decorated(*args, **kwargs):
raw_res = func(*args, **kwargs)
# Making raw_res iterable (since it could be any length)
try:
raw_res = tuple(raw_res)
except:
raw_res = [raw_res]
# Creating a list to store the decorated-outputs
output_list = []
for res in raw_res:
output_list.append(res + 1)
# Sugar to not return a one-tuple
return tuple(output_list) if len(output_list) > 1 else output_list[0]
return decorated
# Decorated func
dec_minus = plus_one_decorator(lambda x: -x)
dec_dbl_tpl = plus_one_decorator(lambda x: (x*2, x*3))
# Checking
print(dec_minus(1)) # >>> 0 (as expected)
print(dec_dbl_tpl(3)) # >>> 7, 10 (as expected)
So this does work for plain numbers, but what if I use it with a numpy.ndarray :
import numpy as np
foo = np.array([1, 1, 1])
print(dec_minus(foo)) # >>> (0, 0, 0) (! Expected array([0, 0, 0]))
print(dec_dbl_tpl(foo)) # >>> (array([3, 3, 3]), array([4, 4, 4])) (as expected)
For the second function, that returns a tuple, it does work as expected, since the raw-returned value is already a tuple (since tuple((array_1, array_2)) --> (array_1, array_2)). BUT, the ndarray array([0, 0, 0])
is converted into a tuple (0, 0, 0)
.
So, in the end, my question is :
is there an elegant way to make iterable the returned values of a function, when these values can be of different number and different types ?
I guess I could test the type of each returned value, but again it doesn't seem very elegant.
Cheers
Upvotes: 1
Views: 369
Reputation: 267
Yes, there is.
For instance:
from collections.abc import Iterable
from copy import deepcopy
...
# in your decorator
if not isinstance(raw_res, Iterable):
raw_res = [raw_res]
output = deepcopy(raw_res)
for i, res in enumerate(output):
output[i] = res + 1
return output if len(output) > 1 else output[0]
In your example you created a list of values named output_list
, but you should copy all data to the output
and then modify a data in it
Upvotes: 1