我对Project Euler#92的尝试太慢了

时间:2013-05-01 17:11:19

标签: haskell functional-programming

我正在尝试用Haskell解决Project Euler problem #92。我最近开始学习Haskell。这是我试图用Haskell解决的第一个Project Euler问题,但是我的代码片段在10分钟内也没有终止。我知道你不直接给我答案,但我应该再次警告我用c ++找回答并不能给出欧拉的答案或解决欧拉的新逻辑。我只是好奇为什么那个人不能快速工作,我该怎么办才能让它更快?

{--EULER 92--}
import Data.List


myFirstFunction 1 = 0
myFirstFunction 89 = 1
myFirstFunction x= myFirstFunction (giveResult x)

giveResult 0 = 0
giveResult x = (square (mod x 10)) + (giveResult (div x 10))

square x = x*x


a=[1..10000000]


main = putStrLn(show (sum (map myFirstFunction a))) 

2 个答案:

答案 0 :(得分:22)

当然,使用更好的算法可以获得最大的加速。不过,我并没有深入探讨这一点。

原始算法调整

因此,让我们专注于改进使用过的算法,而不是真正改变它。

  1. 您永远不会给出任何类型签名,因此类型默认为任意精度Integer。这里的所有内容都很容易适应Int,没有溢出的危险,所以让我们使用它。添加类型签名myFirstFunction :: Int -> Int有助于:时间从Total time 13.77s ( 13.79s elapsed)降至Total time 6.24s ( 6.24s elapsed),总分配下降约15倍。对于这种简单的更改,这不是坏事。

  2. 您使用divmod。这些总是计算非负余数和相应的商,因此如果涉及一些负数,它们需要一些额外的检查。函数quotrem映射到机器分区指令,它们不涉及此类检查,因此更快一些。如果通过LLVM后端(-fllvm)进行编译,那么这也会利用您总是除以一个已知数字(10)的事实,并将除法转换为乘法和位移。现在时间:Total time 1.56s ( 1.56s elapsed)

  3. 我们不是单独使用quotrem,而是使用同时计算两者的quotRem函数,这样我们就不会重复除法(即使乘法+移位需要一点时间):

    giveResult x = case x `quotRem` 10 of
                     (q,r) -> r*r + giveResult q
    

    这并没有多大收获,只有一点点:Total time 1.49s ( 1.49s elapsed)

  4. 您正在使用列表a = [1 .. 10000000]map该列表中的函数,然后sum生成的列表。这是惯用的,整洁的,但不是超快的,因为分配所有这些列表单元和垃圾收集它们也需要时间 - 不是很多,因为GHC 非常擅长这一点,但是将其转换为循环

    main = print $ go 0 1
        where
            go acc n
                | n > 10000000 = acc
                | otherwise    = go (acc + myFirstFunction n) (n+1)
    

    让我们稍微停顿一下:Total time 1.34s ( 1.34s elapsed),分配从最后一个列表版本的880,051,856 bytes allocated in the heap下降到51,840 bytes allocated in the heap

  5. giveResult是递归的,因此无法内联。同样适用于myFirstFunction,因此每次计算都需要两次函数调用(至少)。我们可以通过将giveResult重写为非递归包装器和递归本地循环来避免这种情况,

    giveResult x = go 0 x
        where
            go acc 0 = acc
            go acc n = case n `quotRem` 10 of
                         (q,r) -> go (acc + r*r) q
    

    这样可以内联:Total time 1.04s ( 1.04s elapsed)

  6. 这些是最明显的观点,进一步的改进 - 除了哈马尔在评论中提到的备忘录 - 需要一些思考。

    我们现在在

    module Main (main) where
    
    myFirstFunction :: Int -> Int
    myFirstFunction 1 = 0
    myFirstFunction 89 = 1
    myFirstFunction x= myFirstFunction (giveResult x)
    
    giveResult :: Int -> Int
    giveResult x = go 0 x
        where
            go acc 0 = acc
            go acc n = case n `quotRem` 10 of
                         (q,r) -> go (acc + r*r) q
    
    main :: IO ()
    main = print $ go 0 1
        where
            go acc n
                | n > 10000000 = acc
                | otherwise    = go (acc + myFirstFunction n) (n+1)
    

    使用-O2 -fllvm,在此处运行1.04秒,但使用本机代码生成器(仅-O2),需要3.5秒。这种差异是由于GHC本身并没有将除法转换为乘法和位移的事实。如果我们手动完成,我们可以从本机代码生成器获得相同的性能。

    因为我们知道编译器没有的东西,即我们从不在这里处理负数,并且数字不会变大,我们甚至可以产生更好的乘法和移位(会比编译器产生错误的负数或大股息结果,本机代码生成器的时间缩短为0.9秒,LLVM后端的时间缩短为0.73秒:

    import Data.Bits
    
    qr10 :: Int -> (Int, Int)
    qr10 n = (q, r)
      where
        q = (n * 0x66666667) `unsafeShiftR` 34
        r = n - 10 * q
    

    注意:这要求Int是64位类型,它不能使用32位Int,它会产生错误的结果对于否定n,对于大n,乘法将溢出。我们正在进入肮脏的黑客领域。我们可以使用Word代替Int来减轻肮脏,只留下溢出(n <= 10737418236 Wordn <= 5368709118 Int不会发生溢出} #include <stdio.h> unsigned int myFirstFunction(unsigned int i); unsigned int giveResult(unsigned int i); int main(void) { unsigned int sum = 0; for(unsigned int i = 1; i <= 10000000; ++i) { sum += myFirstFunction(i); } printf("%u\n",sum); return 0; } unsigned int myFirstFunction(unsigned int i) { if (i == 1) return 0; if (i == 89) return 1; return myFirstFunction(giveResult(i)); } unsigned int giveResult(unsigned int i) { unsigned int acc = 0, r, q; while(i) { q = (i*0x66666667UL) >> 34; r = i - q*10; i = q; acc += r*r; } return acc; } ,所以在这里我们舒适地处于安全区。时间没有受到影响。

    相应的C程序

    gcc -O3

    执行类似,使用clang -O3编译,运行0.78秒,<= 7*9²运行0.71。

    在不改变算法的情况下,这几乎就是结束。


    Memoisation

    现在,算法的一个微小变化就是记忆。如果我们为数字module Main (main) where import Data.Array.Unboxed import Data.Array.IArray import Data.Array.Base (unsafeAt) import Data.Bits qr10 :: Int -> (Int, Int) qr10 n = (q, r) where q = (n * 0x66666667) `unsafeShiftR` 34 r = n - 10 * q digitSquareSum :: Int -> Int digitSquareSum = go 0 where go acc 0 = acc go acc n = case qr10 n of (q,r) -> go (acc + r*r) q table :: UArray Int Int table = array (0,567) $ assocs helper where helper :: Array Int Int helper = array (0,567) [(i, f i) | i <- [0 .. 567]] f 0 = 0 f 1 = 0 f 89 = 1 f n = helper ! digitSquareSum n endPoint :: Int -> Int endPoint n = table `unsafeAt` digitSquareSum n main :: IO () main = print $ go 0 1 where go acc n | n > 10000000 = acc | otherwise = go (acc + endPoint n) (n+1) 建立一个查找表,我们只需要计算每个数字的数字平方和,而不是迭代它直到我们达到1或89,所以让我们回忆一下,

    gcc -O3

    手动进行记忆而不是使用库会使代码更长,但我们可以根据需要定制代码。我们可以使用一个未装箱的数组,我们可以省略数组访问的边界检查。两者都显着加快了计算速度。本机代码生成器的时间现在为0.18秒,LLVM后端的时间为0.13秒。相应的C程序用clang -O3编译运行0.16秒,用ghc -O2 -fllvm编译0.145秒(Haskell击败C,w00t!)。


    缩放和提示更好的算法

    然而,所使用的算法不能很好地扩展,比线性更差,并且对于10 8 的上限(具有适当调整的记忆限制),它在1.5秒内运行(clang -O3),resp。 1.64秒(gcc -O3)和1.87秒(10 = 1×3² + 1×1² 10 = 2×2² + 2×1² 10 = 1×2² + 6×1² 10 = 10×1² )[本机代码生成器的2.02秒]。

    使用不同的算法,通过将这些数字划分为数字的平方和来计算序列以1结尾的数字(直接产生1的唯一数字是10的幂。我们可以写

    <= 10^10

    从第一次开始,我们获得13,31,103,130,301,310,1003,1030,1300,3001,3010,3100,...... 从第二个,我们获得1122,1212,1221,2112,2121,2211,11022,11202 ...... 从第三个1111112,1111121,......

    只有13,31,103,130,301,310是数字100 = 1×9² + 1×4² + 3×1² ... 100 = 1×8² + 1×6² ... 数字的可能平方和,因此只需要进一步调查。我们可以写

    $ time ./problem92 7
    8581146
    
    real    0m0.010s
    user    0m0.008s
    sys     0m0.002s
    $ time ./problem92 8
    85744333
    
    real    0m0.022s
    user    0m0.018s
    sys     0m0.003s
    $ time ./problem92 9
    854325192
    
    real    0m0.040s
    user    0m0.033s
    sys     0m0.006s
    $ time ./problem92 10
    8507390852
    
    real    0m0.074s
    user    0m0.069s
    sys     0m0.004s
    

    这些分区中的第一个不生成子项,因为它需要五个非零数字,另一个明确给出生成两个子项68和86(如果限制为10 8 ,则为608,对于更大的限制,更多)),我们可以获得更好的缩放和更快的算法。

    当我解决这个问题时,我写回来的相当未被优化的程序运行(输入是限制的10的指数)

    {{1}}

    在另一个联赛中。

答案 1 :(得分:8)

首先,我冒昧地清理你的代码:

endsAt89 1  = 0
endsAt89 89 = 1
endsAt89 n  = endsAt89 (sumOfSquareDigits n)

sumOfSquareDigits 0 = 0
sumOfSquareDigits n = (n `mod` 10)^2 + sumOfSquareDigits (n `div` 10)    

main = print . sum $ map endsAt89 [1..10^7]

在我糟糕的上网本上是1分13秒。让我们看看我们是否可以改善这一点。

由于数字很小,我们可以先使用机器大小的Int而不是任意大小的Integer。这只是添加类型签名的问题,例如

sumOfSquareDigits :: Int -> Int

这大大缩短了20秒的运行时间。

由于这些数字都是正数,我们可以将divmod替换为稍快一点quotrem,或者甚至同时使用{{1} }}:

quotRem

现在运行时间为17秒。让它的尾部递归剃掉另一秒:

sumOfSquareDigits :: Int -> Int
sumOfSquareDigits 0 = 0
sumOfSquareDigits n = r^2 + sumOfSquareDigits q
  where (q, r) = quotRem x 10

为了进一步改进,我们可以注意到sumOfSquareDigits :: Int -> Int sumOfSquareDigits n = loop n 0 where loop 0 !s = s loop n !s = loop q (s + r^2) where (q, r) = quotRem n 10 对于给定的输入数字最多返回sumOfSquareDigits,因此我们可以记住小数字以减少所需的迭代次数。这是我的最终版本(使用data-memocombinators包进行记忆):

567 = 7 * 9^2

在我的机器上运行时间不到9秒。