Menu

Class 31: Memoizing Matrix Chain Multiplication


Reading
Sections 15.2 of Introduction to Algorithms.

Homework
This Homework need not be turned in. I'd like to you to work on it, but don't spend too long if you don't see how to do it! Here is Problem C from this year's programming competition. Assume that Dorothy will play her cards in sorted (high-to-low) order, and that your cards are given in the input in sorted order too. Then try to write a nice little recursive Scheme program that figures out the maximum number of points Chester can recieve. Dorthy's hand (sorted high-to-low, remember) and your hand will be given as vectors. If you're sucessfull, try on some big hands and see how it performs.

Solving Matrix Chain Multiplication
Recall from last class that in computing the matrix product A1 A2 ... An the parenthesization we use can drastically affect the number of multiplications required. The "Matrix Chain Multiplication Problem" is this:
Given a sequence of matrices, what parenthesization gives the fewest multiplications in computing the matrix chain product?
This is a tricky problem. In a time honored tradition, let's try solving an easier problem first, a problem we'll call the "Fewest Multiplications" problem.
Given a sequence of matrices, what is the smallest number of multiplications required in computing the matrix chain product?
The Fewest Multiplications problem is easier, primarily because its answer is a number rather than a parethesization. It's not even clear how a parenthesiation would be represented! So we'll concentrate on the easier problem and try to expand our solution to it to solve the more difficult problem.

Finding the Fewest Multiplications
We want to create an algorithm fmm(A,i,j) that computes the fewest multiplications required to compute the product:
Ai Ai+1 ... Aj
The basic idea is to think of choosing an "outermost parenthesization" first. Like this:
(Ai Ai+1 ... Ak) (Ak+1 ... Aj)
Given that outermost parenthesization, the fewest multiplications required to compute the product is the fest multiplications required to compute the left-hand product, plus the fewest multiplications required to compute the right-hand product, plus the multiplications required to compute the product of the two result matrices. However, fewest number of multiplications required to compute the left and right hand products can be computed recursively, and we know the dimensions of the resulting matrices will be rows(Ai)xcols(Ak) and rows(Ak+1)xcols(Aj). So, given that outermost parenthesization, the fewest number of multiplications required to compute the product is:
fmm(A,i,k) + fmm(k+1,j) + rows(Ai)*cols(Ak)*cols(Aj)
Now all we need to do is try every possible outermost parenthesization and see which ultimately yields the fewest multiplications in computing the matrix product.

An Algorithm
So here's an algorithm implementing the above plan:
// Return the fewest mults required to compute Ai Ai+1 ... Aj
int fmm(A,i,j)
{
  // Base Case
  if (i == j) return 0;

  // Recursive Case
  int m = -1; // m is the fewest mults we've seen so far, or -1 if none seen so far

  // Try every outermost parenthesization (Ai Ai+1 ... Ak) (Ak+1 ... Aj)
  for(int k = i; k < j; ++k)
  {
    t = fmm(A,i,k) + fmm(k+1,j) + rows(Ai)*cols(Ak)*cols(Aj);
    if (m == -1 || t < m)
      m = t;
  }
  return m;
}
Trying it out
The file mcm.scm defines a function (random-mcm-problem n M) that generates a random vector of n matrices to be multiplied together, each dimension falling in the range 1..M. The vector it returns actually contains just the dimensions of matrices. For example:
> (random-mcm-problem 5 20)
#5((8 3) (3 15) (15 17) (17 14) (14 17))
Now we can test a scheme implementation of fmm quite easily. So ... here it is:
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; (fmm A i j) 
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(define (fmm A i j)
  (if (= i j)
      0
      (do ((m -1)         ; m tracks the fewst mults seen so far (-1 means we haven't seen any yet)
	   (k i (+ k 1))) ; k is the current split i..k and (k+1)..j
        
        ((= k j) m)       ; exit loop when k = j and return m
        
        (let ((t (+ (fmm A i k)                ;fewest mults to compute the left product
		    (fmm A (+ k 1) j)          ;fewest mults to compute the right product
		    (* (car (vector-ref A i))  ;# of mults to compute the outermost product
		       (cadr (vector-ref A k)) 
		       (cadr (vector-ref A j))))))
          (if (or (= m -1) (< t m)) 
              (set! m t))                      ;if current k gives fewer mults we have a new min!
	  ))))
Now we can run some tests to see how it performs. The "time" function is helpful for this. (time expression) returns the result computed by expression and displays the time required to compute it:
> (define prob1 (random-mcm-problem 10 20))
> prob1
#10((13 5) (5 12) (12 10) (10 4) (4 12) (12 6) (6 18) (18 13) (13 19) (19 2))
> (time (fmm prob1 0 9))
cpu time: 150 real time: 144 gc time: 0
1988
> (define prob2 (random-mcm-problem 15 20))
> prob2
#15((15 1) (1 3) (3 2) (2 5) (5 14) (14 15) (15 1) (1 1) (1 17) (17 4) (4 11) (11 13) (13 16) (16 10) (10 16))
> (time (fmm prob2 0 14))
cpu time: 36160 real time: 36905 gc time: 2060
1326
> (define prob3 (random-mcm-problem 20 20))
> prob3
#20((10 7)
    (7 9)
    (9 18)
    (18 11)
    (11 5)
    (5 9)
    (9 20)
    (20 20)
    (20 20)
    (20 18)
    (18 18)
    (18 4)
    (4 8)
    (8 19)
    (19 10)
    (10 9)
    (9 13)
    (13 11)
    (11 5)
    (5 18))
