首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用java基元数组的clojure代码比scala版本慢70X

使用java基元数组的clojure代码比scala版本慢70X
EN

Stack Overflow用户
提问于 2016-07-30 16:49:06
回答 1查看 235关注 0票数 2

我用clojure和scala编写了编辑距离算法。

scala版本的运行速度比clojure版本快70倍。

克洛尔:

代码语言:javascript
复制
(defn edit-distance                                                                                                                                                                                                                                                             
  "['seq of char' 'seq of char']"                                                                                                                                                                                                                                               
  [s0 s1]                                                                                                                                                                                                                                                                       
  (let [n0 (count s0)                                                                                                                                                                                                                                                           
        n1 (count s1)                                                                                                                                                                                                                                                           
        distances (make-array Long/TYPE (inc n0) (inc n1))]                                                                                                                                                                                                                     
    ;;initialize distances                                                                                                                                                                                                                                                      
    (doseq [i (range 1 (inc n0))] (aset-long distances i 0 i))                                                                                                                                                                                                                  
    (doseq [j (range 1 (inc n1))] (aset-long distances 0 j j))                                                                                                                                                                                                                  

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]                                                                                                                                                                                                                         
      (let [ins (aget distances i (dec j))                                                                                                                                                                                                                                      
            del (aget distances (dec i) j)                                                                                                                                                                                                                                      
            match (aget distances (dec i) (dec j))                                                                                                                                                                                                                              
            min-dist (min ins del match)]                                                                                                                                                                                                                                       
        (cond                                                                                                                                                                                                                                                                   
          (not= match min-dist) (aset-long distances i j (inc min-dist))                                                                                                                                                                                                        
          (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset-long distances i j (inc min-dist))                                                                                                                                                                                     
          :else (aset-long distances i j min-dist))))                                                                                                                                                                                                                           
    (aget distances n0 n1)))     

scala:

代码语言:javascript
复制
 def editDistance(s0: Array[Char], s1: Array[Char]):Int = {                                                                                                                                                                                                                   
      val n0 = s0.length                                                                                                                                                                                                                                                        
      val n1 = s1.length                                                                                                                                                                                                                                                        
      val distances = Array.fill(n0+1)(ArrayBuffer.fill(n1+1)(0))                                                                                                                                                                                                               
      for(j <- 0 to n1){distances(0)(j) = j}                                                                                                                                                                                                                                    
      for(i <- 0 to n0){distances(i)(0) = i}                                                                                                                                                                                                                                    
      for(i <- 1 to n0; j <- 1 to n1){                                                                                                                                                                                                                                          
         val ins = distances(i)(j-1)                                                                                                                                                                                                                                            
         val del = distances(i-1)(j)                                                                                                                                                                                                                                            
         val matches = distances(i-1)(j-1)                                                                                                                                                                                                                                      
         val minDist = (ins::del::matches::Nil).reduceLeft(_ min _)                                                                                                                                                                                                             
         if (matches != minDist)                                                                                                                                                                                                                                                
            distances(i)(j) = minDist + 1                                                                                                                                                                                                                                       
         else if (s0(i-1) == s1(j-1))                                                                                                                                                                                                                                           
            distances(i)(j) = minDist                                                                                                                                                                                                                                           
         else                                                                                                                                                                                                                                                                   
            distances(i)(j) = minDist + 1                                                                                                                                                                                                                                       
      }                                                                                                                                                                                                                                                                         
      distances(n0)(n1)                                                                                                                                                                                                                                                         
   }                                 

我在clojure中使用java的数组来获得最好的性能。无论何时调用aget,我都考虑过暗示,但我的代码执行得更糟(这可能是make-array已经定义了类型化数组)。我还在projects.clj中重写了clojure的projects.clj。然而,我得到的较低的性能差距是70倍。

我在clojure中使用java数组有什么问题?

谢谢你的洞察力。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-07-30 18:52:07

我想我知道问题出在哪里了。

正如您在注释中提到的,反射调用占用了大部分时间。这就是为什么。

在分析我将*warn-on-reflection*设置为true的代码之前:

代码语言:javascript
复制
(set! *warn-on-reflection* true)

然后,如果您查看生成aset函数的源,您将看到对于4+,它使用apply来调用函数。对于aget来说,3+也是如此。我不是百分之百确定,但我相信在apply函数运行过程中会丢失有关参数类型的信息。此外,如果仔细查看这里这里,您可能会注意到在编译期间可以内联agetaset函数。我们绝对希望:

代码语言:javascript
复制
(defn edit-distance
  "['seq of char' 'seq of char']"
  [s0 s1]
  (let [n0 (count s0)
        n1 (count s1)
        distances (make-array Long/TYPE (inc n0) (inc n1))]
    ;; I've unwinded all aget/aset calls, so they can be inlined by compiler.
    ;; Also I'm type hinting first argument of toplevel aget/aset calls.
    ;; The reason is explained next.
    (doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i))
    (doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j))

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
      (let [ins (aget ^longs (aget distances i) (dec j))
            del (aget ^longs (aget distances (dec i))  j)
            match (aget ^longs (aget distances (dec i)) (dec j))
            min-dist (min ins del match)]
        (cond
          (not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist))
          (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist))
          :else (aset ^longs (aget distances i) j min-dist))))
    ;; we can leave this, since it is not placed within loop
    (aget distances n0 n1)))

让我们编译我们的新函数。还记得我们在开头设置的全局变量吗?如果设置为true,编译器将在编译期间生成一系列警告:

代码语言:javascript
复制
Reflection warning, core.clj:75:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
Reflection warning, core.clj:76:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
Reflection warning, core.clj:77:25 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
...

问题是Clojure无法确定(make-array Long/TYPE (inc n0) (inc n1))的类型,将其标记为unknown。我们需要键入提示:

代码语言:javascript
复制
(let [...
      ;; type hint for 2d array of primitive longs
      ^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))
      ...]
   ...)

在这一点上,我们似乎都准备好了。最后版本如下:

代码语言:javascript
复制
(defn edit-distance
  "['seq of char' 'seq of char']"
  [s0 s1]
  (let [n0 (count s0)
        n1 (count s1)
        ^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))]
    ;;initialize distances
    (doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i))
    (doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j))

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
      (let [ins (aget ^longs (aget distances i) (dec j))
            del (aget ^longs (aget distances (dec i))  j)
            match (aget ^longs (aget distances (dec i)) (dec j))
            min-dist (min ins del match)]
        (cond
          (not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist))
          (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist))
          :else (aset ^longs (aget distances i) j min-dist))))
    (aget distances n0 n1)))

以下是基准:

在此之前:

代码语言:javascript
复制
> (time (edit-distance i1 i2))
"Elapsed time: 4601.025555 msecs"
291

之后:

代码语言:javascript
复制
> (time (edit-distance i1 i2))
"Elapsed time: 27.782828 msecs"
291
票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/38676298

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档