Panfeng Li
Panfeng Li

Reputation: 3616

Confused about the lambda expression in python

I understand the normal lambda expression, such as

g = lambda x: x**2

However, for some complex ones, I am a little confused about them. For example:

for split in ['train', 'test']:
    sets = (lambda split=split: newspaper(split, newspaper_devkit_path))

def get_imdb():
    return sets()

Where newspaper is a function. I was wondering what actually the sets is and why the get_imdb function can return the value sets()

Thanks for your help!

Added: The codes are actually from here factory.py

Upvotes: 5

Views: 293

Answers (3)

Mad Physicist
Mad Physicist

Reputation: 114300

sets is being assigned a lambda that is not really supposed to accept inputs, which you see from the way it is invoked. Lambdas in general behave like normal functions, and can therefore be assigned to variables like g or sets. The definition of sets is surrounded by an extra set of parentheses for no apparent reason. You can ignore those outer parens.

Lambdas can have all the same types of positional, keyword and default arguments a normal function can. The lambda sets has a default parameter named split. This is a common idiom to ensure that sets in each iteration of the loop gets the value of split corresponding to that iteration rather than just the one from the last iteration in all cases.

Without a default parameter, split would be evaluated within the lambda based on the namespace at the time it was called. Once the loop completes, split in the outer function's namespace will just be the last value it had for the loop.

Default parameters are evaluated immediately when a function object is created. This means that the value of the default parameter split will be wherever it is in the iteration of the loop that creates it.

Your example is a bit misleading because it discards all the actual values of sets besides the last one, making the default parameter to the lambda meaningless. Here is an example illustrating what happens if you keep all the lambdas. First with the default parameter:

sets = []
for split in ['train', 'test']:
    sets.append(lambda split=split: split)
print([fn() for fn in sets])

I have truncated the lambdas to just return their input parameter for purposes of illustration. This example will print ['train', 'test'], as expected.

If you do the same thing without the default parameter, the output will be ['test', 'test'] instead:

sets = []
for split in ['train', 'test']:
    sets.append(lambda: split)
print([fn() for fn in sets])

This is because 'test' is the value of split when all the lambdas get evaluated.

Upvotes: 2

Maybe you are confused about the split=split part. This has the same meaning as it would have in a regular function: the split on the left is a parameter of the lambda function and the split on the right is the default value the left split takes when no value is provided. In this case, the default value would be the variable split defined in the for loop.

So, answering your first question (what is sets?):

sets is a variable to which an anonymous function (or lambda function) is assigned. This allows the lambda function to be referenced and used via the variable sets.

To your second question (why can sets() be returned?), I respond:

Since sets is a variable that acts as a function, adding parenthesis after it calls the lambda function. Because no parameters are given, the parameter split takes the value 'test', which is the last value the for loop variable split takes. It is worth noting here that, since sets is not defined inside the function get_imdb, the interpreter looks for a definition of sets outside the scope of get_imdb (and finds the one that refers to the lambda function).

Upvotes: 0

Mateen Ulhaq
Mateen Ulhaq

Reputation: 27201

A lambda function:

func = lambda x: x**2

can be rewritten almost equivalently:

def func(x):
    return x**2

Using either way, you can call the function in this manner:

func(4)

In your example,

sets = lambda split=split: newspaper(split, newspaper_devkit_path)

can be rewritten:

def sets(split=split):
    return newspaper(split, newspaper_devkit_path)

and so can be called:

sets()

When you write the following:

def get_imdb():
    return sets()

you are defining a "closure". A reference to the function sets is saved within get_imdb so that it can be called later wherever get_imdb is called.

Upvotes: 0

Related Questions