在Java中获取k个最小(或最大)数组元素的最快方法是什么?

时间:2016-08-29 11:34:24

标签: java algorithm optimization priority-queue

我有一个元素数组(在这个例子中,这些只是整数),使用一些自定义比较器进行比较。在此示例中,我通过定义i SMALLER j来模拟此比较器,当且仅当scores[i] <= scores[j]

我有两种方法:

  • 使用当前k候选人的堆
  • 使用当前k候选人的数组

我按以下方式更新上面的两个结构:

  • 堆:方法PriorityQueue.pollPriorityQueue.offer
  • 数组:存储候选数组中前k个候选中最差的索引top。如果新看到的示例比索引top处的元素更好,则后者由前者替换,并且top通过迭代遍历数组的所有k个元素来更新。

然而,当我测试时,哪种方法更快,我发现这是第二种。问题是:

  
      
  • 我使用PriorityQueue次优?
  •   
  • 计算k个最小元素的最快方法是什么?
  •   

我对这种情况感兴趣,当例子的数量可以很大,但是邻居的数量相对较小(在10到20之间)。

以下是代码:

public static void main(String[] args) {
    long kopica, navadno, sortiranje;

    int numTries = 10000;
    int numExamples = 1000;
    int numNeighbours = 10;

    navadno = testSimple(numExamples, numNeighbours, numTries);
    kopica = testHeap(numExamples, numNeighbours, numTries);

    sortiranje = testSort(numExamples, numNeighbours, numTries, false);
    System.out.println(String.format("tries: %d examples: %d neighbours: %d\n time heap[ms]: %d\n time simple[ms]: %d", numTries, numExamples, numNeighbours, kopica, navadno));
}

public static long testHeap(int numberExamples, int numberNeighbours, int numberTries){
    Random rnd = new Random(123);   
    long startTime = System.currentTimeMillis();
    for(int iteration = 0; iteration < numberTries; iteration++){
        final double[] scores = new double[numberExamples];
        for(int i = 0; i < numberExamples; i++){
            scores[i] = rnd.nextDouble();
        }
        PriorityQueue<Integer> myHeap = new PriorityQueue(numberNeighbours, new Comparator<Integer>(){
            @Override
            public int compare(Integer o1, Integer o2) {
                return -Double.compare(scores[o1], scores[o2]);
            }
        });

        int top;
        for(int i = 0; i < numberExamples; i++){
            if(i < numberNeighbours){
                myHeap.offer(i);
            } else{
                top = myHeap.peek();
                if(scores[top] > scores[i]){
                    myHeap.poll();
                    myHeap.offer(i);
                }
            }
        }

    }
    long endTime = System.currentTimeMillis();
    return endTime - startTime;     
}

public static long testSimple(int numberExamples, int numberNeighbours, int numberTries){
    Random rnd = new Random(123);   
    long startTime = System.currentTimeMillis();
    for(int iteration = 0; iteration < numberTries; iteration++){
        final double[] scores = new double[numberExamples];
        for(int i = 0; i < numberExamples; i++){
            scores[i] = rnd.nextDouble();
        }
        int[] candidates = new int[numberNeighbours];
        int top = 0;
        for(int i = 0; i < numberExamples; i++){
            if(i < numberNeighbours){
                candidates[i] = i;
                if(scores[candidates[top]] < scores[candidates[i]]) top = i;
            } else{
                if(scores[candidates[top]] > scores[i]){
                    candidates[top] = i;
                    top = 0;
                    for(int j = 1; j < numberNeighbours; j++){
                        if(scores[candidates[top]] < scores[candidates[j]]) top = j;                            
                    }
                }
            }
        }

    }
    long endTime = System.currentTimeMillis();
    return endTime - startTime;     
}

这会产生以下结果:

tries: 10000 examples: 1000 neighbours: 10
   time heap[ms]: 393
   time simple[ms]: 388

3 个答案:

答案 0 :(得分:3)

创建最快的算法绝非易事,您需要考虑很多事情。例如,k元素需要返回排序与否,您的研究需要stable(如果两个元素等于你需要在第一个之前提取或不需要)或不是?

在本次竞赛中,理论上最好的解决方案是将k个最小元素保存在有序数据结构中。因为插入可能经常发生在这个数据结构的中间,所以平衡的排序树似乎是一个最佳解决方案。

但现实与此截然不同。

根据原始数组的大小和k的值,不同数据结构之间的混合可能是最佳解决方案:

  • 如果k很少使用数组来保存k个最小值
  • 如果k很大,请使用平衡树
  • 如果k非常大且接近数组的维度,只需对数组进行排序(如果不能创建它的新排序副本),则提取前k个元素。

这种算法名为hibryd algorithm。一个着名的混合算法是Tim Sort,它在java类中用于对集合进行排序。

注意:如果可以使用多线程不同的算法,那么可以使用不同的数据结构。

