Java多线程合并排序速度问题

时间:2014-05-10 07:09:00

标签: java multithreading performance sorting mergesort

我试图在Java中实现多线程合并排序。我们的想法是在每次迭代时递归调用新线程。一切正常,但问题是常规的单线程版本似乎更快。请帮忙修理它。 我试过玩.join(),但它没有取得任何成功。 我的代码:

public class MergeThread implements Runnable {

    private final int begin;
    private final int end;

    public MergeThread(int b, int e) {
        this.begin = b;
        this.end = e;
    }

    @Override
    public void run() {
        try {
            MergeSort.mergesort(begin, end);
        } catch (InterruptedException ex) {
            Logger.getLogger(MergeThread.class.getName()).log(Level.SEVERE, null, ex);
        }
    }
}

public class MergeSort {
    private static volatile int[] numbers;
    private static volatile int[] helper;

    private int number;

    public void sort(int[] values) throws InterruptedException {
    MergeSort.numbers = values;
    number = values.length;
    MergeSort.helper = new int[number];
    mergesort(0, number - 1);
    }

    public static void mergesort(int low, int high) throws InterruptedException {
    // check if low is smaller than high, if not then the array
    // is sorted
    if (low < high) {
        // Get the index of the element which is in the middle
        int middle = low + (high - low) / 2;
        // Sort the left side of the array
            Thread left = new Thread(new MergeThread(low, middle));
            Thread right = new Thread(new MergeThread(middle+1, high));

            left.start();
            right.start();
            left.join();
            right.join();

        // combine the sides
        merge(low, middle, high);
    }
}

    private static void merge(int low, int middle, int high) {
    // Copy both parts into the helper array
    for (int i = low; i <= high; i++) {
        helper[i] = numbers[i];
    }

    int i = low;
    int j = middle + 1;
    int k = low;
    // Copy the smallest value from either the left or right side
    // back to the original array
    while (i <= middle && j <= high) {
        if (helper[i] <= helper[j]) {
        numbers[k] = helper[i];
        i++;
        } else {
        numbers[k] = helper[j];
        j++;
        }
        k++;
    }
    // Copy the rest of the left side of the array
    while (i <= middle) {
        numbers[k] = helper[i];
        k++;
        i++;
    }
}

    public static void main(String[] args) throws InterruptedException {
        int[] array = new int[1000];
        for(int pos = 0; pos<1000; pos++) {
            array[pos] = 1000-pos;
        }
        long start = System.currentTimeMillis();
        new MergeSort().sort(array);
        long finish = System.currentTimeMillis();

        for(int i = 0; i<array.length; i++) {
            System.out.print(array[i]+" ");
        }
        System.out.println();
        System.out.println(finish-start);

    }
}

3 个答案:

答案 0 :(得分:3)

这里有几个因素。首先,你产生了太多的线程。远远超过处理器的核心数量。如果我正确理解了您的算法,那么您在树的底层会执行类似log2(n)的操作。

鉴于您正在进行处理器密集型计算,而不涉及I/O,一旦您通过线程数计算了核心数,性能就会开始降级很快。击中几千个线程之类的东西会变慢并最终导致VM崩溃。

如果您希望在此计算中实际受益于多核处理器,您应该尝试使用固定大小的线程池(上限以核心数量或其附近)或等效的线程重用策略。

第二点,如果你想进行有效的比较,你应该尝试使用持续时间更长的计算(排序100个数字不符合条件)。如果没有,那么你从创建线程的成本中获得了相当大的影响。

答案 1 :(得分:1)

以核心数量或更少的数量开始提供线程计数。

以下链接也有性能分析。

这里有一个很好的例子https://courses.cs.washington.edu/courses/cse373/13wi/lectures/03-13/MergeSort.java

答案 2 :(得分:0)

以下是MergeSort的迭代串行版本,它确实比递归版本更快,也不涉及中间的计算,因此避免了溢出错误。但是,对于其他整数也可能发生溢出错误。如果您有兴趣,可以尝试并行化。

