Clojure与Numpy的矩阵乘法

时间:2012-01-17 18:31:01

标签: python matrix numpy clojure

我正在使用Clojure中的一个应用程序,它需要繁殖大型矩阵,并且与相同的Numpy版本相比,遇到了一些大的性能问题。 Numpy似乎能够在一秒钟内通过其转置乘以1,000,000x23矩阵,而等效的clojure代码需要超过六分钟。 (我可以从Numpy打印出结果矩阵,所以它肯定会评估所有内容。)

我在这个Clojure代码中做了哪些非常错误的事情?我可以尝试模仿Numpy的一些技巧吗?

这是python:

import numpy as np

def test_my_mult(n):
    A = np.random.rand(n*23).reshape(n,23)
    At = A.T

    t0 = time.time()
    res = np.dot(A.T, A)
    print time.time() - t0
    print np.shape(res)

    return res

# Example (returns a 23x23 matrix):
# >>> results = test_my_mult(1000000)
# 
# 0.906938076019
# (23, 23)

和clojure:

(defn feature-vec [n]
  (map (partial cons 1)
       (for [x (range n)]
         (take 22 (repeatedly rand)))))

(defn dot-product [x y]
  (reduce + (map * x y)))

(defn transpose
  "returns the transposition of a `coll` of vectors"
  [coll]
  (apply map vector coll))

(defn matrix-mult
  [mat1 mat2]
  (let [row-mult (fn [mat row]
                   (map (partial dot-product row)
                        (transpose mat)))]
    (map (partial row-mult mat2)
         mat1)))

(defn test-my-mult
  [n afn]
  (let [xs  (feature-vec n)
        xst (transpose xs)]
    (time (dorun (afn xst xs)))))

;; Example (yields a 23x23 matrix):
;; (test-my-mult 1000 i/mmult) => "Elapsed time: 32.626 msecs"
;; (test-my-mult 10000 i/mmult) => "Elapsed time: 628.841 msecs"

;; (test-my-mult 1000 matrix-mult) => "Elapsed time: 14.748 msecs"
;; (test-my-mult 10000 matrix-mult) => "Elapsed time: 434.128 msecs"
;; (test-my-mult 1000000 matrix-mult) => "Elapsed time: 375751.999 msecs"


;; Test from wikipedia
;; (def A [[14 9 3] [2 11 15] [0 12 17] [5 2 3]])
;; (def B [[12 25] [9 10] [8 5]])

;; user> (matrix-mult A B)
;; ((273 455) (243 235) (244 205) (102 160))

更新:我使用JBLAS库实现了相同的基准测试,并发现了大规模,大规模的速度提升。感谢大家的投入!是时候把这个傻瓜包裹在Clojure中了。这是新代码:

