Morten Lohne
Morten Lohne

Reputation: 433

Chain two iterators while lazily constructing the second one

I'd like a method like Iterator::chain() that only computes the argument iterator when it's needed. In the following code, expensive_function should never be called:

use std::{thread, time};

fn expensive_function() -> Vec<u64> {
    thread::sleep(time::Duration::from_secs(5));
    vec![4, 5, 6]
}

pub fn main() {
    let nums = [1, 2, 3];
    for &i in nums.iter().chain(expensive_function().iter()) {
        if i > 2 {
            break;
        } else {
            println!("{}", i);
        }
    }
}

Upvotes: 6

Views: 1281

Answers (2)

Shepmaster
Shepmaster

Reputation: 430851

You can create your own custom iterator adapter that only evaluates a closure when the original iterator is exhausted.

trait IteratorExt: Iterator {
    fn chain_with<F, I>(self, f: F) -> ChainWith<Self, F, I::IntoIter>
    where
        Self: Sized,
        F: FnOnce() -> I,
        I: IntoIterator<Item = Self::Item>,
    {
        ChainWith {
            base: self,
            factory: Some(f),
            iterator: None,
        }
    }
}

impl<I: Iterator> IteratorExt for I {}

struct ChainWith<B, F, I> {
    base: B,
    factory: Option<F>,
    iterator: Option<I>,
}

impl<B, F, I> Iterator for ChainWith<B, F, I::IntoIter>
where
    B: Iterator,
    F: FnOnce() -> I,
    I: IntoIterator<Item = B::Item>,
{
    type Item = I::Item;
    fn next(&mut self) -> Option<Self::Item> {
        if let Some(b) = self.base.next() {
            return Some(b);
        }

        // Exhausted the first, generate the second

        if let Some(f) = self.factory.take() {
            self.iterator = Some(f().into_iter());
        }

        self.iterator
            .as_mut()
            .expect("There must be an iterator")
            .next()
    }
}
use std::{thread, time};

fn expensive_function() -> Vec<u64> {
    panic!("You lose, good day sir");
    thread::sleep(time::Duration::from_secs(5));
    vec![4, 5, 6]
}

pub fn main() {
    let nums = [1, 2, 3];
    for i in nums.iter().cloned().chain_with(|| expensive_function()) {
        if i > 2 {
            break;
        } else {
            println!("{}", i);
        }
    }
}

Upvotes: 8

E_net4
E_net4

Reputation: 30003

One possible approach: delegate the expensive computation to an iterator adaptor.

let nums = [1, 2, 3];
for i in nums.iter()
    .cloned()
    .chain([()].into_iter().flat_map(|_| expensive_function()))
{
    if i > 2 {
        break;
    } else {
        println!("{}", i);
    }
}

Playground

The passed iterator is the result of flat-mapping a dummy unit value () to the list of values, which is lazy. Since the iterator needs to own the respective outcome of that computation, I chose to copy the number from the array.

Upvotes: 8

Related Questions