优化Haskell递归列表

时间:2011-11-15 19:36:35

标签: optimization haskell

来自previous的另一个Haskell优化问题。我需要递归生成一个列表,类似于许多介绍性Haskell文章中的fibs函数:

generateSchedule :: [Word32] -> [Word32]
generateSchedule blkw = take 80 ws
    where
    ws          = blkw ++ zipWith4 gen (drop 13 ws) (drop 8 ws) (drop 2 ws) ws
    gen a b c d = rotate (a `xor` b `xor` c `xor` d) 1

上面的功能已经超过了我的最多时间和分配消耗功能。分析器为我提供了以下统计信息:

COST CENTRE        MODULE             %time %alloc  ticks     bytes
generateSchedule   Test.Hash.SHA1     22.1   40.4   31        702556640

我想过应用未装箱的向量来计算列表,但由于列表是递归的,因此无法找到方法。这将在C中有一个自然的实现,但我没有看到一种方法来使这更快(除了展开和写80行变量声明)。有什么帮助吗?

更新:我实际上已经快速展开它以查看它是否有帮助。代码是here。它很丑,事实上它更慢。

COST CENTRE        MODULE             %time %alloc  ticks     bytes
generateSchedule   GG.Hash.SHA1       22.7   27.6   40        394270592

2 个答案:

答案 0 :(得分:5)

import Data.Array.Base
import Data.Array.ST
import Data.Array.Unboxed

generateSchedule :: [Word32] -> UArray Int Word32
generateSchedule ws0 = runSTUArray $ do
    arr <- unsafeNewArray_ (0,79)
    let fromList i [] = fill i 0
        fromList i (w:ws) = do
            unsafeWrite arr i w
            fromList (i+1) ws
        fill i j
          | i == 80 = return arr
          | otherwise = do
              d <- unsafeRead arr j
              c <- unsafeRead arr (j+2)
              b <- unsafeRead arr (j+8)
              a <- unsafeRead arr (j+13)
              unsafeWrite arr i (gen a b c d)
              fill (i+1) (j+1)
    fromList 0 ws0

将创建与您的列表对应的未装箱数组。它依赖于list参数包含至少14个且最多80个项目的假设,否则它将是错误的。我认为它总是16项(64字节),所以这对你来说应该是安全的。 (但是最好从ByteString直接开始填充,而不是构建一个中间列表。)

通过在进行散列轮次之前对此进行严格评估,可以节省列表构造和使用延迟构造列表进行散列之间的切换,从而减少所需的时间。通过使用未装箱的阵列,我们避免了列表的分配开销,这可能会进一步提高速度(但是ghc的分配器速度非常快,因此不要期望太大的影响)。

在您的哈希回合中,通过Word32获取所需的unsafeAt array t以避免不必要的边界检查。

附录:如果你对每个wn发出爆炸声,那么展开列表的创建可能会更快,尽管我不确定。既然你已经有了代码,那么添加刘海和检查并不算太多,是吗?我很好奇。

答案 1 :(得分:1)

我们可以使用惰性数组在直接变异和使用纯列表之间获得中途。你可以获得递归定义的好处,但是由于这个原因,仍然需要付出懒惰和拳击的代价 - 尽管不如列表那样。以下代码使用标准来测试两个惰性数组解决方案(使用标准数组和向量)以及上面的原始列表代码和Daniel的可变uarray代码:

module Main where
import Data.Bits
import Data.List
import Data.Word
import qualified Data.Vector as LV
import Data.Array.ST
import Data.Array.Unboxed
import qualified Data.Array as A
import Data.Array.Base
import Criterion.Main

gen :: Word32 -> Word32 -> Word32 -> Word32 -> Word32
gen a b c d = rotate (a `xor` b `xor` c `xor` d) 1

gss blkw = LV.toList v
    where v = LV.fromList $ blkw ++ rest
          rest = map (\i -> gen (LV.unsafeIndex v (i + 13))
                                (LV.unsafeIndex v (i + 8))
                                (LV.unsafeIndex v (i + 2))
                                (LV.unsafeIndex v i)
                     )
                 [0..79 - 14]

gss' blkw = A.elems v
    where v = A.listArray (0,79) $ blkw ++ rest
          rest = map (\i -> gen (unsafeAt v (i + 13))
                                (unsafeAt v (i + 8))
                                (unsafeAt v (i + 2))
                                (unsafeAt v i)
                     )
                 [0..79 - 14]

generateSchedule :: [Word32] -> [Word32]
generateSchedule blkw = take 80 ws
    where
    ws          = blkw ++ zipWith4 gen (drop 13 ws) (drop 8 ws) (drop 2 ws) ws

gs :: [Word32] -> [Word32]
gs ws = elems (generateSched ws)

generateSched :: [Word32] -> UArray Int Word32
generateSched ws0 = runSTUArray $ do
    arr <- unsafeNewArray_ (0,79)
    let fromList i [] = fill i 0
        fromList i (w:ws) = do
            unsafeWrite arr i w
            fromList (i+1) ws
        fill i j
          | i == 80 = return arr
          | otherwise = do
              d <- unsafeRead arr j
              c <- unsafeRead arr (j+2)
              b <- unsafeRead arr (j+8)
              a <- unsafeRead arr (j+13)
              unsafeWrite arr i (gen a b c d)
              fill (i+1) (j+1)
    fromList 0 ws0

args = [0..13]

main = defaultMain [
        bench "list"   $ whnf (sum . generateSchedule) args
       ,bench "vector" $ whnf (sum . gss) args
       ,bench "array"  $ whnf (sum . gss') args
       ,bench "uarray" $ whnf (sum . gs) args
       ]

我使用-O2-funfolding-use-threshold=256编译代码以强制进行大量内联。

标准基准测试表明,矢量解决方案略好一些,阵列解决方案稍微好一点,但未装箱的可变解决方案仍然以压倒性优势获胜:

benchmarking list
mean: 8.021718 us, lb 7.720636 us, ub 8.605683 us, ci 0.950
std dev: 2.083916 us, lb 1.237193 us, ub 3.309458 us, ci 0.950

benchmarking vector
mean: 6.829923 us, lb 6.725189 us, ub 7.226799 us, ci 0.950
std dev: 882.3681 ns, lb 76.20755 ns, ub 2.026598 us, ci 0.950

benchmarking array
mean: 6.212669 us, lb 5.995038 us, ub 6.635405 us, ci 0.950
std dev: 1.518521 us, lb 946.8826 ns, ub 2.409086 us, ci 0.950

benchmarking uarray
mean: 2.380519 us, lb 2.147896 us, ub 2.715305 us, ci 0.950
std dev: 1.411092 us, lb 1.083180 us, ub 1.862854 us, ci 0.950

我也运行了一些基本的分析,并注意到懒惰/盒装数组解决方案比列表解决方案稍微好一点,但再次明显比未装箱数组方法更差。