Class 21: Remembering Matrix Chain Multiplication


Reading
Sections 15.2 of Introduction to Algorithms.

Homework
To Appear

The Matrix Chain Multiplication Problem
Given an array A of matrices, and a range i..j of indices in the array, what is the parenthesization of
Ai Ai+1 ... Aj
that computes the matrix product using the fewest multiplications? (I.e. multiplications of values in the matrices.)

This is called the matrix chain multiplication problem, and it's kind of a tough one. When we're face with a difficult problem the best approach is often to first solve a simplified version of the problem, and then ask how that solution can be augmented to solve the full original problem. The simplified problem we're going to consider is simply to figure out what the fewest multiplications required to compute the matrix product is. We'll call this fmm(A,i,j).

Review: Multiplying Matrices
The dimension of a matrix is the number of rows and the number of columns. So a 5x7 matrix would have 5 rows and 7 columns. When we take the product of two matrices the number of columns in the left-hand matrix must equal the number of rows in the right hand. This is because, as you will recall, we compute element i,j of the produce by multiplying each element of ith row of the left-hand matrix by the associated element of the jth column of the right-hand matrix, and summing all the results. Thus, there better be the same number of elements in the rows of the left-hand matrix as in columns of the right-hand matrix.

The product of an rxc matrix and a r'xc' matrix (remember c=r') is a rxc' matrix. Computing each of the r*c' entries of this matrix requires c multiplications, so the total number of multiplications involved is r*c*c'.

| 4 -1 |                 | 4*-1 + -1*0   4*-1 + -1*2   4*3 + -1*1 |
|      |   |-1 -1  3 |   |                                        |
|-2  0 | x |         | = |-2*-1 +  0*0  -2*-1 +  0*2  -2*3 +  0*1 | 
|      |   | 0  2  1 |   |                                        |
| 2 -1 |                 | 2*-1 + -1*0   2*-1 + -1*2   2*3 + -1*1 |

                         | -4  -6  11 |
                         |            |
                       = |  2   2  -6 |
                         |            |
                         | -2  -4   5 |
	
See? Two multiplications for every element in the 3x3 product matrix.

If you have a chain of matrices to multiply together the number of columns in one matrix needs to equal the number of rows in the following matrix for the product to even make sense. As long as that is satisfied, you can compute the product. However, products are computed pairwise - i.e. two matrices at a time. Thus, you have to completely parenthisize the chain of matrices in order to really describe how you'll compute the product. The interesting thing is this: The way you do the parenthesization does not affect the answer, but it can have a dramatic effect on the number of multiplications performed in computing the answer.

Example:  Compute A0 x A1 x A3 where:

 A0    A1   A2
10x2  2x8  8x3

/ A0    A1 \  / A2 \
| 10x2  2x8 | | 8x3 | 
\          /  \    /
 10*2*8 = 160
 mults,
\__________/
    10x8
Takes 10*8*3=240 mults for 10x3 reslult

Total is 400 multiplications

/ A0  \  / A1   A2 \
| 10x2 | | 2x8  8x3 | 
\      / \         /
           2*8*3=48
           mults,
         \__________/
            2x3
Takes 10*2*3=60 mults for 10x3 reslult

Total is 108 multiplications
So, where you put the parenthesese really matters. That's why we're trying to figure out where the best place is to put 'em.

Finding the Fewest Multiplications
We want to create an algorithm fmm(A,i,j) that computes the fewest multiplications required to compute the product:
Ai Ai+1 ... Aj
The basic idea is to think of choosing an "outermost parenthesization" first. Like this:
(Ai Ai+1 ... Ak) (Ak+1 ... Aj)
Given that outermost parenthesization, the fewest multiplications required to compute the product is the fest multiplications required to compute the left-hand product, plus the fewest multiplications required to compute the right-hand product, plus the multiplications required to compute the product of the two result matrices. However, fewest number of multiplications required to compute the left and right hand products can be computed recursively, and we know the dimensions of the resulting matrices will be rows(Ai)xcols(Ak) and rows(Ak+1)xcols(Aj). So, given that outermost parenthesization, the fewest number of multiplications required to compute the product is:
fmm(A,i,k) + fmm(k+1,j) + rows(Ai)*cols(Ak)*cols(Aj)
Now all we need to do is try every possible outermost parenthesization and see which ultimately yields the fewest multiplications in computing the matrix product.

An Algorithm
So here's an algorithm implementing the above plan:
// Return the fewest mults required to compute Ai Ai+1 ... Aj
int fmm(A,i,j)
{
  // Base Case
  if (i == j) return 0;

  // Recursive Case
  int m = -1; // m is the fewest mults we've seen so far, or -1 if none seen so far

  // Try every outermost parenthesization (Ai Ai+1 ... Ak) (Ak+1 ... Aj)
  for(int k = i; k < j; ++k)
  {
    t = fmm(A,i,k) + fmm(k+1,j) + rows(Ai)*cols(Ak)*cols(Aj);
    if (m == -1 || t < m)
      m = t;
  }
  return m;
}
Trying it out
The file mcm.scm defines a function (random-mcm-problem n M) that generates a random vector of n matrices to be multiplied together, each dimension falling in the range 1..M. The vector it returns actually contains just the dimensions of matrices. For example:
> (random-mcm-problem 5 20)
#5((8 3) (3 15) (15 17) (17 14) (14 17))
Now we can test a scheme implementation of fmm quite easily. So ... here it is:
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; (fmm A i j) 
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(define (fmm A i j)
  (if (= i j)
      0
      (do ((m -1)         ; m tracks the fewst mults seen so far (-1 means we haven't seen any yet)
	   (k i (+ k 1))) ; k is the current split i..k and (k+1)..j
        
        ((= k j) m)       ; exit loop when k = j and return m
        
        (let ((t (+ (fmm A i k)                ;fewest mults to compute the left product
		    (fmm A (+ k 1) j)          ;fewest mults to compute the right product
		    (* (car (vector-ref A i))  ;# of mults to compute the outermost product
		       (cadr (vector-ref A k)) 
		       (cadr (vector-ref A j))))))
          (if (or (= m -1) (< t m)) 
              (set! m t))                      ;if current k gives fewer mults we have a new min!
	  ))))
Now we can run some tests to see how it performs. The "time" function is helpful for this. (time expression) returns the result computed by expression and displays the time required to compute it:
> (define prob1 (random-mcm-problem 10 20))
> prob1
#10((13 5) (5 12) (12 10) (10 4) (4 12) (12 6) (6 18) (18 13) (13 19) (19 2))
> (time (fmm prob1 0 9))
cpu time: 150 real time: 144 gc time: 0
1988
> (define prob2 (random-mcm-problem 15 20))
> prob2
#15((15 1) (1 3) (3 2) (2 5) (5 14) (14 15) (15 1) (1 1) (1 17) (17 4) (4 11) (11 13) (13 16) (16 10) (10 16))
> (time (fmm prob2 0 14))
cpu time: 36160 real time: 36905 gc time: 2060
1326
> (define prob3 (random-mcm-problem 20 20))
> prob3
#20((10 7)
    (7 9)
    (9 18)
    (18 11)
    (11 5)
    (5 9)
    (9 20)
    (20 20)
    (20 20)
    (20 18)
    (18 18)
    (18 4)
    (4 8)
    (8 19)
    (19 10)
    (10 9)
    (9 13)
    (13 11)
    (11 5)
    (5 18))
> (time (fmm prob3 0 19))
> ??? it never stopped!
You can see that as the number of matrices grows, the time requirements rapidly get out of hand! Turns out that we're computing the same thing over and over. For example, in computing (fmm prob2 0 14) we compute (fmm prob2 5 7) ??? times. Hopefully you've made the connection - this is a perfect candidate for memoization.


Christopher W Brown
Last modified: Mon Mar 8 17:02:12 EST 2004