Aguy
Aguy

Reputation: 8059

A function that works on a function and returns a function

Suppose I want to define a meta-function that accept a function as an argument and returns a new modified function. Something like

metaf(f) -> f**2

So whatever f would result, metaf would result with the answer to the power of two (and if the result cannot be raised to the power of 2, so be it. Raise an error).

Currently the way I've found to do that requires explicit reference to the argument of f in the definition of metaf, i.e. define

metaf = lambda f, x : f(x)**2

and then

mynewf = lambda x : metaf(f, x)

This works, but I wonder if it will hold for complex argument functions where there could be many variation to the input argument.

So I'm wondering if you can suggest a different way, especially one that does not require specifying the argument in metaf definition.

Edit: Both answers below were helpful. Just for reference, following the answers I've also realized the way to define metaf using lambda expression. Slightly cumbersome, but might still be worth noting:

metaf = lambda func: lambda *args, **kwargs : func(*args, **kwargs)**2

Upvotes: 1

Views: 53

Answers (2)

Israel Unterman
Israel Unterman

Reputation: 13510

I believe this is what you are after:

def metaf(f):
    def func(*r, **kw):
        return f(*r, **kw) ** 2
    return func

And now let's define some f...

def f(a, b):
    return a + b

And here is converting it:

mynewf = metaf(f)

Try it:

In [148]: f(10, 20)
Out[148]: 30

In [149]: mynewf(10, b=20)
Out[149]: 900

Please note the use of both normal argument and keyword argument in the useage of mynewf. I works as you would expect.

Upvotes: 4

Ulrich Schwarz
Ulrich Schwarz

Reputation: 7727

You should be able to use *args and **kwargs to gobble up all other arguments and pass them on like so:

def squarer(f):
    return lambda *args, **kwargs: f(*args, **kwargs)**2


>>> squarer(lambda x: x+1)(3)
16
>>> squarer(lambda x: x+1)(4)
25
>>> squarer(lambda x,y: x+1)(4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 2, in <lambda>
TypeError: <lambda>() takes exactly 2 arguments (1 given)
>>> squarer(lambda x,y=1: x+y)(4)
25
>>> squarer(lambda x,y=1: x+y)(4,2)
36

Upvotes: 3

Related Questions