为什么当线程数增加时程序变慢

时间:2019-12-24 08:34:22

标签: java multithreading performance

我是Java的初学者。

最近,我正在编写一个计算矩阵乘法的程序。所以我写了一个类来做这个。

public class MultiThreadsMatrixMultipy{
 public   int[][] multipy(int[][] matrix1,int[][] matrix2) {
     if(!utils.CheckDimension(matrix1,matrix2)){
         return null;
     }
     int row1 = matrix1.length;
     int col1 = matrix1[0].length;
     int row2 = matrix2.length;
     int col2 = matrix2[0].length;
     int[][] ans = new int[row1][col2];
     Thread[][]  threads = new SingleRowMultipy[row1][col2];

     for(int i=0;i<row1;i++){
         for(int j=0;j<col2;j++){
             threads[i][j] = new SingleRowMultipy(i,j,matrix1,matrix2,ans));
             threads[i][j].start();
         }
     }
     return ans;
 }
}
public class SingleRowMultipy extends Thread{
        private int row;
        private int col;
        private int[][] A;
        private int[][] B;
        private int[][] ans;
        public SingleRowMultipy(int row,int col,int[][] A,int[][] B,int[][] C){
            this.row = row;
            this.col = col;
            this.A = A;
            this.B = B;
            this.ans = C;
        }
        public void run(){
            int sum =0;
            for(int i=0;i<A[row].length;i++){
                 sum+=(A[row][i]*B[i][col]);
            }
            ans[row][col] = sum;
        }
}

我想使用一个线程来计算matrix1[i][:] * matrix2[:][j],矩阵的大小是1000*50005000*1000,所以线程数是1000 * 1000。程序,它的运行速度非常慢,大约需要38s。如果我仅使用single-thread来计算结果,它将花费17s。单线程代码如下:

public class SimpleMatrixMultipy
{
    public int[][] multipy(int[][] matrix1,int[][] matrix2){
        int row1 = matrix1.length;
        int col1 = matrix1[0].length;
        int row2 = matrix2.length;
        int col2 = matrix2[0].length;
        int[][] ans = new int[row1][col2];
        for(int i=0;i<row1;i++){
            for(int j=0;j<col2;j++){
                for(int k=0;k<col1;k++){
                    ans[i][j] += matrix1[i][k]*matrix2[k][j];
                }
            }
        }
        return ans;
    }

}

我该如何加快程序执行速度?

1 个答案:

答案 0 :(得分:0)

正如@ Turing85所说,需要管理线程数。有两种方法可以将Executors.newFixedThreadPool用于固定数量的线程,也可以使用Executors.newCachedThreadPool来使用现有线程。

其他重要的一点是避免直接继承Thread类,而应实现可运行。

import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;

public class MultiThreadsMatrixMultipy {

    public static void main(final String[] args) {

    }

    public int[][] multipy(final int[][] matrix1, final int[][] matrix2) {
        if(!utils.CheckDimension(matrix1,matrix2)){
            return null;
        }
        final int row1 = matrix1.length;
        final int col2 = matrix2[0].length;
        final int[][] ans = new int[row1][col2];
        // final Executor executor = Executors.newCachedThreadPool(new CustomThreadFactory("Multiplier"));
        final Executor executor = Executors.newFixedThreadPool(20, new CustomThreadFactory("Multiplier"));

        for (int i = 0; i < row1; i++) {
            for (int j = 0; j < col2; j++) {
                executor.execute(new SingleRowMultipy(i, j, matrix1, matrix2, ans));
            }
        }
        return ans;
    }
}

class CustomThreadFactory implements ThreadFactory {
    private int counter;
    private final String name;
    private final List<String> stats;

    public CustomThreadFactory(final String name) {
        counter = 1;
        this.name = name;
        stats = new ArrayList<>();
    }

    @Override
    public Thread newThread(final Runnable runnable) {
        final Thread t = new Thread(runnable, name + "-Thread_" + counter);
        counter++;
        stats.add(String.format("Created thread %d with name %s on %s \n", t.getId(), t.getName(), new Date()));
        return t;
    }

    public String getStats() {
        final StringBuffer buffer = new StringBuffer();
        final Iterator<String> it = stats.iterator();
        while (it.hasNext()) {
            buffer.append(it.next());
        }
        return buffer.toString();
    }
}

class SingleRowMultipy implements Runnable {
    private final int row;
    private final int col;
    private final int[][] A;
    private final int[][] B;
    private final int[][] ans;

    public SingleRowMultipy(final int row, final int col, final int[][] A, final int[][] B, final int[][] C) {
        this.row = row;
        this.col = col;
        this.A = A;
        this.B = B;
        this.ans = C;
    }

    @Override
    public void run() {
        int sum = 0;
        for (int i = 0; i < A[row].length; i++) {
            sum += (A[row][i] * B[i][col]);
        }
        ans[row][col] = sum;
    }
}
相关问题