protected static int[] ASC(int input_array[]) // Sorts in ascending order
{
    int num = input_array.length;
    int[] temp_array = new int[num];
    int temp_indx;
    int left;
    int mid,j;
    int right;
    int[] swap;
    int LIMIT = 1;
    while (LIMIT < num)
    {
        left = 0;
        mid = LIMIT ; // The mid point
        right = LIMIT << 1;
        while (mid < num)
        {
            if (right > num){ right = num; }
            temp_indx = left;
            j = mid;
            while ((left < mid) && (j < right))
            {
                if (input_array[left] < input_array[j]){  temp_array[temp_indx++] = input_array[left++];  }
                else{  temp_array[temp_indx++] = input_array[j++];  }
            }
            while (left < mid){  temp_array[temp_indx++] = input_array[left++];  }
            while (j < right){  temp_array[temp_indx++] = input_array[j++];  }

            // Do not copy back the elements to input_array
            left = right;
            mid = left + LIMIT;
            right = mid + LIMIT;
        }
        // Instead of copying back in previous loop, copy remaining elements to temp_array, then swap the array pointers
        while (left < num){  temp_array[left] = input_array[left++];  }

        swap = input_array;
        input_array = temp_array;
        temp_array = swap;

        LIMIT <<= 1;
    }
    return input_array ;
}

使用java执行器服务,速度更快,即使线程数超过核心数(你可以使用它构建可扩展的多线程应用程序),我有一个只使用线程的代码,但它非常慢,我是新的执行者如此无法帮助,但这是一个值得探索的有趣领域。

同样存在并行成本,因为线程管理是一个大问题,所以在高N处寻求并行性,如果你正在寻找合并排序的串行替代方案,我建议使用Dual-Pivot-QuickSort或3-分区 - 快速排序,因为它们经常被击败合并排序。原因是它们具有低于MergeSort的常数因子,并且最坏情况时间复杂度仅出现1 /(n!)的概率。如果N很大,则最坏情况概率变为非常小的铺平方式以增加平均情况的概率。您可以多线程并查看4个程序中的哪一个(1个串行和1个多线程:DPQ和3PQ)运行速度最快。

但是当没有或几乎没有重复的键时,Dual-Pivot-QuickSort效果最好,而且当有许多重复键时,3-Partition-Quick-Sort的效果最佳。我没见过3-Partition-Quick-Sort击败Dual-Pivot-QuickSort,当没有或很少有重复键时,但我看到Dual-Pivot-QuickSort击败3-Partition-Quick-Sort极少数在许多重复键的情况下的时间。如果您感兴趣,DPQ序列码低于(升序和降序)<​​/ p>

protected static void ASC(int[]a, int left, int right, int div)
{
    int len = 1 + right - left;
    if (len < 27)
    {
        // insertion sort for small array
        int P1 = left + 1;
        int P2 = left;
        while ( P1 <= right )
        {
            div = a[P1];
            while(( P2 >= left )&&( a[P2] > div ))
            {
                a[P2 + 1] = a[P2];
                P2--;
            }
            a[P2 + 1] = div;
            P2 = P1;
            P1++;
        }
        return;
    }
    int third = len / div;
    // "medians"
    int P1 = left + third;
    int P2 = right - third;
    if (P1 <= left)
    {
        P1 = left + 1;
    }
    if (P2 >= right)
    {
        P2 = right - 1;
    }
    int temp;
    if (a[P1] < a[P2])
    {
        temp = a[P1]; a[P1] = a[left]; a[left] = temp;
        temp = a[P2]; a[P2] = a[right]; a[right] = temp;
    }
    else
    {
        temp = a[P1];  a[P1] = a[right];  a[right] = temp;
        temp = a[P2];  a[P2] = a[left];  a[left] = temp;
    }
    // pivots
    int pivot1 = a[left];
    int pivot2 = a[right];
    // pointers
    int less = left + 1;
    int great = right - 1;
    // sorting
    for (int k = less; k <= great; k++)
    {
        if (a[k] < pivot1)
        {
            temp = a[k];  a[k] = a[less];  a[less] = temp;
            less++;
        }
        else if (a[k] > pivot2)
        {
            while (k < great && a[great] > pivot2)
            {
                great--;
            }
            temp = a[k];  a[k] = a[great];  a[great] = temp;
            great--;
            if (a[k] < pivot1)
            {
                temp = a[k];  a[k] = a[less];  a[less] = temp;
                less++;
            }
        }
    }
    int dist = great - less;
    if (dist < 13)
    {
        div++;
    }
    temp = a[less-1];  a[less-1] = a[left];  a[left] = temp;
    temp = a[great+1];  a[great+1] = a[right];  a[right] = temp;
    // subarrays
    ASC(a, left, less - 2, div);
    ASC(a, great + 2, right, div);
    // equal elements
    if (dist > len - 13 && pivot1 != pivot2)
    {
        for (int k = less; k <= great; k++)
        {
            if (a[k] == pivot1)
            {
                temp = a[k];  a[k] = a[less];  a[less] = temp;
                less++;
            }
            else if (a[k] == pivot2)
            {
                temp = a[k];  a[k] = a[great];  a[great] = temp;
                great--;
                if (a[k] == pivot1)
                {
                    temp = a[k];  a[k] = a[less];  a[less] = temp;
                    less++;
                }
            }
        }
    }
    // subarray
    if (pivot1 < pivot2)
    {
        ASC(a, less, great, div);
    }
}

