Reputation: 826
I was going through a Scala-99 problem to reduce a complex nested list into a flat list. Code given below:
def flatten(l: List[Any]): List[Any] = l flatMap {
case ms:List[_] => flatten(ms)
case l => List(l)
}
val L = List(List(1, 1), 2, List(3, List(5, 8)))
val flattenedList = flatten(L)
For the given input above L
, I understood this problem by drawing a tree (given below)
List(List(1, 1), 2, List(3, List(5, 8))) (1)
| \ \
List(1, 1) List(2) List(3, List(5, 8)) (2)
| \ | \
List(1) List(1) List(3) List(5, 8) (3)
| \
List(5) List(8) (4)
What I've understood is that, the program results in the leaf nodes being added in a list maintained by Scala internally, like:
li = List(List(1), List(1), List(2), List(3), List(5), List(8))
and then the result is passed to the flatten method which results in the final answer:
List(1, 1, 2, 3, 5, 8)
Is my understanding correct?
EDIT: I'm sorry, I forgot to add this:
I wanted to ask that if my understanding is correct then why does replacing flatMap
with map
in the flatten
's definition above produces this list:
List(List(List(1), List(1)), List(2), List(List(3), List(List(5), List(8))))
I mean isn't flatMap
just map
then flatten
. Shouldn't I be getting like the one I mentioned above:
li = List(List(1), List(1), List(2), List(3), List(5), List(8))
Upvotes: 1
Views: 263
Reputation: 505
You're right that flatMap
is just map
and flatten
but note that this flatten
is not the same flatten
you define, for list it only concatenate inner lists at 1 level.
One very useful way to unpack these is to use substitution model, just like maths
if I define it like this, (calling it f to avoid confusion with flatten here and flatten in std library)
def f(l: List[Any]): List[Any] = l map {
case ms:List[_] => f(ms)
case l => List(l)
}
then
f(List( List(1, 1), 2))
= List(f(List(1, 1)), f(2)) // apply f to element of the outer most list
= List(List(f(1), f(1)), f(2)) // apply f to element of the inner list
= List(List(List(1), List(1)), List(2))) // no more recursion
Notice map
doesn't change the structure of your list, it only applies the function to each element. This should explains how you have the result if you replace flatMap
with map
Now if you have flatMap
instead of map
, then the flatten
step is simply concatenating
def f(l: List[Any]): List[Any] = l flatMap {
case ms:List[_] => f(ms)
case l => List(l)
}
then
f(List(List(1,1), 2))
= f(List(1,1)) ++ f(2) // apply f to each element and concatenate
= (f(1) ++ f(1)) ++ f(2)
= (List(1) ++ List(1)) ++ List(2)
= List( 1,1) ++ List(2)
= List(1,2,3)
or in another way, using flatten instead of ++
f( List( List(1,1), 2))
= flatten(List( f( List( 1, 1)) , f(2))) // map and flatten
= flatten(List( flatten(List(f(1), f(1))), f(2))) // again map and flatten
= flatten(List( flatten(List(List(1), List(1))), List(2))))
now you can see that flatten
is called multiple times, at every level where you recursively apply f
which will collapse your tree 1 level at a time into just 1 big list.
To answer your comment: why is List(1,1)
is turned into flatten(List(List(1), List(1))
. It's because this is the simple case, but consider List(1, List(2))
, then f
will be applied for 1
and List(2)
. Because the next step is to 'flatten' (in stdlib) then both 1
& List(2)
must be turned into a List so that it is in the right shape
Upvotes: 1