> (time (fmm prob3 0 19))
> ??? it never stopped!
You can see that as the number of matrices grows, the time requirements rapidly get out of hand! Turns out that we're computing the same thing over and over. For example, in computing (fmm prob2 0 14) we compute (fmm prob2 5 7) ??? times. Hopefully you've made the connection - this is a perfect candidate for memoization.
Memoizing fewest-matrix multiplications
Since fmm is so slow, and since it's slow on account of computing the same thing over and over again, we decided to memoize it. Memoization requires us to store the result of each fmm(A,i,j) so we can look it up later instead of recomputing it. Since A always stays the same, we need to remember something for each i,j value. When you have two indices like this, you typically use a table to store information. So that's what we'll use. In our table, T[i,j] will be the result of fmm(A,i,j), and -1 if that result has yet to be computed.
Original Recursive Algorithm Memoized Version
int fmm(A,i,j)
{  
  int m = -1;
  if (i == j) 
    m = 0;
  else
  {
    for(int k = i; k < j; ++k)
    {
      t = fmm(A,i,k) + fmm(k+1,j) 
          + rows(Ai)*cols(Ak)*cols(Aj);
      if (m == -1 || t < m)
        m = t;
    }
  }
  return m;
}
int fmm(A,i,j)
{  
  if (T[i,j] == -1)
  {
    int m = -1;
    if (i == j) 
      m = 0;
    else
    {
      for(int k = i; k < j; ++k)
      {
        t = fmm(A,i,k) + fmm(k+1,j) 
            + rows(Ai)*cols(Ak)*cols(Aj);
        if (m == -1 || t < m)
          m = t;
      }
    }
    T[i,j] = m;
  }
  return T[i,j];
}
Notice how the original fmm code (i.e. the code in blue) pops up in the memoized version.

All that's changed is that we check to see if fmm(A,i,j) is in the table T before we go through the work of computing it. What about the time complexity of the resulting algorithm? Well, this is a tricky algorithm to analyze. It's recursive, so you might be tempted to try and derive a recurence relation. On the other hand, we can never really say for sure whether a call takes constant time because an answer is in the table and when it won't, so deriving that recurrence relation will be tough. On the other-other hand, it's not an iterative algorithm, so those analysis techniques won't work either. The trick is to rely on the table.

The memoized algorithm fills in Θ(n^2) entries in the table T. We could analyze the algorithm by summing up for each table entry the time required by the call to fmm that filled in that entry. When determining the time to fill in a table entry, we can assume that the cost of calling fmm recursively is constant, because either it will literally be constant because the result is already in the table, or it can be treated as constant because its cost will be accounted for when we get to its entry in the table. So, if each recursive call is constant time, fmm(A,i,j) takes time O(n) [and Ω(1)]. Why? Because we go through our for-loop j-i times, and j-i < n. So, Θ(n^2) table entries, each taking O(n) [and Ω(1)] time to fill in, gives us a total of O(n^3) [and Ω(n^2)] for the whole algorithm. That might not seem really fast, but the original fmm took exponential time!

Memoizing fewest-matrix multiplications (a scheme implementation)
Using my implementation of tables for scheme (table.scm) I've implemented the memoized fmm we did in class: fmm-memo.scm. You can test it using the random matrix chain mulitplication generator in mcm.scm.
> (define A (random-mcm-problem 5 10))
> A
#5((8 5) (5 1) (1 8) (8 10) (10 10))
> (fmm-memo A)
300
> (table-print fewT)
0 40 104 200 300 
-1 0 40 130 230 
-1 -1 0 80 180 
-1 -1 -1 0 800 
-1 -1 -1 -1 0 
> (time (fmm-memo (random-mcm-problem 20 20)))
cpu time: 130 real time: 146 gc time: 0
1869
> 
	


Christopher W Brown
Last modified: Mon Mar 27 14:57:57 EST 2006