优化Haskell中的n-queens

时间:2018-04-04 13:47:36

标签: performance haskell optimization n-queens

此代码:

{-# LANGUAGE BangPatterns #-}


module Main where

import Data.Bits
import Data.Word
import Control.Monad
import System.CPUTime
import Data.List

-- The Damenproblem.
-- Wiki: https://de.wikipedia.org/wiki/Damenproblem
main :: IO ()
main = do
  start <- getCPUTime
  print $ dame 14
  end <- getCPUTime
  print $ "Needed " ++ (show ((fromIntegral (end - start)) / (10^12))) ++ " Seconds"

type BitState = (Word64, Word64, Word64)

dame :: Int -> Int
dame max = foldl' (+) 0 $ map fn row
  where fn x = recur (max - 2) $ nextState (x, x, x)
        recur !depth !state = foldl' (+) 0 $ flip map row $ getPossible depth (getStateVal state) state
        getPossible depth !stateVal state bit
          | (bit .&. stateVal) > 0 = 0
          | depth == 0 = 1
          | otherwise = recur (depth - 1) (nextState (addBitToState bit state))
        row = take max $ iterate moveLeft 1

getStateVal :: BitState -> Word64
getStateVal (l, r, c) = l .|. r .|. c

addBitToState :: Word64 -> BitState -> BitState
addBitToState l (ol, or, oc) = (ol .|. l, or .|. l, oc .|. l)

nextState :: BitState -> BitState
nextState (l, r, c) = (moveLeft l, moveRight r, c)

moveRight :: Word64 -> Word64
moveRight x = shiftR x 1

moveLeft :: Word64 -> Word64
moveLeft x = shift x 1

需要大约60秒才能执行。如果我使用-O2启用编译器优化,则大约需要7秒。 -O1速度更快,大约需要5秒钟。 测试了此代码的java版本,使用for循环代替映射列表,大约需要1s(!)。我一直在努力优化,但我在网上找到的提示都没有超过半秒。请帮忙

编辑:Java版本:

public class Queens{
    static int getQueens(){
        int res = 0;
        for (int i = 0; i < N; i++) {
            int pos = 1 << i;
            res += run(pos << 1, pos >> 1, pos, N - 2);
        }
        return res;
    }

    static int run(long diagR, long diagL, long mid, int depth) {
        long valid = mid | diagL | diagR;
        int resBuffer = 0;

        for (int i = 0; i < N; i++) {
            int pos = 1 << i;
            if ((valid & pos) > 0) {
                continue;
            }
            if (depth == 0) {
                resBuffer++;
                continue;
            }
            long n_mid = mid | pos;
            long n_diagL = (diagL >> 1) | (pos >> 1);
            long n_diagR = (diagR << 1) | (pos << 1);

            resBuffer += run(n_diagR, n_diagL, n_mid, depth - 1);
        }
        return resBuffer;
    }
}

编辑:在具有3.2GHz的i5 650上使用ghc 8.4.1在Windows上运行。

1 个答案:

答案 0 :(得分:4)

假设您的算法是正确的(我还没有对此进行验证),我能够始终获得900毫秒(比Java实现更快!)。 -O2-O3在我的机器上都具有可比性。

值得注意的变化:(编辑:最重要的变化:从List切换到Vector)切换到GHC 8.4.1,严格使用严格,BitState现在是严格的3元组 使用Vector对于获得更好的速度非常重要 - 在我看来,即使使用融合,您也无法实现与链接列表相当的速度。未装箱的矢量非常重要,因为您知道Vector始终为Word64 s或Int s。

{-# LANGUAGE BangPatterns #-}

module Main (main) where

import Data.Bits ((.&.), (.|.), shiftR, shift)
import Data.Vector.Unboxed (Vector)
import qualified Data.Vector.Unboxed as Vector
import Data.Word (Word64)
import Prelude hiding (max, sum)
import System.CPUTime (getCPUTime)

--
-- The Damenproblem.
-- Wiki: https://de.wikipedia.org/wiki/Damenproblem
main :: IO ()
main = do
  start <- getCPUTime
  print $ dame 14
  end <- getCPUTime
  print $ "Needed " ++ (show ((fromIntegral (end - start)) / (10^12))) ++ " Seconds"

data BitState = BitState !Word64 !Word64 !Word64

bmap :: (Word64 -> Word64) -> BitState -> BitState
bmap f (BitState x y z) = BitState (f x) (f y) (f z)
{-# INLINE bmap #-}

bfold :: (Word64 -> Word64 -> Word64) -> BitState -> Word64
bfold f (BitState x y z) = x `f` y `f` z 
{-# INLINE bfold #-}

singleton :: Word64 -> BitState
singleton !x = BitState x x x
{-# INLINE singleton #-}

dame :: Int -> Int
dame !x = sumWith fn row
  where
    fn !x' = recur (x - 2) $ nextState $ singleton x'
    getPossible !depth !stateVal !state !bit
      | (bit .&. stateVal) > 0 = 0
      | depth == 0 = 1
      | otherwise = recur (depth - 1) (nextState (addBitToState bit state))
    recur !depth !state = sumWith (getPossible depth (getStateVal state) state) row
    !row = Vector.iterateN x moveLeft 1

sumWith :: (Vector.Unbox a, Vector.Unbox b, Num b) => (a -> b) -> Vector a -> b
sumWith f as = Vector.sum $ Vector.map f as
{-# INLINE sumWith #-}

getStateVal :: BitState -> Word64
getStateVal !b = bfold (.|.) b

addBitToState :: Word64 -> BitState -> BitState
addBitToState !l !b = bmap (.|. l) b

nextState :: BitState -> BitState
nextState !(BitState l r c) = BitState (moveLeft l) (moveRight r) c

moveRight :: Word64 -> Word64
moveRight !x = shiftR x 1
{-# INLINE moveRight #-}

moveLeft :: Word64 -> Word64
moveLeft !x = shift x 1
{-# INLINE moveLeft #-}

我用ghc dame.hs -O2 -fforce-recomp -ddump-simpl -dsuppress-all检查了核心,它看起来很不错(即所有未装箱,循环看起来都很好)。我担心getPossible的部分应用可能会成为一个问题,但事实证明并非如此。我觉得如果我更好地理解算法,就可以用更好/更有效的方式编写,但是我并不太关心 - 这仍然能够击败Java实现。