Matrixmult. Parallel mit split und merge
Verfasst: So Dez 03, 2017 7:58 am
Ich habe derzeit ein Problem, einen Algorithmus zu realisieren, und zwar würde ich gerne Matrizen parallel multiplizieren. So soll das ganze als Beispiel für eine 2x2 Matrix aussehen:
Die beiden Matrizen A und B sollen dabei also vorher in viele kleine 2x2 Matrizen gesplittet werden und dann sollen gemäß dem Bild die einzelnen Elemente der Ergebnismatrix C berechnet werden, was parallel ablaufen soll und am Ende soll auch die Ergebnismatrix, die ja aus vielen Untermatrizen besteht wieder "gemerged" werden, ganz nach dem Prinzip split und merge (Teile und Herrsche).
Ich bin bei der Implementierung aber auf ein Problem gestoßen und find den Fehler aber einfach nicht.
Das hier ist ein Junit Test, der leider nur für den ersten Testfall ein grün gibt. Die anderen beiden sind rot, bzw. nur der este, da jUnit ja beim ersten fehlgeschlagenem Test in einer Methode ja aufhört. bzw. abbricht.
Da ihr vielleicht so das ganze auch ganz einfach zum laufen bringen könnt, hier mal meine Klasse Matrix, mit den entsprechend benötigten Methoden. Ich hoffe mal, dass es so passt, die ganze Klasse wäre mittlerweile ein wenig zu umfangreich.
So könnt ihr auch meinen kleinen einzelnen Test laufen lassen. Ich hoffe wirklich, dass mir hier jemand helfen kann. Ich suche jetzt mittlerweile schon sehr lange nach dem Fehler aber ich finde ihn einfach nicht. Ich hoffe jemand kann helfen.
Die beiden Matrizen A und B sollen dabei also vorher in viele kleine 2x2 Matrizen gesplittet werden und dann sollen gemäß dem Bild die einzelnen Elemente der Ergebnismatrix C berechnet werden, was parallel ablaufen soll und am Ende soll auch die Ergebnismatrix, die ja aus vielen Untermatrizen besteht wieder "gemerged" werden, ganz nach dem Prinzip split und merge (Teile und Herrsche).
Ich bin bei der Implementierung aber auf ein Problem gestoßen und find den Fehler aber einfach nicht.
Code: Alles auswählen
public static Matrix parallelMult(final Matrix matA, final Matrix matB) {
Matrix[][] matASplit = split(matA);
Matrix[][] matBSplit = split(matB);
final int matASplitRows = matASplit.length;
final int matBSplitCols = matBSplit[0].length;
final int cpuCount = Runtime.getRuntime().availableProcessors();
Matrix[][] matCSplit = new Matrix[matASplitRows][matBSplitCols];
Matrix.initializeEmpty(matCSplit);
BlockingQueue<Runnable> workQueue = new ArrayBlockingQueue<Runnable>(matASplitRows * matBSplitCols + 1);
ThreadPoolExecutor threadpool = new ThreadPoolExecutor(cpuCount, cpuCount + 2, 10, TimeUnit.MINUTES, workQueue);
for(int i = 0; i < matASplitRows; i++) {
final int iF = i;
for(int l = 0; l < matBSplitCols; l++) {
final int lF = l;
threadpool.execute(new Runnable() {
int j, k;
@Override
public void run() {
for(j = 0; j < 2; j++) {
for (k = 0; k < 2; k++) {
matCSplit[iF][lF].data[j][k] += matASplit[iF][lF].data[j][0] * matBSplit[iF][lF].data[0][k] +
matASplit[iF][lF].data[j][1] * matBSplit[iF][lF].data[1][k];
}
}
}
});
}
}
try {
threadpool.shutdown();
if(!threadpool.awaitTermination(200, TimeUnit.MILLISECONDS)) {
threadpool.shutdownNow();
threadpool.awaitTermination(200, TimeUnit.MILLISECONDS);
}
} catch(InterruptedException e) {
e.printStackTrace();
}
return Matrix.merge(matCSplit);
}
private static Matrix[][] split(Matrix mat) {
int currentRowIndex = 0;
int currentColIndex = 0;
int currentNewMatRow = 0;
int currentNewMatCol = 0;
final double splittedMatRows = Math.round(((double) mat.getNumberOfRows() / 2));
final double splittedMatCols = Math.round(((double) mat.getNumberOfColumns() / 2));
Matrix[][] splittedMatrix = new Matrix[(int) splittedMatRows][(int) splittedMatCols];
double[][] newMatData;
while(currentNewMatRow < splittedMatRows && currentRowIndex < (mat.getNumberOfRows() - 1)) {
while(currentNewMatCol < splittedMatCols && currentColIndex < (mat.getNumberOfColumns() - 1)) {
newMatData = new double[2][2];
for(int j = 0; j < 2; j++) {
for(int k = 0; k < 2; k++) {
newMatData[j][k] = mat.data[j + currentRowIndex][k + currentColIndex];
}
}
currentColIndex += 2;
splittedMatrix[currentNewMatRow][currentNewMatCol++] = new Matrix(newMatData);
}
currentColIndex = 0;
currentRowIndex += 2;
currentNewMatRow++;
currentNewMatCol = 0;
}
return splittedMatrix;
}
private static Matrix merge(Matrix[][] matToMerge) {
final int splittedMatRows = matToMerge.length;
final int splittedMatCols = matToMerge[0].length;
final int mergedMatRows = splittedMatRows * 2;
final int mergedMatCols = splittedMatCols * 2;
double[][] mergedMatrix = new double[mergedMatRows][mergedMatCols];
for(int i = 0; i < splittedMatRows; i++) {
for(int j = 0; j < splittedMatCols; j++) {
for(int k = 0; k < 2; k++) {
for(int l = 0; l < 2; l++) {
mergedMatrix[k + i][j + l] = matToMerge[i][j].data[k][l];
}
}
}
}
return new Matrix(mergedMatrix);
}
Das hier ist ein Junit Test, der leider nur für den ersten Testfall ein grün gibt. Die anderen beiden sind rot, bzw. nur der este, da jUnit ja beim ersten fehlgeschlagenem Test in einer Methode ja aufhört. bzw. abbricht.
Code: Alles auswählen
@Test
public final void testParallelMult() throws IllegalArgumentException {
Matrix matA;
Matrix matB;
Matrix resultMat;
Matrix expectedResultMat;
matA = new Matrix(new double[][] {{1,2},{3,4}});
matB = new Matrix(new double[][] {{2,3},{4,5}});
expectedResultMat = new Matrix(new double[][] {{10,13},{22,29}});
resultMat = Matrix.parallelMult(matA, matB);
assertTrue("Wrong result!\nExpected:\n" + expectedResultMat.toString() + "Got:\n" + resultMat.toString(), expectedResultMat.equals(resultMat));
matA = new Matrix(new double[][] {{1,2},{3,4}});
matB = new Matrix(new double[][] {{2,3,4},{5,6,7}});
expectedResultMat = new Matrix(new double[][] {{12,15,18},{26,33,40}});
resultMat = Matrix.parallelMult(matA, matB);
assertTrue("Wrong result!\nExpected:\n" + expectedResultMat.toString() + "Got:\n" + resultMat.toString(), expectedResultMat.equals(resultMat));
matA = new Matrix(new double[][] {{1,2},{3,4},{5,6}});
matB = new Matrix(new double[][] {{2,3},{4,5}});
expectedResultMat = new Matrix(new double[][] {{10,13},{22,29},{34,45}});
}
Code: Alles auswählen
public class Matrix {
private final int rows;
private final int columns;
private double data[][];
public Matrix(double matrix[][]) {
this.data = matrix;
this.rows = matrix.length;
this.columns = matrix[0].length;
}
@Override
public boolean equals(Object other) {
if(other == null) return false;
if(other == this) return true;
int i, j;
final Matrix otherMatrix = (Matrix) other;
final int otherMatrixRows = otherMatrix.rows;
final int otherMatrixColumns = otherMatrix.columns;
if(this.getNumberOfRows() != otherMatrixRows || this.getNumberOfColumns() != otherMatrixColumns)
return false;
for(i = 0; i < otherMatrixRows; i++) {
for(j = 0; j < otherMatrixColumns; j++) {
if(this.data[i][j] != otherMatrix.data[i][j])
return false;
}
}
return true;
}
@Override
public String toString() {
int i, j;
String matWrapper = "";
String retString = "";
final String minusAColumn = "------------";
for(i = 0; i < this.getNumberOfColumns(); i++)
matWrapper += minusAColumn;
matWrapper += "-";
retString += matWrapper + "\n";
for(i = 0; i < this.getNumberOfRows(); i++) {
for(j = 0; j < this.getNumberOfColumns(); j++) {
retString += String.format("| %10.0f", this.data[i][j]); // 10.2f
}
retString += "|\n";
}
retString += matWrapper + "\n";
return retString;
}
public static Matrix parallelMult(final Matrix matA, final Matrix matB) {
Matrix[][] matASplit = split(matA);
Matrix[][] matBSplit = split(matB);
final int matASplitRows = matASplit.length;
final int matBSplitCols = matBSplit[0].length;
final int cpuCount = Runtime.getRuntime().availableProcessors();
Matrix[][] matCSplit = new Matrix[matASplitRows][matBSplitCols];
Matrix.initializeEmpty(matCSplit);
BlockingQueue<Runnable> workQueue = new ArrayBlockingQueue<Runnable>(matASplitRows * matBSplitCols + 1);
ThreadPoolExecutor threadpool = new ThreadPoolExecutor(cpuCount, cpuCount + 2, 10, TimeUnit.MINUTES, workQueue);
for(int i = 0; i < matASplitRows; i++) {
final int iF = i;
for(int l = 0; l < matBSplitCols; l++) {
final int lF = l;
threadpool.execute(new Runnable() {
int j, k;
@Override
public void run() {
for(j = 0; j < 2; j++) {
for (k = 0; k < 2; k++) {
matCSplit[iF][lF].data[j][k] += matASplit[iF][lF].data[j][0] * matBSplit[iF][lF].data[0][k] +
matASplit[iF][lF].data[j][1] * matBSplit[iF][lF].data[1][k];
}
}
}
});
}
}
try {
threadpool.shutdown();
if(!threadpool.awaitTermination(200, TimeUnit.MILLISECONDS)) {
threadpool.shutdownNow();
threadpool.awaitTermination(200, TimeUnit.MILLISECONDS);
}
} catch(InterruptedException e) {
e.printStackTrace();
}
return Matrix.merge(matCSplit);
}
private static Matrix[][] split(Matrix mat) {
int currentRowIndex = 0;
int currentColIndex = 0;
int currentNewMatRow = 0;
int currentNewMatCol = 0;
final double splittedMatRows = Math.round(((double) mat.getNumberOfRows() / 2));
final double splittedMatCols = Math.round(((double) mat.getNumberOfColumns() / 2));
Matrix[][] splittedMatrix = new Matrix[(int) splittedMatRows][(int) splittedMatCols];
double[][] newMatData;
while(currentNewMatRow < splittedMatRows && currentRowIndex < (mat.getNumberOfRows() - 1)) {
while(currentNewMatCol < splittedMatCols && currentColIndex < (mat.getNumberOfColumns() - 1)) {
newMatData = new double[2][2];
for(int j = 0; j < 2; j++) {
for(int k = 0; k < 2; k++) {
newMatData[j][k] = mat.data[j + currentRowIndex][k + currentColIndex];
}
}
currentColIndex += 2;
splittedMatrix[currentNewMatRow][currentNewMatCol++] = new Matrix(newMatData);
}
currentColIndex = 0;
currentRowIndex += 2;
currentNewMatRow++;
currentNewMatCol = 0;
}
return splittedMatrix;
}
private static Matrix merge(Matrix[][] matToMerge) {
final int splittedMatRows = matToMerge.length;
final int splittedMatCols = matToMerge[0].length;
final int mergedMatRows = splittedMatRows * 2;
final int mergedMatCols = splittedMatCols * 2;
double[][] mergedMatrix = new double[mergedMatRows][mergedMatCols];
for(int i = 0; i < splittedMatRows; i++) {
for(int j = 0; j < splittedMatCols; j++) {
for(int k = 0; k < 2; k++) {
for(int l = 0; l < 2; l++) {
mergedMatrix[k + i][j + l] = matToMerge[i][j].data[k][l];
}
}
}
}
return new Matrix(mergedMatrix);
}
}