alex314159
alex314159

Reputation: 3247

Optimize tail-recursion in Clojure: exponential moving average

I'm new to Clojure and trying to implement an exponential moving average function using tail recursion. After battling a little with stack overflows using lazy-seq and concat, I got to the following implementation which works, but is very slow:

(defn ema3 [c a]
    (loop [ct (rest c) res [(first c)]]
        (if (= (count ct) 0)
            res
            (recur
                (rest ct)
                (into;NOT LAZY-SEQ OR CONCAT
                    res
                    [(+ (* a (first ct)) (* (- 1 a) (last res)))]
                    )
                )
            )
        )
    )

For a 10,000 item collection, Clojure will take around 1300ms, whereas a Python Pandas call such as

s.ewm(alpha=0.3, adjust=True).mean()

will only take 700 us. How can I reduce that performance gap? Thank you,

Upvotes: 1

Views: 390

Answers (2)

amalloy
amalloy

Reputation: 91917

Personally I would do this lazily with reductions. It's simpler to do than using loop/recur or building up a result vector by hand with reduce, and it also means you can consume the result as it is built up, rather than needing to wait for the last element to be finished before you can look at the first one.

If you care most about throughput then I suppose Taylor Wood's reduce is the best approach, but the lazy solution is only very slightly slower and is much more flexible.

(defn ema3-reductions [c a]
  (let [a' (- 1 a)]
    (reductions
     (fn [ave x]
       (+ (* a x)
          (* (- 1 a') ave)))
     (first c)
     (rest c))))

user> (quick-bench (dorun (ema3-reductions (range 10000) 0.3)))

Evaluation count : 288 in 6 samples of 48 calls.
             Execution time mean : 2.336732 ms
    Execution time std-deviation : 282.205842 µs
   Execution time lower quantile : 2.125654 ms ( 2.5%)
   Execution time upper quantile : 2.686204 ms (97.5%)
                   Overhead used : 8.637601 ns
nil
user> (quick-bench (dorun (ema3-reduce (range 10000) 0.3)))
Evaluation count : 270 in 6 samples of 45 calls.
             Execution time mean : 2.357937 ms
    Execution time std-deviation : 26.934956 µs
   Execution time lower quantile : 2.311448 ms ( 2.5%)
   Execution time upper quantile : 2.381077 ms (97.5%)
                   Overhead used : 8.637601 ns
nil

Honestly in that benchmark you can't even tell the lazy version is slower than the vector version. I think my version is still slower, but it's a vanishingly trivial difference.

You can also speed things up if you tell Clojure to expect doubles, so it doesn't have to keep double-checking the types of a, c, and so on.

(defn ema3-reductions-prim [c ^double a]
  (let [a' (- 1.0 a)]
    (reductions (fn [ave x]
                  (+ (* a (double x))
                     (* a' (double ave))))
                (first c)
                (rest c))))

user> (quick-bench (dorun (ema3-reductions-prim (range 10000) 0.3)))
Evaluation count : 432 in 6 samples of 72 calls.
             Execution time mean : 1.720125 ms
    Execution time std-deviation : 385.880730 µs
   Execution time lower quantile : 1.354539 ms ( 2.5%)
   Execution time upper quantile : 2.141612 ms (97.5%)
                   Overhead used : 8.637601 ns
nil

Another 25% speedup, not too bad. I expect you could squeeze out a bit more by using primitives in either a reduce solution or with loop/recur if you were really desperate. It would be especially helpful in a loop because you wouldn't have to keep boxing and unboxing the intermediate results between double and Double.

Upvotes: 4

Taylor Wood
Taylor Wood

Reputation: 16194

If res is a vector (which it is in your example) then using peek instead of last yields much better performance:

(defn ema3 [c a]
  (loop [ct (rest c) res [(first c)]]
    (if (= (count ct) 0)
      res
      (recur
        (rest ct)
        (into
          res
          [(+ (* a (first ct)) (* (- 1 a) (peek res)))])))))

Your example on my computer:

(time (ema3 (range 10000) 0.3))
"Elapsed time: 990.417668 msecs"

Using peek:

(time (ema3 (range 10000) 0.3))
"Elapsed time: 9.736761 msecs"

Here's a version using reduce that's even faster on my computer:

(defn ema3 [c a]
  (reduce (fn [res ct]
            (conj
              res
              (+ (* a ct)
                 (* (- 1 a) (peek res)))))
          [(first c)]
          (rest c)))
;; "Elapsed time: 0.98824 msecs"

Take these timings with a grain of salt. Use something like criterium for more thorough benchmarking. You might be able to squeeze out more gains using mutability/transients.

Upvotes: 3

Related Questions