itarato
itarato

Reputation: 865

How Haskell's thunks are so efficient?

The following Haskell implementation of Fibonacci runs in linear time:

fib = 0 : 1 : zipWith (+) fib (tail fib)

From what I understand fib calls are thunks here, so they are evaluated progressively and lazily.

However what really surprise me is that it's running in linear time. Coming from a procedural language if I look at all the inner fib calls they suppose to be called recursively and run around exponential time.

I've tried a somewhat equivalent version in Ruby:

def fib
  Enumerator.new do |yielder|
    yielder.to_proc.call(0)
    yielder.to_proc.call(1)

    gen_a = fib
    gen_b = fib
    gen_b.next()

    while true
      yielder.to_proc.call(gen_a.next() + gen_b.next())
    end
  end
end

gen = fib
20.times { puts(gen.next()) }

and Python:

def fib():
    yield 0
    yield 1

    lhs = fib()
    rhs = fib()
    rhs.__next__()

    while True:
        yield (lhs.__next__() + rhs.__next__())


gen = fib()
for _ in range(0, 20):
    print(gen.__next__())

They both are lazy and run in exponential time as I'd expect.

It's possible it's my implementation that is wrong, however I cannot help but wonder if thunks leverage the functional immutability of pure functions and it stores the values for reuse, and maybe that's why Haskell's execution can skip most recursive calls.

Is this true? Or I'm totally wrong? Is there any good material to read about it?

Upvotes: 3

Views: 134

Answers (2)

willeM_ Van Onsem
willeM_ Van Onsem

Reputation: 477200

There is a difference w.r.t. Python. Haskell does not create multiple generators.

What has happened is that you made something that looks like:

       +-------+
       |  (:)  |
       +---+---+
fib -> | o | o |
    |  +-|-+-|-+
    |    v   v
    |    0  +---+---+
    |       |  (:)  |<----------.
    |       +---+---+           |
    |       | o | o |           |
    |       +-|-+-|-+           |
    |         v   v             |
    |         1  +-----------+  |
    |            |  zipWith  |  |
    |            +-----------+  |
    |            | o | o | o |  |
    |            +-|-+-|-+-|-+  |
    |              v   |   '----'
    |             (+)  |   
    '------------------'

if you need the next item, Haskell will be forced to evaluate a node for zipWith, and zipWith will evaluate give us 0 + 1, and as tail of that list, it will perform a zipWith (+) on the tails of the two lists, so:

       +-------+
       |  (:)  |
       +---+---+
fib -> | o | o |
       +-|-+-|-+
         v   v
         0  +---+---+
            |  (:)  |<-----------------.
            +---+---+                  |
            | o | o |                  |
            +-|-+-|-+                  |
              v   v                    |
              1  +---+---+             |
                 |  (:)  |<----------. |
                 +---+---+           | |
                 | o | o |           | |
                 +-|-+-|-+           | |
                   |   v             | |
                   |  +-----------+  | |
                   |  |  zipWith  |  | |
                   |  +-----------+  | |
                   |  | o | o | o |  | |
                   |  +-|-+-|-+-|-+  | |
                   |    v   |   '----' |
                   |    (+)  |          |
                   v         '----------'
             +-------+
             |  (+)  |
             +---+---+
             | o | o |
             +-|-+-|-+
               v   v
               0   1

It thus does not spawn new generators blowing up memory and also resulting in exponentially more cursors that we have to move forward.

We could implement this the same by working with a linked list approach, like:

class LazyNode:
  def __init__(self, head, tail):
      self._head = head
      self._tail = tail
  
  @property
  def head(self):
    return self._head
  
  @property
  def tail(self):
    if callable(self._tail):
        subhead, subtail = self._tail()
        self._tail = LazyNode(subhead, subtail)
    return self._tail
  def __iter__(self):
      yield self.head
      yield from self.tail

and we can construct this with:

def produce(node0=None, node1=None):
  def producenext(x, y):
      return lambda: call(x.tail, y.tail)
  def call(x0=node0, x1=node1):
    if x0 is None:
      x0 = f0
    if x1 is None:
      x1 = f1
    return (x0.head + x1.head), producenext(x0, x1)
  return call
  
f1 = LazyNode(1, produce())
f0 = LazyNode(0, f1)

Upvotes: 2

sepp2k
sepp2k

Reputation: 370367

def fib
  Enumerator.new do |yielder|
    yielder.to_proc.call(0)
    yielder.to_proc.call(1)

    gen_a = fib
    gen_b = fib
    gen_b.next()

    while true
      yielder.to_proc.call(gen_a.next() + gen_b.next())
    end
  end
end

Here you're defining a function that returns an enumerator and this function is called twice by the enumerator, creating a new independent enumerator each time. Clearly this will have exponential runtime. The Haskell version doesn't do that. The Haskell version creates a single list and no functions.

To be equivalent to the Haskell version, your Ruby version would have to get rid of the method and instead define fib as fib = Enumerator.new do ... end. However, it's not quite as easy because if we just do that, we end up with this code (which won't work):

fib = Enumerator.new do |yielder|
  yielder << 0
  yielder << 1

  gen_a = fib
  gen_b = fib
  gen_b.next()

  while true
    yielder << gen_a.next() + gen_b.next()
  end
end

So why doesn't this work?

  1. Because now gen_a = fib and gen_b = fib aren't function calls anymore and instead set gen_a and gen_b as a reference to the same enumerator. That is closer to the Haskell version which doesn't have multiple lists/enumerators either, but the problem is that, unlike Haskell's lists, enumerators are mutable. So calling gen_b.next() now also affects gen_a and your entire logic doesn't work anymore; and
  2. Because enumerators don't allow you to access the enumerator being created while creating it.

To get something that works like the Haskell example, you can define a Ruby version of a lazy list like this:

class LazyList
  attr_reader :head

  def initialize(head, &blk)
    @head = head
    @tail = :unevaluated
    @tail_blk = blk
  end

  def tail
    if @tail == :unevaluated
      @tail = @tail_blk[]
    end
    @tail
  end

  def take(n)
    result = []
    list = self
    while n > 0 && list
      result << list.head
      list = list.tail
      n -= 1
    end
    result
  end

  def zip_with(other_list, &blk)
    LazyList.new(blk[head, other_list.head]) do
      tail.zip_with(other_list.tail, &blk)
    end
  end
end

fib = LazyList.new(0) do
  LazyList.new(1) do
    fib.zip_with(fib.tail, &:+)
  end
end

p fib.take(42)

This works in linear time (if we pretend integer operations are O(1) anyway) just like the Haskell version.


PS: Since you've specifically asked about thunks, here's another version that uses explicit thunks just to show that thunks aren't magical:

class Thunk
  def initialize(&blk)
    @value = :unevaluated
    @blk = blk
  end

  def get
    if @value == :unevaluated
      @value = @blk[]
    end
    @value
  end
end

class LazyList
  attr_reader :head

  def initialize(head, tail)
    @head = head
    @tail = tail
  end

  def tail
    @tail.get
  end

  def take(n)
    result = []
    list = self
    while n > 0 && list
      result << list.head
      list = list.tail
      n -= 1
    end
    result
  end

  def zip_with(other_list, &blk)
    LazyList.new(blk[head, other_list.head], Thunk.new do
      tail.zip_with(other_list.tail, &blk)
    end)
  end
end

fib = LazyList.new(0, Thunk.new do
  LazyList.new(1, Thunk.new do
    fib.zip_with(fib.tail, &:+)
  end)
end)

p fib.take(42)

Upvotes: 7

Related Questions