了解快速取幂函数

时间:2015-08-20 20:04:15

标签: c algorithm exponentiation

我无法理解为什么这个功能有效?有人能解释一下它一步一步做什么吗?我知道这个想法是如果n是偶数,则^ n等于(a ^(n / 2))^ 2或者如果n是奇数,则a(a ^((n-1)/ 2))^ 2,但是这个功能是如何做到的?

double pow(double a, int n) {
    double ret = 1;
    while(n) {
        if(n%2 == 1) ret *= a;
        a *= a; n /= 2;
    }
    return ret;
}

3 个答案:

答案 0 :(得分:3)

本计划使用的平等点如下:

  1. invariant of the loop是:(在循环的每一步),a^n * ret是结果。事实上,在开头ret1,而在循环结束时n == 0,因此a^0 * ret是结果,自a^0 == 1以来, ret是预期的结果。
  2. 如果n为奇数(即n%2 == 1),则存在b≥0n=b*2+1。在这种情况下,我们使用以下等式:a^(b*2+1)=(a^(b*2))*a。因此ret乘以a
  3. 在下一个语句中,使用以下相等性:a^(b*2) = (a^2)^b,以便a与自身相乘,n除以2,最终保持不变量。< / LI>

    请注意,在循环内部,整数除法用于n /= 2,因此在两种情况下结果始终为bn奇数,即n=b*2+1 ,或n是偶数,即n=b*2)。

    最后请注意,正如@chux在评论中指出的那样,该函数无法正确管理n的负值。

答案 1 :(得分:1)

这是我的Python递归代码,它是IMO更易读和可理解的(我知道在Python中创建递归函数并不是一个好主意,但我之所以选择Python是因为其语法简单来演示这个想法)。

def pow(n, e):
    if e == 0:
        return 1

    if e % 2 == 1:
        return n * pow(n, e - 1)

    # this step makes the algorithm to run in O(lg n) time
    tmp = pow(n, e / 2)

    return tmp * tmp

我将再次强调,tmp = pow(n, e / 2)是降低时间复杂度的线。

算法不是将数乘以e乘以n,而是重用一些先前计算的结果。例如,2 ^ 8将被计算为2 ^ 4 * 2 ^ 4。这里2 ^ 4将只计算一次,并且将以这种方式跳过一半的迭代。同样适用于2 ^ 4等。

我试图以某种方式更直观地解释它,而没有深入研究这种优化背后的理论。如果你想更深入地理解它以及它在位级上的工作原理,那么这是一个很好的tutorial

答案 2 :(得分:1)

我将从一些更明显的代码开始:

double pow(double a, int n) {
    int k = 0, m = 1, n2 = n;
    double pow_k = 1.0, pow_m = a;
    assert (n2 * m + k == n);

    while (n2 != 0) {
        if (n2 % 2 != 0) { k += m; pow_k *= pow_m; n2 -= 1; }
        assert (n2 * m + k == n); assert (n2 % 2 == 0);
        m = m * 2; pow_m = pow_m * pow_m; n2 /= 2;
        assert (n2 * m + k == n);
    }

    return pow_k;
}

在循环中的每个点,pow_k = a ^ k和pow_m = a ^ m。 n2 * m + k == n始终为真。当n2 == n,m == 1,k == 0时,它最初为真。

在循环中的第一个if语句之前,n2是偶数,因此断言保持为真且n2保持偶数。或者n2是奇数。在那种情况下,n2减少1,使n2 * m减小m; k增加m,使n2 * m + k保持不变。并且n2是均匀的。

然后m加倍并且n2正好减半,因为n2是偶数,再次保持n2 * m + k不变。

由于在每次迭代中n2除以2,因此n2最终变为0,因此循环结束。具有n2 == 0的断言意味着0 * m + k == n或k == n,因此pow_k = a ^ k = a ^ n。因此返回的结果是^ n。

现在我们省略了k,m和断言,它们没有改变计算:

double pow(double a, int n) {
    int n2 = n;
    double pow_k = 1.0, pow_m = a;

    while (n2 != 0) {
        if (n2 % 2 != 0) { pow_k *= pow_m; n2 -= 1; }
        m = m * 2; pow_m = pow_m * pow_m; n2 /= 2;
    }

    return pow_k;
}

当n2为奇数时,我们可以删除n2 - = 1,因为在除以2之后它不会产生差异。由于没有使用n,我们可以使用n而不是n2:

double pow(double a, int n) {
    double pow_k = 1.0, pow_m = a;

    while (n != 0) {
        if (n % 2 != 0) pow_k *= pow_m;
        pow_m = pow_m * pow_m; n /= 2;
    }

    return pow_k;
}

现在我们将pow_k更改为ret,将pow_m更改为a,并将n%2!= 0更改为n%2 == 1,我们将获得原始代码:

double pow(double a, int n) {
    double ret = 1.0;

    while (n != 0) {
        if (n % 2 == 1) ret *= a;
        a *= a; n /= 2;
    }

    return ret;
}
相关问题