为ST monad编写有效的迭代循环

时间:2013-06-15 18:55:59

标签: loops haskell state monads

go worker tail-recursive loop pattern似乎非常适合编写纯代码。为ST monad编写这种循环的等效方法是什么?更具体地说,我想避免循环迭代中的新堆分配。我的猜测是它包含CPS transformationfixST来重写代码,使得循环中所有正在变化的值都会在每次迭代中传递,从而产生寄存器位置(或堆栈以防spill)可用于迭代的那些值。我有一个简化的例子(不要尝试运行它 - 它可能会因为分段错误而崩溃!),涉及一个名为findSnakes的函数,该函数具有go工作者模式但不会传递更改的状态值通过累加器参数:

{-# LANGUAGE BangPatterns #-}
module Test where

import Data.Vector.Unboxed.Mutable as MU
import Data.Vector.Unboxed as U hiding (mapM_)
import Control.Monad.ST as ST
import Control.Monad.Primitive (PrimState)
import Control.Monad as CM (when,forM_)
import Data.Int

type MVI1 s  = MVector (PrimState (ST s)) Int

-- function to find previous y
findYP :: MVI1 s -> Int -> Int -> ST s Int
findYP fp k offset = do
              y0 <- MU.unsafeRead fp (k+offset-1) >>= \x -> return $ 1+x
              y1 <- MU.unsafeRead fp (k+offset+1)
              if y0 > y1 then return y0
              else return y1
{-#INLINE findYP #-}

findSnakes :: Vector Int32 -> MVI1 s ->  Int -> Int -> (Int -> Int -> Int) -> ST s ()
findSnakes a fp !k !ct !op = go 0 k
     where
           offset=1+U.length a
           go x k'
            | x < ct = do
                yp <- findYP fp k' offset
                MU.unsafeWrite fp (k'+offset) (yp + k')
                go (x+1) (op k' 1)
            | otherwise = return ()
{-#INLINE findSnakes #-}

查看cmm中的ghc 7.6.1输出(我对cmm知之甚少 - 如果我弄错了请纠正我),我看到这种呼叫流程,循环进入s1tb_info(在每次迭代中导致堆分配和堆检查):

findSnakes_info -> a1_r1qd_info -> $wa_r1qc_info (new stack allocation, SpLim check)
-> s1sy_info -> s1sj_info: if arg > 1 then s1w8_info else R1 (can't figure out 
what that register points to)

-- I am guessing this one below is for go loop
s1w8_info -> s1w7_info (big heap allocation, HpLim check) -> s1tb_info: if arg >= 1
then s1td_info else R1

s1td_info (big heap allocation, HpLim check) -> if arg >= 1 then s1tb_info
(a loop) else s1tb_info (after executing a different block of code)

我的猜测是arg >= 1代码中cmm格式的检查是确定go循环是否已终止。如果这是正确的,似乎除非重写go循环以跨越循环传递yp,否则将在循环中为新值发生堆分配(我猜yp导致堆分配) 。在上面的示例中,编写go循环的有效方法是什么?我想yp必须作为go循环中的参数传递,或通过fixSTCPS转换等效传递。我想不出一个好的方法来重写上面的go循环来删除堆分配,并且会很感激它的帮助。

1 个答案:

答案 0 :(得分:3)

我重写了你的函数以避免任何显式递归并删除了一些计算偏移量的冗余操作。这比使用原始函数编译得更好。

顺便说一下,核心可能是分析编译代码以进行此类分析的更好方法。使用ghc -ddump-simpl查看生成的核心输出或ghc-core

等工具
import Control.Monad.Primitive                                                                               
import Control.Monad.ST                                                                                      
import Data.Int                                                                                              
import qualified Data.Vector.Unboxed.Mutable as M                                                            
import qualified Data.Vector.Unboxed as U                                                                    

type MVI1 s  = M.MVector (PrimState (ST s)) Int                                                              

findYP :: MVI1 s -> Int -> ST s Int                                                                                                                                                     
findYP fp offset = do                                                                                      
    y0 <- M.unsafeRead fp (offset+0)                                                                       
    y1 <- M.unsafeRead fp (offset+2)                                                                       
    return $ max (y0 + 1) y1                                                                                  

findSnakes :: U.Vector Int32 -> MVI1 s ->  Int -> Int -> (Int -> Int -> Int) -> ST s ()                                                                                                         
findSnakes a fp k0 ct op = U.mapM_ writeAt $ U.iterateN ct (`op` 1) k0                                       
    where writeAt k = do    
              let offset = U.length a + k                                                                 
              yp <- findYP fp offset                                                                        
              M.unsafeWrite fp (offset + 1) (yp + k)

          -- or inline findYP manually
          writeAt k = do
             let offset = U.length a + k
             y0 <- M.unsafeRead fp (offset + 0)
             y1 <- M.unsafeRead fp (offset + 2)
             M.unsafeWrite fp (offset + 1) (k + max (y0 + 1) y1)

此外,您将U.Vector Int32传递给findSnakes,仅计算其长度,不再使用a。为什么不直接传递长度?