模幂运算中的整数溢出

时间:2012-01-02 21:18:22

标签: .net f# overflow

我正在写一个powMod函数,我必须非常密集地使用它。起点是自定义pow函数:

// Compute power using multiplication and square.
// pow (*) (^2) 1 x n = x^n
let pow mul sq one x n =   
    let rec loop x' n' acc =
       match n' with
       | 0 -> acc
       | _ -> let q = n'/2
              let r = n'%2
              let x2 = sq x'
              if r = 0 then
                 loop x2 q acc
              else
                 loop x2 q (mul x' acc)
    loop x n one

在检查输入范围后,我选择了int64,因为它足以表示输出,我可以避免使用bigint进行昂贵的计算:

let mulMod m a b = (a*b)%m
let squareMod m a = mulMod m a a
let powMod m = pow (mulMod m) (squareMod m) 1L

我假设模数(m)大于乘数(ab),函数仅适用于非负数。 powMod函数对大多数情况都是正确的;但是,问题在于mulMod函数,其中a*b可能超过int64范围但(a*b)%m不是。 以下示例演示了溢出问题:

let a = (pown 2L 40) - 1L
let b = (pown 2L 32) - 1L
let p = powMod a b 2 // p = -8589934591L -- wrong

有没有办法避免int64溢出而不诉诸bigint类型?

5 个答案:

答案 0 :(得分:2)

根据Wikipedia,以下公式是等效的。你的代码正在使用第一个,更改为第二个应解决溢出问题。

c = (a x b) mod(m)
c = (a x (b mod(m))) mod(m) 

希望这有帮助。

根据您的评论 - 如果a< = m且b< = m,并且m> sqrt(maxint64),那么我不确定如果没有更大的存储空间就可以实现解决方案。对于m的大值,b mod m将返回b,因此使用上述等价公式没有任何好处。

好消息是你应该能够将更改限制为单行,然后将值重新键入64位[因为我们知道(a * b)%c不应该溢出],然后再继续计算。这将费用(在执行性能方面)限制为尽可能小的代码段。

答案 1 :(得分:2)

我对f#几乎一无所知,但我认为你可以应用以下事实:

如果b为奇数且nb = 2n + 1

a * b mod(m) = 2 * a * n + a mod(m)
             = 2 * (a*n mod(m)) + a mod(m)

并且类似地如果b是偶数。显然,您可以根据需要在an上多次重复此操作,直到您最终获得适合int64的产品。如果m>我认为仍有可能溢出maxint64 / 2。

答案 2 :(得分:2)

你遇到的问题是你的所有中间计算都是隐含的mod 2 64 ,而且通常情况并非如此

  

a·b mod m =(a·b mod 2 64 )mod m

这是你正在计算的。

我想不出使用64位数字进行正确计算的简单方法,但是你不必一直到bigints;如果ab最多有64位,那么它们的完整产品最多只有128位,因此您可以用两个64位整数(这里捆绑为自定义结构)来跟踪产品:

// bit width of a uint64, needed for mod calculation
let width = 
    let rec loop w = function
    | 0uL -> w
    | n -> loop (w+1) (n >>> 1)
    loop 0

[<Struct; CustomComparison; CustomEquality>]
type UInt128 =
    val hi : uint64
    val lo : uint64
    new (hi,lo) = { lo = lo; hi = hi }
    new (lo) = { lo = lo; hi = 0uL }
    static member (+)(x:UInt128, y:UInt128) =
        if x.lo > 0xffffffffuL - y.lo then
            UInt128(x.hi + y.hi + 1uL, x.lo + y.lo)
        else
            UInt128(x.hi + y.hi, x.lo + y.lo)
    static member (-)(x:UInt128, y:UInt128) =
        if y.lo > x.lo then
            UInt128(x.hi - y.hi - 1uL, x.lo - y.lo)
        else
            UInt128(x.hi - y.hi, x.lo - y.lo)

    static member ( * )(x:UInt128, y:UInt128) =
        let a1 = ((x.lo &&& 0xffffffffuL) * (y.lo &&& 0xffffffffuL)) >>> 32
        let a2 =  (x.lo &&& 0xffffffffuL) * (y.lo >>> 32)
        let a3 =  (x.lo >>> 32) * (y.lo &&& 0xffffffffuL)
        let sum = ((a1 + a2 + a3) >>> 32) + (x.lo >>> 32) * (y.lo >>> 32)
        let sum =
            if a2 > 0xffffffffffffffffuL - a1 || a1 + a2 > 0xffffffffffffffffuL - a3 then
                0x100000000uL + sum
            else
                sum
        UInt128(x.hi * y.lo + x.lo * y.hi + sum, x.lo * y.lo)

    static member (>>>)(x:UInt128, n) =
        UInt128(x.hi >>> n, x.lo >>> n)

    static member (<<<)(x:UInt128, n) =
        UInt128((x.hi <<< n) + (x.lo >>> (64 - n)), x.lo <<< n)

    interface System.IComparable with
        member x.CompareTo(y) =
            match y with
            | :? UInt128 as y ->
                match x.hi.CompareTo(y.hi) with
                | 0 -> x.lo.CompareTo(y.lo)
                | n -> n

    override x.Equals(y) = 
        match y with
        | :? UInt128 as y -> x.hi = y.hi && x.lo = y.lo
        | _ -> false

    override x.GetHashCode() = x.hi.GetHashCode() + x.lo.GetHashCode() * 7

    (* calculate mod via long-division *)
    static member (%)(x:UInt128, d) =
        let rec reduce (r:UInt128) d' =
            if r.hi = 0uL then r.lo % d
            else
                let r' = if r < d' then r else r - d'
                reduce r' (d' >>> 1)
        let shift = width x.hi + (64 - width d)
        reduce x (UInt128(0uL,d) <<< shift)

