user3639782
user3639782

Reputation: 507

clojure code using java primitive arrays 70X slower than scala version

I wrote an edit distance algorithm both in clojure and scala.

The scala version runs 70x faster than the clojure one.

clojure:

(defn edit-distance                                                                                                                                                                                                                                                             
  "['seq of char' 'seq of char']"                                                                                                                                                                                                                                               
  [s0 s1]                                                                                                                                                                                                                                                                       
  (let [n0 (count s0)                                                                                                                                                                                                                                                           
        n1 (count s1)                                                                                                                                                                                                                                                           
        distances (make-array Long/TYPE (inc n0) (inc n1))]                                                                                                                                                                                                                     
    ;;initialize distances                                                                                                                                                                                                                                                      
    (doseq [i (range 1 (inc n0))] (aset-long distances i 0 i))                                                                                                                                                                                                                  
    (doseq [j (range 1 (inc n1))] (aset-long distances 0 j j))                                                                                                                                                                                                                  

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]                                                                                                                                                                                                                         
      (let [ins (aget distances i (dec j))                                                                                                                                                                                                                                      
            del (aget distances (dec i) j)                                                                                                                                                                                                                                      
            match (aget distances (dec i) (dec j))                                                                                                                                                                                                                              
            min-dist (min ins del match)]                                                                                                                                                                                                                                       
        (cond                                                                                                                                                                                                                                                                   
          (not= match min-dist) (aset-long distances i j (inc min-dist))                                                                                                                                                                                                        
          (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset-long distances i j (inc min-dist))                                                                                                                                                                                     
          :else (aset-long distances i j min-dist))))                                                                                                                                                                                                                           
    (aget distances n0 n1)))     

scala:

 def editDistance(s0: Array[Char], s1: Array[Char]):Int = {                                                                                                                                                                                                                   
      val n0 = s0.length                                                                                                                                                                                                                                                        
      val n1 = s1.length                                                                                                                                                                                                                                                        
      val distances = Array.fill(n0+1)(ArrayBuffer.fill(n1+1)(0))                                                                                                                                                                                                               
      for(j <- 0 to n1){distances(0)(j) = j}                                                                                                                                                                                                                                    
      for(i <- 0 to n0){distances(i)(0) = i}                                                                                                                                                                                                                                    
      for(i <- 1 to n0; j <- 1 to n1){                                                                                                                                                                                                                                          
         val ins = distances(i)(j-1)                                                                                                                                                                                                                                            
         val del = distances(i-1)(j)                                                                                                                                                                                                                                            
         val matches = distances(i-1)(j-1)                                                                                                                                                                                                                                      
         val minDist = (ins::del::matches::Nil).reduceLeft(_ min _)                                                                                                                                                                                                             
         if (matches != minDist)                                                                                                                                                                                                                                                
            distances(i)(j) = minDist + 1                                                                                                                                                                                                                                       
         else if (s0(i-1) == s1(j-1))                                                                                                                                                                                                                                           
            distances(i)(j) = minDist                                                                                                                                                                                                                                           
         else                                                                                                                                                                                                                                                                   
            distances(i)(j) = minDist + 1                                                                                                                                                                                                                                       
      }                                                                                                                                                                                                                                                                         
      distances(n0)(n1)                                                                                                                                                                                                                                                         
   }                                 

I am using java's array in clojure to get the best performance. I have considered hinting whenever agetis called but my code performs even worse (which might be expected as make-array already defines a typed array). I have also overridden clojure :jvm-opts in projects.clj. Yet the lower performance gap I get is 70x.

What's wrong with my use of java array in clojure?

Thanks for insight.

Upvotes: 2

Views: 251

Answers (1)

OlegTheCat
OlegTheCat

Reputation: 4513

I think I figured out where the problem lies.

As you mentioned in the comment, the reflection calls consume most of the time. Here's why.

Before analyzing the code I've set *warn-on-reflection* to true:

(set! *warn-on-reflection* true)

Then, if you look at the source of aset or macro that generates aset-long function, you'll see that for 4+ arities it uses apply to invoke the functions. Same thing for aget for 3+ arities. I'm not 100% sure, but I believe that information about types of arguments is lost during applying a function. Also if you look closely here and here you may notice that aget and aset functions can be inlined during compilation. We definitely want that:

(defn edit-distance
  "['seq of char' 'seq of char']"
  [s0 s1]
  (let [n0 (count s0)
        n1 (count s1)
        distances (make-array Long/TYPE (inc n0) (inc n1))]
    ;; I've unwinded all aget/aset calls, so they can be inlined by compiler.
    ;; Also I'm type hinting first argument of toplevel aget/aset calls.
    ;; The reason is explained next.
    (doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i))
    (doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j))

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
      (let [ins (aget ^longs (aget distances i) (dec j))
            del (aget ^longs (aget distances (dec i))  j)
            match (aget ^longs (aget distances (dec i)) (dec j))
            min-dist (min ins del match)]
        (cond
          (not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist))
          (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist))
          :else (aset ^longs (aget distances i) j min-dist))))
    ;; we can leave this, since it is not placed within loop
    (aget distances n0 n1)))

Let's compile our new function. Remember that global variable that we've set at the beginning? If set to true, compiler will produce a bunch of warnings during compilation:

Reflection warning, core.clj:75:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
Reflection warning, core.clj:76:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
Reflection warning, core.clj:77:25 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
...

The problem is that Clojure cannot figure out the type of (make-array Long/TYPE (inc n0) (inc n1)), marking it as unknown. We need to type hint it:

(let [...
      ;; type hint for 2d array of primitive longs
      ^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))
      ...]
   ...)

At this point, it seems that we're all set. The final version is below:

(defn edit-distance
  "['seq of char' 'seq of char']"
  [s0 s1]
  (let [n0 (count s0)
        n1 (count s1)
        ^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))]
    ;;initialize distances
    (doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i))
    (doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j))

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
      (let [ins (aget ^longs (aget distances i) (dec j))
            del (aget ^longs (aget distances (dec i))  j)
            match (aget ^longs (aget distances (dec i)) (dec j))
            min-dist (min ins del match)]
        (cond
          (not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist))
          (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist))
          :else (aset ^longs (aget distances i) j min-dist))))
    (aget distances n0 n1)))

Here are benchmarks:

before:

> (time (edit-distance i1 i2))
"Elapsed time: 4601.025555 msecs"
291

after:

> (time (edit-distance i1 i2))
"Elapsed time: 27.782828 msecs"
291

Upvotes: 4

Related Questions