Given a sequence of matrices, what parenthesization gives the fewest multiplications in computing the matrix chain product?This is a tricky problem. In a time honored tradition, let's try solving an easier problem first, a problem we'll call the "Fewest Multiplications" problem.
Given a sequence of matrices, what is the smallest number of multiplications required in computing the matrix chain product?The Fewest Multiplications problem is easier, primarily because its answer is a number rather than a parethesization. It's not even clear how a parenthesiation would be represented! So we'll concentrate on the easier problem and try to expand our solution to it to solve the more difficult problem.
Ai Ai+1 ... AjThe basic idea is to think of choosing an "outermost parenthesization" first. Like this:
(Ai Ai+1 ... Ak) (Ak+1 ... Aj)Given that outermost parenthesization, the fewest multiplications required to compute the product is the fest multiplications required to compute the left-hand product, plus the fewest multiplications required to compute the right-hand product, plus the multiplications required to compute the product of the two result matrices. However, fewest number of multiplications required to compute the left and right hand products can be computed recursively, and we know the dimensions of the resulting matrices will be rows(Ai)xcols(Ak) and rows(Ak+1)xcols(Aj). So, given that outermost parenthesization, the fewest number of multiplications required to compute the product is:
fmm(A,i,k) + fmm(k+1,j) + rows(Ai)*cols(Ak)*cols(Aj)Now all we need to do is try every possible outermost parenthesization and see which ultimately yields the fewest multiplications in computing the matrix product.
// Return the fewest mults required to compute Ai Ai+1 ... Aj int fmm(A,i,j) { // Base Case if (i == j) return 0; // Recursive Case int m = -1; // m is the fewest mults we've seen so far, or -1 if none seen so far // Try every outermost parenthesization (Ai Ai+1 ... Ak) (Ak+1 ... Aj) for(int k = i; k < j; ++k) { t = fmm(A,i,k) + fmm(k+1,j) + rows(Ai)*cols(Ak)*cols(Aj); if (m == -1 || t < m) m = t; } return m; }
(random-mcm-problem n M)
that generates a random
vector of n matrices to be multiplied together, each dimension
falling in the range 1..M. The vector it returns actually
contains just the dimensions of matrices. For example:
> (random-mcm-problem 5 20) #5((8 3) (3 15) (15 17) (17 14) (14 17))Now we can test a scheme implementation of fmm quite easily. So ... here it is:
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; (fmm A i j) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; (define (fmm A i j) (if (= i j) 0 (do ((m -1) ; m tracks the fewst mults seen so far (-1 means we haven't seen any yet) (k i (+ k 1))) ; k is the current split i..k and (k+1)..j ((= k j) m) ; exit loop when k = j and return m (let ((t (+ (fmm A i k) ;fewest mults to compute the left product (fmm A (+ k 1) j) ;fewest mults to compute the right product (* (car (vector-ref A i)) ;# of mults to compute the outermost product (cadr (vector-ref A k)) (cadr (vector-ref A j)))))) (if (or (= m -1) (< t m)) (set! m t)) ;if current k gives fewer mults we have a new min! ))))Now we can run some tests to see how it performs. The "time" function is helpful for this. (time expression) returns the result computed by expression and displays the time required to compute it:
> (define prob1 (random-mcm-problem 10 20)) > prob1 #10((13 5) (5 12) (12 10) (10 4) (4 12) (12 6) (6 18) (18 13) (13 19) (19 2)) > (time (fmm prob1 0 9)) cpu time: 150 real time: 144 gc time: 0 1988 > (define prob2 (random-mcm-problem 15 20)) > prob2 #15((15 1) (1 3) (3 2) (2 5) (5 14) (14 15) (15 1) (1 1) (1 17) (17 4) (4 11) (11 13) (13 16) (16 10) (10 16)) > (time (fmm prob2 0 14)) cpu time: 36160 real time: 36905 gc time: 2060 1326 > (define prob3 (random-mcm-problem 20 20)) > prob3 #20((10 7) (7 9) (9 18) (18 11) (11 5) (5 9) (9 20) (20 20) (20 20) (20 18) (18 18) (18 4) (4 8) (8 19) (19 10) (10 9) (9 13) (13 11) (11 5) (5 18)) > (time (fmm prob3 0 19)) > ??? it never stopped!You can see that as the number of matrices grows, the time requirements rapidly get out of hand! Turns out that we're computing the same thing over and over. For example, in computing (fmm prob2 0 14) we compute (fmm prob2 5 7) ??? times. Hopefully you've made the connection - this is a perfect candidate for memoization.
Original Recursive Algorithm | Memoized Version |
int fmm(A,i,j)
{
int m = -1;
if (i == j)
m = 0;
else
{
for(int k = i; k < j; ++k)
{
t = fmm(A,i,k) + fmm(k+1,j)
+ rows(Ai)*cols(Ak)*cols(Aj);
if (m == -1 || t < m)
m = t;
}
}
return m;
} |
int fmm(A,i,j)
{
if (T[i,j] == -1)
{
int m = -1;
if (i == j)
m = 0;
else
{
for(int k = i; k < j; ++k)
{
t = fmm(A,i,k) + fmm(k+1,j)
+ rows(Ai)*cols(Ak)*cols(Aj);
if (m == -1 || t < m)
m = t;
}
}
T[i,j] = m;
}
return T[i,j];
} |
Notice how the original fmm code (i.e. the code in blue) pops up in the memoized version. |
All that's changed is that we check to see if fmm(A,i,j) is in the table T before we go through the work of computing it. What about the time complexity of the resulting algorithm? Well, this is a tricky algorithm to analyze. It's recursive, so you might be tempted to try and derive a recurence relation. On the other hand, we can never really say for sure whether a call takes constant time because an answer is in the table and when it won't, so deriving that recurrence relation will be tough. On the other-other hand, it's not an iterative algorithm, so those analysis techniques won't work either. The trick is to rely on the table.
The memoized algorithm fills in Θ(n^2) entries in the table T. We could analyze the algorithm by summing up for each table entry the time required by the call to fmm that filled in that entry. When determining the time to fill in a table entry, we can assume that the cost of calling fmm recursively is constant, because either it will literally be constant because the result is already in the table, or it can be treated as constant because its cost will be accounted for when we get to its entry in the table. So, if each recursive call is constant time, fmm(A,i,j) takes time O(n) [and Ω(1)]. Why? Because we go through our for-loop j-i times, and j-i < n. So, Θ(n^2) table entries, each taking O(n) [and Ω(1)] time to fill in, gives us a total of O(n^3) [and Ω(n^2)] for the whole algorithm. That might not seem really fast, but the original fmm took exponential time!
> (define A (random-mcm-problem 5 10)) > A #5((8 5) (5 1) (1 8) (8 10) (10 10)) > (fmm-memo A) 300 > (table-print fewT) 0 40 104 200 300 -1 0 40 130 230 -1 -1 0 80 180 -1 -1 -1 0 800 -1 -1 -1 -1 0 > (time (fmm-memo (random-mcm-problem 20 20))) cpu time: 130 real time: 146 gc time: 0 1869 >