如何改进这段代码?

时间:2015-03-31 01:52:50

标签: java performance algorithm

我已经开发了一个代码,用于根据2的幂来表示数字,我在下面附上相同的代码。

但问题是表达的输出应该是最小长度。

我的输出为3^2+1^2+1^2+1^2,这不是最小长度。 我需要以这种格式输出:

package com.algo;
import java.util.Scanner;

public class GetInputFromUser {
    public static void main(String[] args) {
    // TODO Auto-generated method stub
        int n;
        Scanner in = new Scanner(System.in);

        System.out.println("Enter an integer");
        n = in.nextInt();

        System.out.println("The result is:");
        algofunction(n);
    }

    public static int algofunction(int n1)
    {
        int r1 = 0;
        int r2 = 0;
        int r3 = 0;
        //System.out.println("n1: "+n1);
        r1 = (int) Math.sqrt(n1);
        r2 = (int) Math.pow(r1, 2);
        // System.out.println("r1: "+r1);
        //System.out.println("r2: "+r2);
        System.out.print(r1+"^2");

        r3 = n1-r2;
        //System.out.println("r3: "+r3);
        if (r3 == 0)
            return 1;

        if(r3 == 1)
        {
            System.out.print("+1^2");
            return 1;
        } 
        else {
            System.out.print("+");
            algofunction(r3);
            return 1;
        }
    }
}

1 个答案:

答案 0 :(得分:1)

动态编程就是以这样的方式定义问题:如果你知道原版的较小版本的答案,你可以使用它来更快/更直接地回答主要问题。这就像应用数学归纳法一样。

在您的特定问题中,我们可以将MinLen(n)定义为n的最小长度表示。接下来,说,因为我们想要解决MinLen(12),假设我们已经知道MinLen(1),MinLen(2),MinLen(3),...,MinLen(11)的答案。我们怎么能用这些小问题的答案来弄清楚MinLen(12)?这是动态编程的另一半 - 弄清楚如何使用较小的问题来解决较大的问题。如果你提出一些较小的问题,但它无法将它们组合在一起,它对你没有帮助。

对于这个问题,我们可以做一个简单的陈述,“对于12,它的最小长度表示肯定有1 ^ 2,2 ^ 2或3 ^ 2。”并且通常,n的最小长度表示将具有小于或等于n的一些平方作为其一部分。你可以做出更好的声明,这会改善运行时间,但我会说它现在已经足够好了。

该陈述表示MinLen(12)= 1 ^ 2 + MinLen(11),OR 2 ^ 2 + MinLen(8),OR 3 ^ 2 + MinLen(3)。你检查所有这些并选择最好的一个,然后你将它保存为MinLen(12)。现在,如果你想解决MinLen(13),你也可以这样做。

独唱时的建议: 我自己测试这种程序的方法是插入1,2,3,4,5等,并看第一次出错。另外,我碰巧认为的任何假设都是个好主意,我质疑:“小于n的最大平方数是否在MinLen(n)的表示中是真的吗?”

您的代码:

r1 = (int) Math.sqrt(n1);
r2 = (int) Math.pow(r1, 2);

体现了这种假设(一种贪婪的假设),但这是错误的,因为你已经清楚地看到了MinLen(12)的答案。

相反,你想要更像这样的东西:

public ArrayList<Integer> minLen(int n)
{
    // base case of recursion
    if (n == 0)
        return new ArrayList<Integer>();

    ArrayList<Integer> best = null;
    int bestInt = -1;
    for (int i = 1; i*i <= n; ++i)
    {
        // Check what happens if we use i^2 as part of our representation
        ArrayList<Integer> guess = minLen(n - i*i);

        // If we haven't selected a 'best' yet (best == null)
        // or if our new guess is better than the current choice (guess.size() < best.size())
        // update our choice of best
        if (best == null || guess.size() < best.size())
        {
            best = guess;
            bestInt = i;
        }
    }

    best.add(bestInt);
    return best;
}

