Alumashka
Alumashka

Reputation: 195

Effective check for value existence in the Scala Stream (in functional manner)

I've come upon this question on reddit - I become quite puzzled with it, but I have no idea (and seems there were no satisfying answer). I dare to copy it here:

Suppose we have a stream with each element depending on its predecessor, like for siimple pseudo-random sequence, e.g.:

def neum(a:Int): Stream[Int] = Stream.iterate(a)(a => (a*a/100)%10000)

This is a von-Neumann's randomizer from the exercise referenced by the question

and starting with given value we want to know when the sequence comes to a loop. We can easily do this in imperative form, using a Set for storing values:

// like in java
Set<Integer> values = new HashSet<>();
while (true) {
    int x = nextValueInSequence(x)
    if (values.contains(x)) {
        break;
    }
    values.add(x);
}

However with Scala it is interesting to come up with "functional" solution. But the author of the question seems to have no idea how this could be achieved preserving O(N) time complexity. Me too. The only comment there looks like suggesting straightforward O(N^2) solution.

Upvotes: 2

Views: 297

Answers (5)

Rodion Gorkovenko
Rodion Gorkovenko

Reputation: 2852

I believe it is possible to use the solution with Set as is with the help of tail recursive function:

@tailrec
def neumannCount(x: Int, m: Set[Int] = Set[Int]()): Int = {
    if (m.contains(x)) m.size else neumannCount(x * x / 100 % 10000, m + x)
}

The function simply gets current value and set of previous elements. It checks if value exists in the Set and if no - then generates next element and another set with addition of current element - and passes them to another call to the same function. When the value is at last encountered - we just return the size of the Set as a result (so the function is tail-recursive).

I believe this should be O(1) in time and O(N) in space thanks to immutable collections which are built on top of each other (if I understand them right).

Upvotes: 1

Odomontois
Odomontois

Reputation: 16308

Here is simplified version of Kolmar's answer running in O(N) time and O(1) space. It basically do this:

  1. Finding any repeated number using fast\slow strategy. It's minimal integer that is divisible by the real period and greater than non-cyclic prefix.
  2. Finding real period.
  3. Finding first element repeated after period, i.e. begin of cycle.
  4. Returning non-cyclic prefix part and cyclic part.

code:

def cycleOf[T](seq: => Iterator[T]): (Iterator[T], Iterator[T]) = {
  def fast = seq.sliding(1, 2) map (_.head)
  val meet = seq zip fast drop 1 dropWhile { case (x, y) => x != y }
  val met = meet.next()
  val period = (meet indexOf met) + 1
  val start = seq drop period zip seq indexWhere { case (x, y) => x == y }
  (seq take start, seq.slice(start, start + period))
}

You could try it by

val (prefix, cycle) = cycleOf(neum(5761).iterator)

then prefix.toList is

List(5761, 1891, 5758, 1545, 3870, 9769, 4333, 7748, 315, 992, 9840, 8256, 1615, 6082, 9907, 1486, 2081, 3305, 9230, 1929, 7210, 9841, 8452, 4363, 357, 1274, 6230, 8129, 806, 6496, 1980, 9204, 7136, 9224, 821, 6740, 4276, 2841, 712, 5069, 6947, 2608, 8016, 2562, 5638, 7870, 9369, 7781, 5439, 5827, 9539, 9925, 5056, 5631, 7081, 1405, 9740, 8676, 2729, 4474, 166, 275, 756, 5715, 6612, 7185, 6242, 9625, 6406, 368, 1354, 8333, 4388, 2545, 4770, 7529, 6858, 321, 1030, 609, 3708, 7492, 1300, 6900)

and cycle.toList is

List(6100, 2100, 4100, 8100)

Also note SpiderPig's advice: you could simply replace Stream with Iterator in your neum definition to achieve more memory-efficient version.

Upvotes: 2

Kolmar
Kolmar

Reputation: 14224

I believe, there is an algorithm that has O(N log N) time complexity (may be possible to improve that to O(N)), and O(1) total memory consumption. That is, we don't have to remember most of the previous numbers. The constant factor is quite high though.

This memory consumption is calculated not with a Stream, but with a general number sequence defined by a starting element and a recursive formula. For example Iterator.iterate(start)(a => a * a / 100 % 10000). Stream would memoize previous results and effectively make it O(N) memory.

Let's say that the sequence has P ≥ 0 elements before the loop starts, and L ≥ 1 elements in the loop. For example, a sequence [2, 10, 13, 9, 11, 17, 11, 17, ...] has P = 4 and L = 2. And we need to find P + L.

