Skip to content

Commit e728aa7

Browse files
authored
Add tests, remove main, enhance docs in MatrixChainMultiplication (#5658)
1 parent bd3b754 commit e728aa7

File tree

3 files changed

+163
-86
lines changed

3 files changed

+163
-86
lines changed

DIRECTORY.md

+1
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@
814814
* [LongestIncreasingSubsequenceTests](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/LongestIncreasingSubsequenceTests.java)
815815
* [LongestPalindromicSubstringTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/LongestPalindromicSubstringTest.java)
816816
* [LongestValidParenthesesTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/LongestValidParenthesesTest.java)
817+
* [MatrixChainMultiplicationTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/MatrixChainMultiplicationTest.java)
817818
* [MinimumPathSumTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/MinimumPathSumTest.java)
818819
* [MinimumSumPartitionTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/MinimumSumPartitionTest.java)
819820
* [OptimalJobSchedulingTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/OptimalJobSchedulingTest.java)

src/main/java/com/thealgorithms/dynamicprogramming/MatrixChainMultiplication.java

+108-86
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,32 @@
22

33
import java.util.ArrayList;
44
import java.util.Arrays;
5-
import java.util.Scanner;
65

6+
/**
7+
* The MatrixChainMultiplication class provides functionality to compute the
8+
* optimal way to multiply a sequence of matrices. The optimal multiplication
9+
* order is determined using dynamic programming, which minimizes the total
10+
* number of scalar multiplications required.
11+
*/
712
public final class MatrixChainMultiplication {
813
private MatrixChainMultiplication() {
914
}
1015

11-
private static final Scanner SCANNER = new Scanner(System.in);
12-
private static final ArrayList<Matrix> MATRICES = new ArrayList<>();
13-
private static int size;
16+
// Matrices to store minimum multiplication costs and split points
1417
private static int[][] m;
1518
private static int[][] s;
1619
private static int[] p;
1720

18-
public static void main(String[] args) {
19-
int count = 1;
20-
while (true) {
21-
String[] mSize = input("input size of matrix A(" + count + ") ( ex. 10 20 ) : ");
22-
int col = Integer.parseInt(mSize[0]);
23-
if (col == 0) {
24-
break;
25-
}
26-
int row = Integer.parseInt(mSize[1]);
27-
28-
Matrix matrix = new Matrix(count, col, row);
29-
MATRICES.add(matrix);
30-
count++;
31-
}
32-
for (Matrix m : MATRICES) {
33-
System.out.format("A(%d) = %2d x %2d%n", m.count(), m.col(), m.row());
34-
}
35-
36-
size = MATRICES.size();
21+
/**
22+
* Calculates the optimal order for multiplying a given list of matrices.
23+
*
24+
* @param matrices an ArrayList of Matrix objects representing the matrices
25+
* to be multiplied.
26+
* @return a Result object containing the matrices of minimum costs and
27+
* optimal splits.
28+
*/
29+
public static Result calculateMatrixChainOrder(ArrayList<Matrix> matrices) {
30+
int size = matrices.size();
3731
m = new int[size + 1][size + 1];
3832
s = new int[size + 1][size + 1];
3933
p = new int[size + 1];
@@ -44,51 +38,20 @@ public static void main(String[] args) {
4438
}
4539

4640
for (int i = 0; i < p.length; i++) {
47-
p[i] = i == 0 ? MATRICES.get(i).col() : MATRICES.get(i - 1).row();
41+
p[i] = i == 0 ? matrices.get(i).col() : matrices.get(i - 1).row();
4842
}
4943

50-
matrixChainOrder();
51-
for (int i = 0; i < size; i++) {
52-
System.out.print("-------");
53-
}
54-
System.out.println();
55-
printArray(m);
56-
for (int i = 0; i < size; i++) {
57-
System.out.print("-------");
58-
}
59-
System.out.println();
60-
printArray(s);
61-
for (int i = 0; i < size; i++) {
62-
System.out.print("-------");
63-
}
64-
System.out.println();
65-
66-
System.out.println("Optimal solution : " + m[1][size]);
67-
System.out.print("Optimal parens : ");
68-
printOptimalParens(1, size);
69-
}
70-
71-
private static void printOptimalParens(int i, int j) {
72-
if (i == j) {
73-
System.out.print("A" + i);
74-
} else {
75-
System.out.print("(");
76-
printOptimalParens(i, s[i][j]);
77-
printOptimalParens(s[i][j] + 1, j);
78-
System.out.print(")");
79-
}
80-
}
81-
82-
private static void printArray(int[][] array) {
83-
for (int i = 1; i < size + 1; i++) {
84-
for (int j = 1; j < size + 1; j++) {
85-
System.out.printf("%7d", array[i][j]);
86-
}
87-
System.out.println();
88-
}
44+
matrixChainOrder(size);
45+
return new Result(m, s);
8946
}
9047

91-
private static void matrixChainOrder() {
48+
/**
49+
* A helper method that computes the minimum cost of multiplying
50+
* the matrices using dynamic programming.
51+
*
52+
* @param size the number of matrices in the multiplication sequence.
53+
*/
54+
private static void matrixChainOrder(int size) {
9255
for (int i = 1; i < size + 1; i++) {
9356
m[i][i] = 0;
9457
}
@@ -109,33 +72,92 @@ private static void matrixChainOrder() {
10972
}
11073
}
11174

112-
private static String[] input(String string) {
113-
System.out.print(string);
114-
return (SCANNER.nextLine().split(" "));
115-
}
116-
}
117-
118-
class Matrix {
75+
/**
76+
* The Result class holds the results of the matrix chain multiplication
77+
* calculation, including the matrix of minimum costs and split points.
78+
*/
79+
public static class Result {
80+
private final int[][] m;
81+
private final int[][] s;
82+
83+
/**
84+
* Constructs a Result object with the specified matrices of minimum
85+
* costs and split points.
86+
*
87+
* @param m the matrix of minimum multiplication costs.
88+
* @param s the matrix of optimal split points.
89+
*/
90+
public Result(int[][] m, int[][] s) {
91+
this.m = m;
92+
this.s = s;
93+
}
11994

120-
private final int count;
121-
private final int col;
122-
private final int row;
95+
/**
96+
* Returns the matrix of minimum multiplication costs.
97+
*
98+
* @return the matrix of minimum multiplication costs.
99+
*/
100+
public int[][] getM() {
101+
return m;
102+
}
123103

124-
Matrix(int count, int col, int row) {
125-
this.count = count;
126-
this.col = col;
127-
this.row = row;
104+
/**
105+
* Returns the matrix of optimal split points.
106+
*
107+
* @return the matrix of optimal split points.
108+
*/
109+
public int[][] getS() {
110+
return s;
111+
}
128112
}
129113

130-
int count() {
131-
return count;
132-
}
114+
/**
115+
* The Matrix class represents a matrix with its dimensions and count.
116+
*/
117+
public static class Matrix {
118+
private final int count;
119+
private final int col;
120+
private final int row;
121+
122+
/**
123+
* Constructs a Matrix object with the specified count, number of columns,
124+
* and number of rows.
125+
*
126+
* @param count the identifier for the matrix.
127+
* @param col the number of columns in the matrix.
128+
* @param row the number of rows in the matrix.
129+
*/
130+
public Matrix(int count, int col, int row) {
131+
this.count = count;
132+
this.col = col;
133+
this.row = row;
134+
}
133135

134-
int col() {
135-
return col;
136-
}
136+
/**
137+
* Returns the identifier of the matrix.
138+
*
139+
* @return the identifier of the matrix.
140+
*/
141+
public int count() {
142+
return count;
143+
}
137144

138-
int row() {
139-
return row;
145+
/**
146+
* Returns the number of columns in the matrix.
147+
*
148+
* @return the number of columns in the matrix.
149+
*/
150+
public int col() {
151+
return col;
152+
}
153+
154+
/**
155+
* Returns the number of rows in the matrix.
156+
*
157+
* @return the number of rows in the matrix.
158+
*/
159+
public int row() {
160+
return row;
161+
}
140162
}
141163
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package com.thealgorithms.dynamicprogramming;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import java.util.ArrayList;
6+
import org.junit.jupiter.api.Test;
7+
8+
class MatrixChainMultiplicationTest {
9+
10+
@Test
11+
void testMatrixCreation() {
12+
MatrixChainMultiplication.Matrix matrix1 = new MatrixChainMultiplication.Matrix(1, 10, 20);
13+
MatrixChainMultiplication.Matrix matrix2 = new MatrixChainMultiplication.Matrix(2, 20, 30);
14+
15+
assertEquals(1, matrix1.count());
16+
assertEquals(10, matrix1.col());
17+
assertEquals(20, matrix1.row());
18+
19+
assertEquals(2, matrix2.count());
20+
assertEquals(20, matrix2.col());
21+
assertEquals(30, matrix2.row());
22+
}
23+
24+
@Test
25+
void testMatrixChainOrder() {
26+
// Create a list of matrices to be multiplied
27+
ArrayList<MatrixChainMultiplication.Matrix> matrices = new ArrayList<>();
28+
matrices.add(new MatrixChainMultiplication.Matrix(1, 10, 20)); // A(1) = 10 x 20
29+
matrices.add(new MatrixChainMultiplication.Matrix(2, 20, 30)); // A(2) = 20 x 30
30+
31+
// Calculate matrix chain order
32+
MatrixChainMultiplication.Result result = MatrixChainMultiplication.calculateMatrixChainOrder(matrices);
33+
34+
// Expected cost of multiplying A(1) and A(2)
35+
int expectedCost = 6000; // The expected optimal cost of multiplying A(1)(10x20) and A(2)(20x30)
36+
int actualCost = result.getM()[1][2];
37+
38+
assertEquals(expectedCost, actualCost);
39+
}
40+
41+
@Test
42+
void testOptimalParentheses() {
43+
// Create a list of matrices to be multiplied
44+
ArrayList<MatrixChainMultiplication.Matrix> matrices = new ArrayList<>();
45+
matrices.add(new MatrixChainMultiplication.Matrix(1, 10, 20)); // A(1) = 10 x 20
46+
matrices.add(new MatrixChainMultiplication.Matrix(2, 20, 30)); // A(2) = 20 x 30
47+
48+
// Calculate matrix chain order
49+
MatrixChainMultiplication.Result result = MatrixChainMultiplication.calculateMatrixChainOrder(matrices);
50+
51+
// Check the optimal split for parentheses
52+
assertEquals(1, result.getS()[1][2]); // s[1][2] should point to the optimal split
53+
}
54+
}

0 commit comments

Comments
 (0)