然后,一旦你有了你的清单,就可以对它进行排序(不保证它按排序顺序排列),并按照你想要的方式打印出来。

最后,您可能会注意到,对于较大的n值(1000可能太大),您插入上述递归,它将开始变得非常慢。这是因为我们不断重新计算所有小的子问题 - 例如,当我们调用MinLen(4)时,我们找出MinLen(3),因为4 - 1 ^ 2 = 3.但我们为MinLen(7)计算出两次 - &GT; 3 = 7 - 2 ^ 2,但3也是7 - 1 ^ 2 - 1 ^ 2 - 1 ^ 2 - 1 ^ 2。你走得越大就越糟糕。

这个问题的解决方案是使用一种名为Memoization的技术,它可以让你快速解决n = 1,000,000或更多。这意味着一旦我们弄清楚MinLen(3),我们将它保存在某个地方,让我们说一个全球位置让它变得简单。然后,每当我们尝试重新计算它时,我们首先检查全局缓存,看看我们是否已经这样做了。如果是这样,那么我们只是使用它,而不是重做所有的工作。

import java.util.*;

class SquareRepresentation
{
    private static HashMap<Integer, ArrayList<Integer>> cachedSolutions;
    public static void main(String[] args)
    {
        cachedSolutions = new HashMap<Integer, ArrayList<Integer>>();
        for (int j = 100000; j < 100001; ++j)
        {
            ArrayList<Integer> answer = minLen(j);
            Collections.sort(answer);
            Collections.reverse(answer);
            for (int i = 0; i < answer.size(); ++i)
            {
                if (i != 0)
                    System.out.printf("+");
                System.out.printf("%d^2", answer.get(i));
            }
            System.out.println();
        }
    }

    public static ArrayList<Integer> minLen(int n)
    {
        // base case of recursion
        if (n == 0)
            return new ArrayList<Integer>();

        // new base case: problem already solved once before
        if (cachedSolutions.containsKey(n))
        {
            // It is a bit tricky though, because we need to be careful!
            // See how below that we are modifying the 'guess' array we get in?
            // That means we would modify our previous solutions! No good!
            // So here we need to return a copy
            ArrayList<Integer> ans = cachedSolutions.get(n);
            ArrayList<Integer> copy = new ArrayList<Integer>();
            for (int i: ans) copy.add(i);
            return copy;
        }

        ArrayList<Integer> best = null;
        int bestInt = -1;
        // THIS IS WRONG, can you figure out why it doesn't work?:
        // for (int i = 1; i*i <= n; ++i)
        for (int i = (int)Math.sqrt(n); i >= 1; --i)
        {
            // Check what happens if we use i^2 as part of our representation
            ArrayList<Integer> guess = minLen(n - i*i);

            // If we haven't selected a 'best' yet (best == null)
            // or if our new guess is better than the current choice (guess.size() < best.size())
            // update our choice of best
            if (best == null || guess.size() < best.size())
            {
                best = guess;
                bestInt = i;
            }
        }

        best.add(bestInt);

        // check... not needed unless you coded wrong
        int sum = 0;
        for (int i = 0; i < best.size(); ++i)
        {
            sum += best.get(i) * best.get(i);
        }
        if (sum != n)
        {
            throw new RuntimeException(String.format("n = %d, sum=%d, arr=%s\n", n, sum, best));
        }

        // New step: Save the solution to the global cache
        cachedSolutions.put(n, best);

        // Same deal as before... if you don't return a copy, you end up modifying your previous solutions
        // 
        ArrayList<Integer> copy = new ArrayList<Integer>();
        for (int i: best) copy.add(i);
        return copy;
    }
}

我的程序大约需要5秒才能运行n = 100,000。显然,如果我们希望它更快,并解决更大的n,还有更多工作要做。现在的主要问题是,在存储以前答案的整个结果列表时,我们会占用大量内存。所有这些复制!你可以做的更多,比如只存储一个整数和指向子问题的指针,但我会让你这样做。

由by,1000 = 30 ^ 2 + 10 ^ 2.