Reputation: 150
I've written a simple generator function that takes a list that could have sub-lists and tries to flatten the list:
so [1, [2, 3], 4, [5, [6, 7], 8]] should produce 1,2,3,4,5,6,7,8
If I just want to print out the values (not a generator) it looks like this and this works:
# Code A
def flatten_list_of_lists(my_list):
for element in my_list:
if isinstance(element, list):
flatten_list_of_lists(element)
else:
print(element)
my_list = [1, [2, 3], 4, [5, [6, 7], 8]]
flatten_list_of_lists(my_list)
And that prints out 1,2,3,4,5,6,7,8 as expected
However, when I change the code to this:
# Code B
def flatten_list_of_lists(my_list):
for element in my_list:
if isinstance(element, list):
flatten_list_of_lists(element)
else:
yield element
for i in flatten_list_of_lists(my_list):
print(i)
which is just switching over the print to a yield, the program just prints out 1,4.
I'll paste code below that actually works. But I'm wondering why the previous code doesnt work? If Code A 'prints' out the numbers correctly, why wouldnt Code B 'yield' the numbers correctly?
Seems like I have a fundamental misunderstanding of how generators work with recursion.
This code actually works:
# Code C
def flatten_list_of_lists_v2(my_list):
for element in my_list:
if isinstance(element, list):
for sub_element in flatten_list_of_lists_v2(element):
yield sub_element
else:
yield element
l = []
for element in flatten_list_of_lists_v2(my_list):
print(element)
And that prints out 1,2,3,4,5,6,7,8
Just a little background, I just finished watching this video: https://www.youtube.com/watch?v=LelQTPiH3f4
and in there he explains when you're designing your generators, just put a print where you want to yield and see if you get the right results and then just switch the print to a yield. So I guess his advice doesnt work in all circumstances, I just want to understand why.
Upvotes: 0
Views: 51
Reputation: 135227
A simple mistake -
def flatten_list_of_lists(my_list):
for element in my_list:
if isinstance(element, list):
# add yield from
yield from flatten_list_of_lists(element)
else:
# yield, not print
yield element
my_list = [1, [2, 3], 4, [5, [6, 7], 8]]
for e in flatten_list_of_lists(my_list):
print(e)
Output
1
2
3
4
5
6
7
8
Upvotes: 6