我是Java的初学者。
最近,我正在编写一个计算矩阵乘法的程序。所以我写了一个类来做这个。
public class MultiThreadsMatrixMultipy{
public int[][] multipy(int[][] matrix1,int[][] matrix2) {
if(!utils.CheckDimension(matrix1,matrix2)){
return null;
}
int row1 = matrix1.length;
int col1 = matrix1[0].length;
int row2 = matrix2.length;
int col2 = matrix2[0].length;
int[][] ans = new int[row1][col2];
Thread[][] threads = new SingleRowMultipy[row1][col2];
for(int i=0;i<row1;i++){
for(int j=0;j<col2;j++){
threads[i][j] = new SingleRowMultipy(i,j,matrix1,matrix2,ans));
threads[i][j].start();
}
}
return ans;
}
}
public class SingleRowMultipy extends Thread{
private int row;
private int col;
private int[][] A;
private int[][] B;
private int[][] ans;
public SingleRowMultipy(int row,int col,int[][] A,int[][] B,int[][] C){
this.row = row;
this.col = col;
this.A = A;
this.B = B;
this.ans = C;
}
public void run(){
int sum =0;
for(int i=0;i<A[row].length;i++){
sum+=(A[row][i]*B[i][col]);
}
ans[row][col] = sum;
}
}
我想使用一个线程来计算matrix1[i][:] * matrix2[:][j]
,矩阵的大小是1000*5000
和5000*1000
,所以线程数是1000 * 1000。程序,它的运行速度非常慢,大约需要38s
。如果我仅使用single-thread
来计算结果,它将花费17s
。单线程代码如下:
public class SimpleMatrixMultipy
{
public int[][] multipy(int[][] matrix1,int[][] matrix2){
int row1 = matrix1.length;
int col1 = matrix1[0].length;
int row2 = matrix2.length;
int col2 = matrix2[0].length;
int[][] ans = new int[row1][col2];
for(int i=0;i<row1;i++){
for(int j=0;j<col2;j++){
for(int k=0;k<col1;k++){
ans[i][j] += matrix1[i][k]*matrix2[k][j];
}
}
}
return ans;
}
}
我该如何加快程序执行速度?
答案 0 :(得分:0)
正如@ Turing85所说,需要管理线程数。有两种方法可以将Executors.newFixedThreadPool
用于固定数量的线程,也可以使用Executors.newCachedThreadPool
来使用现有线程。
其他重要的一点是避免直接继承Thread
类,而应实现可运行。
import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
public class MultiThreadsMatrixMultipy {
public static void main(final String[] args) {
}
public int[][] multipy(final int[][] matrix1, final int[][] matrix2) {
if(!utils.CheckDimension(matrix1,matrix2)){
return null;
}
final int row1 = matrix1.length;
final int col2 = matrix2[0].length;
final int[][] ans = new int[row1][col2];
// final Executor executor = Executors.newCachedThreadPool(new CustomThreadFactory("Multiplier"));
final Executor executor = Executors.newFixedThreadPool(20, new CustomThreadFactory("Multiplier"));
for (int i = 0; i < row1; i++) {
for (int j = 0; j < col2; j++) {
executor.execute(new SingleRowMultipy(i, j, matrix1, matrix2, ans));
}
}
return ans;
}
}
class CustomThreadFactory implements ThreadFactory {
private int counter;
private final String name;
private final List<String> stats;
public CustomThreadFactory(final String name) {
counter = 1;
this.name = name;
stats = new ArrayList<>();
}
@Override
public Thread newThread(final Runnable runnable) {
final Thread t = new Thread(runnable, name + "-Thread_" + counter);
counter++;
stats.add(String.format("Created thread %d with name %s on %s \n", t.getId(), t.getName(), new Date()));
return t;
}
public String getStats() {
final StringBuffer buffer = new StringBuffer();
final Iterator<String> it = stats.iterator();
while (it.hasNext()) {
buffer.append(it.next());
}
return buffer.toString();
}
}
class SingleRowMultipy implements Runnable {
private final int row;
private final int col;
private final int[][] A;
private final int[][] B;
private final int[][] ans;
public SingleRowMultipy(final int row, final int col, final int[][] A, final int[][] B, final int[][] C) {
this.row = row;
this.col = col;
this.A = A;
this.B = B;
this.ans = C;
}
@Override
public void run() {
int sum = 0;
for (int i = 0; i < A[row].length; i++) {
sum += (A[row][i] * B[i][col]);
}
ans[row][col] = sum;
}
}