关于微观基准的补充说明。您的绩效指标会受到与算法效率无关的外部因素的强烈影响。像在两个函数中一样创建对象时,可能需要不可用的内存,要求GC完成额外的工作。这种因素对你的结果影响很大。至少尝试最小化与要调查的代码部分不紧密相关的代码。以不同的顺序重复测试,在调用测试之前等待,以确保没有GC正在运行。

答案 1 :(得分:1)

第一个解决方案的时间复杂度为O(numberExamples * log numberNeighbours),而第二个解决方案为O(numberExamples * numberNeighbours),因此对于足够大的输入,它必须更慢。第二种解决方案更快,因为您测试小numberNeighbours,并且PriorityQueue比简单数组具有更大的开销。 您使用PriorityQueue最佳。

更快,但不是最优,只是对数组进行排序,然后最小元素位于k位置。

无论如何,你可能想要实现QuickSelect算法,如果你聪明地选择pivot元素,你应该有更好的性能。您可能希望看到此https://discuss.leetcode.com/topic/55501/2ms-java-quick-select-only-2-points-to-mention

答案 2 :(得分:1)

首先,您的基准测试方法不正确。您正在测量输入数据创建以及算法性能,并且您在测量之前没有预热JVM。通过JMH测试代码的结果:

Benchmark                     Mode  Cnt      Score   Error  Units
CounterBenchmark.testHeap    thrpt    2  18103,296          ops/s
CounterBenchmark.testSimple  thrpt    2  59490,384          ops/s

修改后的基准pastebin

关于两个提供的解决方案之间的3倍差异。在big-O表示法方面,你的第一个算法可能看起来更好,但实际上big-O表示法只能告诉你算法在缩放方面有多好,它从不告诉你它的执行速度有多快(见{{3也)。在您的情况下,缩放不是问题,因为您的numNeighbours限制为20.换句话说,big-O表示法描述了完成算法需要多少滴答,但它并不限制持续时间一个滴答声,它只是说当输入变化时,滴答持续时间不会改变。就嘀嗒复杂度而言,你的第二个算法肯定会获胜。

  

计算k个最小元素的最快方法是什么?

我想出了下一个解决方案,我相信它允许question完成它的工作:

@Benchmark
public void testModified(Blackhole bh) {
    final double[] scores = sampleData;
    int[] candidates = new int[numberNeighbours];
    for (int i = 0; i < numberNeighbours; i++) {
        candidates[i] = i;
    }
    // sorting candidates so scores[candidates[0]] is the largest
    for (int i = 0; i < numberNeighbours; i++) {
        for (int j = i+1; j < numberNeighbours; j++) {
            if (scores[candidates[i]] < scores[candidates[j]]) {
                int temp = candidates[i];
                candidates[i] = candidates[j];
                candidates[j] = temp;
            }
        }
    }
    // processing other scores, while keeping candidates array sorted in the descending order
    for (int i = numberNeighbours; i < numberExamples; i++) {
        if (scores[i] > scores[candidates[0]]) {
            continue;
        }
        // moving all larger candidates to the left, to keep the array sorted
        int j; // here the branch prediction should kick-in
        for (j = 1; j < numberNeighbours && scores[i] < scores[candidates[j]]; j++) {
            candidates[j - 1] = candidates[j];
        }
        // inserting the new item
        candidates[j - 1] = i;
    }
    bh.consume(candidates);
}

基准测试结果(比当前解决方案快2倍):

(10 neighbours) CounterBenchmark.testModified    thrpt    2  136492,151          ops/s
(20 neighbours) CounterBenchmark.testModified    thrpt    2  118395,598          ops/s

其他人提到了branch prediction,但正如人们所预料的那样,该算法的复杂性忽略了它在你的情况下的强势:

@Benchmark
public void testQuickSelect(Blackhole bh) {
    final int[] candidates = new int[sampleData.length];
    for (int i = 0; i < candidates.length; i++) {
        candidates[i] = i;
    }
    final int[] resultIndices = new int[numberNeighbours];
    int neighboursToAdd = numberNeighbours;

    int left = 0;
    int right = candidates.length - 1;
    while (neighboursToAdd > 0) {
        int partitionIndex = partition(candidates, left, right);
        int smallerItemsPartitioned = partitionIndex - left;
        if (smallerItemsPartitioned <= neighboursToAdd) {
            while (left < partitionIndex) {
                resultIndices[numberNeighbours - neighboursToAdd--] = candidates[left++];
            }
        } else {
            right = partitionIndex - 1;
        }
    }
    bh.consume(resultIndices);
}

private int partition(int[] locations, int left, int right) {
    final int pivotIndex = ThreadLocalRandom.current().nextInt(left, right + 1);
    final double pivotValue = sampleData[locations[pivotIndex]];
    int storeIndex = left;
    for (int i = left; i <= right; i++) {
        if (sampleData[locations[i]] <= pivotValue) {
            final int temp = locations[storeIndex];
            locations[storeIndex] = locations[i];
            locations[i] = temp;

            storeIndex++;
        }
    }
    return storeIndex;
}

在这种情况下,基准测试结果非常令人沮丧:

CounterBenchmark.testQuickSelect  thrpt    2   11586,761          ops/s