In the algorithm we have to iterate through the sequence. I'll call the current position a "pointer". In a number sequence the pointer means just a number. Initially the pointer is equal to the starting element of the sequence, and to move the pointer 1 step forward we have to apply the recursive formula to it.

Now to the algorithm:

  1. Start with two pointers to the start of the sequence: "slow" and "fast". Slow pointer moves 1 step at a time, and fast moves 2 steps at a time (i.e. 2 applications of recursive formula).
  2. Initially the pointers are equal. Start moving them forward until they are equal again, and keep track of the number of steps of the slow pointer. Let's name the number of steps for the pointers to get equal again K0. It's possible to prove, that PK0 < P+L and K0 = 0 (mod L).

    In this step we should also pay special attention to the case, when P = 0: when the pointers become equal, if they are also equal to the starting element, we should set K0 = 0, to be able to distinguish this case later.

    Time complexity of this step is O(N).

  3. Now the pointers are certainly inside the loop of the sequence. Again start moving them forward and keep track of the number of steps of the slow pointer until they get equal once more. This number of steps is the length of loop L of the sequence. (You can also move only the slow pointer in this step until it gets back to the same position, but I'll reuse the function to move both, which doesn't increase time complexity)

    Time complexity of this step is O(N).

  4. Now we have to calculate P. We can notice that if in the step 2) of the algorithm we start the "fast" pointer not from the start, but with some shift S: 0 ≤ S < L, then the result will be different: either KS = K0S, if SK0P; or KS = K0 + LS otherwise. So, we can use binary search to find the maximum shift S*: 0 ≤ S* < L, for which KS* = K0S*. Then we can find P = K0S*, and return P + L = K0S* + L

    Time complexity of this step is O(N log N) because each step in the binary search takes O(N).

So we have an algorithm, that works in O(N log N) with O(1) memory. Here is a code sample:

case class Sequence[T](start: T)(f: T => T) {
  def next = Sequence(f(start))(f)
  def forward(steps: Int) =
    Sequence(Function.chain(List.fill(steps)(f))(start))(f)
}

object Sequence {
  def neum(a: Int) = Sequence(a)(a => a * a / 100 % 10000)

  def movesToEquality[T](
    slow: Sequence[T], fast: Sequence[T], count: Int = 1
  ): (Sequence[T], Int) = {
    val nextSlow = slow.next
    val nextFast = fast.forward(2)
    if (nextSlow == nextFast) (nextSlow, count)
    else movesToEquality(nextSlow, nextFast, count+1)
  }

  def findLoopStart[T](seq: Sequence[T]): Int = {
    val (inLoop, k0) = movesToEquality(seq, seq) match {
      case (c, k) if c == seq => (c, 0)
      case other => other
    }
    val (_, loopSize) = movesToEquality(inLoop, inLoop)

    def binarySearch(lo: Int, hi: Int): Int = {
      if (lo + 1 >= hi) lo
      else {
        val mid = (lo + hi) / 2
        if (movesToEquality(seq, seq.forward(mid))._2 == k0 - mid)
          binarySearch(mid, hi)
        else
          binarySearch(lo, mid)
      }
    }

    k0 - binarySearch(0, loopSize) + loopSize
  }
}

object Main extends App {
  println(Sequence.findLoopStart(Sequence.neum(1)))
  println(Sequence.findLoopStart(Sequence.neum(4100)))
  println(Sequence.findLoopStart(Sequence.neum(5761)))
}

Upvotes: 2

Tesseract
Tesseract

Reputation: 8139

Does it have to be streams? Iterators are faster for this. Here are two different solutions. Both are functional and don't mutate any state.

def neumann(seed: Int): Int = {
  Iterator.iterate(seed)(s => ((s * s)/100)%10000)
          .scanLeft(Set.empty[Int])((set, n) => if(set(n)) Set(-1) else set + n)
          .takeWhile(_ != Set(-1)).size - 1
}

def neumann(seed: Int): Int = {
  def search(s: Int, set: Set[Int], count: Int): Int = {
    if(set(s)) count
    else search(((s * s)/100)%10000, set + s, count + 1)
  }
  search(seed, Set.empty[Int], 0)
}

Upvotes: 0

Chris Martin
Chris Martin

Reputation: 30736

I'm sure there's a nicer way to write this, but here's my first stab at it:

def loop[A](xs: Stream[A]): Set[A] =
  xs.scanLeft(Set.empty[A])(_ + _).sliding(2)
    .find(_.map(_.size).toSet.size == 1).get.head

scala> neum(93).take(8).toList
res0: List[Int] = List(93, 86, 73, 53, 28, 7, 0, 0)

scala> loop(neum(93))
res1: Set[Int] = Set(0, 93, 28, 53, 73, 86, 7)

Upvotes: 1

Related Questions