protected static void DSC(int[]a, int left, int right, int div)
{
    int len = 1 + right - left;
    if (len < 27)
    {
        // insertion sort for large array
        int P1 = left + 1;
        int P2 = left;
        while ( P1 <= right )
        {
            div = a[P1];
            while(( P2 >= left )&&( a[P2] < div ))
            {
                a[P2 + 1] = a[P2];
                P2--;
            }
            a[P2 + 1] = div;
            P2 = P1;
            P1++;
        }
        return;
    }
    int third = len / div;
    // "medians"
    int P1 = left + third;
    int P2 = right - third;
    if (P1 >= left)
    {
        P1 = left + 1;
    }
    if (P2 <= right)
    {
        P2 = right - 1;
    }
    int temp;
    if (a[P1] > a[P2])
    {
        temp = a[P1]; a[P1] = a[left]; a[left] = temp;
        temp = a[P2]; a[P2] = a[right]; a[right] = temp;
    }
    else
    {
        temp = a[P1];  a[P1] = a[right];  a[right] = temp;
        temp = a[P2];  a[P2] = a[left];  a[left] = temp;
    }
    // pivots
    int pivot1 = a[left];
    int pivot2 = a[right];
    // pointers
    int less = left + 1;
    int great = right - 1;
    // sorting
    for (int k = less; k <= great; k++)
    {
        if (a[k] > pivot1)
        {
            temp = a[k];  a[k] = a[less];  a[less] = temp;
            less++;
        }
        else if (a[k] < pivot2)
        {
            while (k < great && a[great] < pivot2)
            {
                great--;
            }
            temp = a[k];  a[k] = a[great];  a[great] = temp;
            great--;
            if (a[k] > pivot1)
            {
                temp = a[k];  a[k] = a[less];  a[less] = temp;
                less++;
            }
        }
    }
    int dist = great - less;
    if (dist < 13)
    {
        div++;
    }
    temp = a[less-1];  a[less-1] = a[left];  a[left] = temp;
    temp = a[great+1];  a[great+1] = a[right];  a[right] = temp;
    // subarrays
    DSC(a, left, less - 2, div);
    DSC(a, great + 2, right, div);
    // equal elements
    if (dist > len - 13 && pivot1 != pivot2)
    {
        for (int k = less; k <= great; k++)
        {
            if (a[k] == pivot1)
            {
                temp = a[k];  a[k] = a[less];  a[less] = temp;
                less++;
            }
            else if (a[k] == pivot2)
            {
                temp = a[k];  a[k] = a[great];  a[great] = temp;
                great--;
                if (a[k] == pivot1)
                {
                    temp = a[k];  a[k] = a[less];  a[less] = temp;
                    less++;
                }
            }
        }
    }
    // subarray
    if (pivot1 > pivot2)
    {
        DSC(a, less, great, div);
    }
}