(import '[org.jblas FloatMatrix])

(defn feature-vec [n]
  (FloatMatrix.
   (into-array (for [x (range n)]
                 (float-array (cons 1 (take 22 (repeatedly rand))))))))

(defn test-mult [n]
  (let [xs  (feature-vec n)
        xst (.transpose xs)]
    (time (let [result (.mmul xst xs)]
            [(.rows result)
             (.columns result)]))))

;; user> (test-mult 10000)
;; "Elapsed time: 6.99 msecs"
;; [23 23]

;; user> (test-mult 100000)
;; "Elapsed time: 43.88 msecs"
;; [23 23]

;; user> (test-mult 1000000)
;; "Elapsed time: 383.439 msecs"
;; [23 23]

(defn matrix-stream [rows cols]
  (repeatedly #(FloatMatrix/randn rows cols)))

(defn square-benchmark
  "Times the multiplication of a square matrix."
  [n]
  (let [[a b c] (matrix-stream n n)]
    (time (.mmuli a b c))
    nil))

;; forma.matrix.jblas> (square-benchmark 10)
;; "Elapsed time: 0.113 msecs"
;; nil
;; forma.matrix.jblas> (square-benchmark 100)
;; "Elapsed time: 0.548 msecs"
;; nil
;; forma.matrix.jblas> (square-benchmark 1000)
;; "Elapsed time: 107.555 msecs"
;; nil
;; forma.matrix.jblas> (square-benchmark 2000)
;; "Elapsed time: 793.022 msecs"
;; nil

9 个答案:

答案 0 :(得分:32)

Python版本正在编译为C中的循环,而Clojure版本正在为此代码中的每个调用构建一个新的中间序列。您看到的性能差异很可能来自数据结构的差异。

为了获得更好的效果,你可以使用像Incanter这样的库,或按照this SO question中的说明编写自己的版本。另请参阅this oneneanderthalnd4j。如果您真的希望保留序列以保持惰性评估属性等,那么通过查看内部矩阵计算的transients可以获得真正的提升

编辑:忘了添加调整clojure的第一步,打开"警告反思"

答案 1 :(得分:27)

Numpy正在链接到BLAS / Lapack例程,这些例程已经在机器架构层面上进行了数十年的优化,而Clojure则是以最简单和天真的方式实现乘法。

任何时候你都要执行非平凡的矩阵/向量运算,你应该链接到BLAS / LAPACK。

唯一不会更快的是来自语言的小矩阵,其中在语言运行时和LAPACK之间转换数据表示的开销超过了计算所花费的时间。

答案 2 :(得分:14)

我刚刚在Incanter 1.3和jBLAS 1.2.1之间进行了一次小小的枪战。这是代码:

(ns ml-class.experiments.mmult
  [:use [incanter core]]
  [:import [org.jblas DoubleMatrix]])

(defn -main [m]
  (let [n 23 m (Integer/parseInt m)
        ai (matrix (vec (double-array (* m n) (repeatedly rand))) n)
        ab (DoubleMatrix/rand m n)
        ti (copy (trans ai))
        tb (.transpose ab)]
    (dotimes [i 20]
      (print "Incanter: ") (time (mmult ti ai))
      (print "   jBLAS: ") (time (.mmul tb ab)))))

在我的测试中,在普通矩阵乘法中,Incanter始终比jBLAS慢约 45%。但是,Incanter trans函数不会创建矩阵的新副本,因此jBLAS中的(.mmul (.transpose ab) ab)需要两倍的内存,并且仅比{{1> 15%在Incanter。

考虑到Incanter丰富的功能集(特别是它的绘图库),我认为我不会很快切换到jBLAS。尽管如此,我还是希望看到jBLAS和Parallel Colt之间的另一次枪战,也许值得考虑在Incanter用jBLAS替换Parallel Colt? : - )


编辑:这是绝对数字(以毫秒为单位)。我上了(相当慢)PC:

(mmult (trans ai) ai)

对于每个库,我选择了20次运行中的最佳时间,矩阵大小为23x400000。

PS。 Haskell hmatrix结果接近于numpy,但我不确定如何正确地对其进行基准测试。

答案 3 :(得分:12)

Numpy代码使用内置库,在过去几十年中用Fortran编写,并由作者,CPU供应商和操作系统分销商(以及Numpy人员)进行优化,以获得最佳性能。您只是采用了完全直接,明显的矩阵乘法方法。实际上,性能不同并不奇怪。

但是如果你在Clojure中坚持做,那么考虑查找better algorithms,使用直接循环而不是像reduce这样的高阶函数,或者为Java找到合适的矩阵代数库(我怀疑在Clojure中有好的,但我真的不知道)是由一位称职的数学家写的。

最后,查看如何正确编写快速Clojure。使用类型提示,在代码上运行一个分析器(出乎意料!你的点产品功能耗尽了大部分时间),并将高级功能放入紧密循环中。

答案 4 :(得分:9)

正如@littleidea和其他人已经指出你的numpy版本正在使用LAPACK / BLAS / ATLAS,这将比你在clojure中做的任何事情都要快得多,因为它经过了多年的精心调整。 :)

那说Clojure代码的最大问题在于它使用了双打,就像盒装双打一样。我称之为“懒惰的双重”问题,我在工作中遇到了很多次。截至目前,即使使用1.3,clojure的集合也不是原始友好的。 (你可以创建一个基元向量,但它不会帮助你,因为所有的seq。函数最终会装箱!我还应该说1.3中的原始改进非常好并最终帮助..我们只是不是100%有集合中的WRT原始支持。)

在clojure中进行任何类型的矩阵数学时,你真的需要使用java数组,或者更好的是矩阵库。 Incanter确实使用了parrelcolt,但是你需要注意你使用的incanter功能...因为很多它们使得矩阵可以选择最终装箱双打,让你获得与你目前看到的相似的性能。 (顺便说一句,我有自己设置的parrelcolt包装器,如果你认为它们会有用,我可以发布它。)

为了使用BLAS库,您在java-land中有几个选项。使用所有这些选项,您必须支付JNA税...所有数据必须先复制才能处理。当你进行像矩阵分解这样的CPU绑定操作并且处理时间比复制数据所花费的时间更长时,这种税是有意义的。对于使用小矩阵的简单操作,保留在java-land中可能会更快。你只需要像上面那样做一些测试,看看什么最适合你。

以下是从java中使用BLAS的选项:

http://jblas.org/

http://code.google.com/p/netlib-java/

我应该指出parrelcolt使用netlib-java项目。这意味着,我相信,如果你正确地设置它将使用BLAS。但是,我没有证实这一点。有关jblas和netlib-java之间差异的解释,请参阅我在jblas的邮件列表上开始的这个帖子:

http://groups.google.com/group/jblas-users/browse_thread/thread/c9b3867572331aa5

我还应该指出Universal Java Matrix Package库:

http://sourceforge.net/projects/ujmp/

它包含了我提到的所有库,然后是一些库!虽然知道它的抽象是多么漏洞,但我并没有太多关注API。这似乎是一个很好的项目。我最终使用自己的parrelcolt clojure包装,因为它们足够快,我实际上非常喜欢colt API。 (Colt使用函数对象,这意味着我能够轻松地传递clojure函数!)

答案 5 :(得分:5)

如果你想在Clojure中做数字,我强烈建议使用Incanter,而不是尝试滚动你自己的矩阵函数等。

Incanter在引擎盖下使用Parallel Colt,这非常快。

修改

截至2013年初,如果你想在Clojure中做数字,我强烈建议你查看core.matrix

答案 6 :(得分:4)

Numpy针对线性代数进行了高度优化。当然对于大型矩阵,大多数处理都在本机C代码中。

为了匹配这种性能(假设它在Java中是可能的)你将不得不剥离Clojure的大部分抽象:在迭代大型矩阵时不要使用带有匿名函数的map,添加类型提示以启用原始Java的使用数组等。

可能最好的选择就是使用为数值计算优化的现成Java库(http://math.nist.gov/javanumerics/或类似的)。

答案 7 :(得分:0)

我没有任何具体的答案;只是一些建议。

  1. 使用分析器确定花费的时间
  2. 设置警告反射并在需要时使用类型提示
  3. 您可能不得不放弃一些高级构造并使用loop-recur来平衡最后一盎司的性能
  4. IME,Clojure代码应该非常接近Java(2或3X)。但你必须努力。

答案 8 :(得分:-3)

如果有意义,只使用map()。这意味着:如果你有一个特定的问题,如乘以两个矩阵,不要尝试map()它,只需乘以矩阵。

我倾向于只在语言有意义时使用map()(即如果程序真的比没有它时更可读)。乘法矩阵是如此明显的循环,映射它是没有意义的。

此致。

Pedro Fortuny。