let mulMod m a b =
    UInt128(a) * UInt128(b) % m

(* squareMod, powMod basically as before: *)
let squareMod m a = mulMod m a a  
let powMod m = pow (mulMod m) (squareMod m) 1uL  

let a = (pown 2uL 40) - 1uL  
let b = (pown 2uL 32) - 1uL  
let p = powMod a b 2

话虽如此,由于bigint s会给你正确答案,为什么不用bigint来进行中间计算,最后转换为long(保证是给定m的范围无损转换)?我怀疑使用bigints的性能损失对于大多数应用程序来说应该是可接受的(与维护自己的数学例程的头痛相比)。

答案 3 :(得分:0)

我不太了解F#,但是(伪代码)

pmod a b n // a^b % n
pmod a 0 n = 1
pmod a 1 n = a%n
pmod a b n = match b%2
   | 0 -> ((pmod (a) (b/2) n) ^ 2) % n
   | 1 -> ((pmod (a) (b-1) n) * a ) % n

pmod还不行,但应该用作帮助函数

PowMod a b n = pmod (a%n) b n

你可以看到,只要结果的平方将是uint64,那么这个结果就会出错,因此n必须适合uint32

答案 4 :(得分:0)

不是你问题的答案,但以这种方式编写函数会使它更通用,更方便使用,它似乎也更有效:

let inline pow x n =
    let zero = LanguagePrimitives.GenericZero
    let rec loop x acc = function
        | n when n = zero -> acc
        | n ->
            let q = n >>> 1
            let acc = if n = (q <<< 1) then acc else x * acc
            loop (x * x) acc q
    loop x LanguagePrimitives.GenericOne n;;

for x = 0 to 1000000 do
    pow 3UL 31UL |> ignore

另外,我认为unsigned long还不够吗?

编辑:以下算法比大bigint快3倍,因为执行较少的乘法 - 可能有助于您选择bigint:

let inline pow2 x n =
    let zero = LanguagePrimitives.GenericZero
    let one = LanguagePrimitives.GenericOne
    let rec loop x data = function
        | c when c <<< 1 <= n ->
            let c = c <<< 1
            let x = x * x
            loop x (Map.add -c x data) c
        | c -> reduce x data (n - c)
    and reduce acc data = function
        | c when c = zero -> acc
        | c ->
            let next, value = data |> Seq.pick (fun (KeyValue (n, v)) -> if -n <= c then Some (-n, v) else None)
            reduce (acc * value) data (c - next)
    loop x (Map [-1, x]) one;;

for x = 10000 downto 9000 do
    pow2 7I x |> ignore