Tuesday, February 8, 2011

Factorial Function

I was going to include this in my post on functional looping alternatives, but I think this deserves its own post.  The factorial function, i.e. n!, is frequently used as an example function for a variety of things.  There are a number of wrong (er, poor) ways of implementing it, but only a few good ways. 

First, let's start at what is probably the most wrong way.  Ironically, this is often used to demonstrate recursion to intro-level CS students.  Might have something to do with why intro-level CS students often hate recursion.  The recursive definition is like so:
fac( 0 ) = 1
fac( 1 ) = 1
fac( n ) = n * fac( n - 1 )

Nice and simple.  In pure math land, beautiful.  In code land, not so much (using LISP):
(defun factorial (n)
  (if (or (= n 0)
          (= n 1)) 1
   (* n (factorial (- n 1)))))


Oh sure, it's short.  But it's brutally inefficient.  We don't actually start any calculations until we have generated all the values of n between n and 1, inclusive.  Worse yet, this space is on the stack.  Ick.  

So let's improve things a bit.  We can make this tail-recursive with an accumulator, or we can reach in the imperative toolbox.  I've actually shown the former elsewhere in my blog, so let's do the latter.  But instead of just generating the number and being done with it, let's go for something better.

Consider an actual program that frequently calls factorial for a variety of inputs.  Often times we see the same inputs, but we must generate the factorial (O(n)) for every call.  Why not save the result after computing it?  That way, every call after the first will be O(1). (Thanks to amortization, n calls will be O(n), instead of O(n^2).  Cool beans.).  This general trick is known as memoization (not memorization).  I've used it a number of times on different problems, and it can seriously improve performance if you do everything right.  

Without further ado, here is some Java code that utilizes the technique:
import java.util.*;

public class Factorial {

  private List< Long > memo;
  public Factorial() {
    memo = new ArrayList< Long >();
    memo.add( 0L );
    memo.add( 1L );
  }
  public long factorial( int n ) {
    int current = memo.size();
    while( current <= n ) {
        memo.add( current * memo.get( current - 1 ).longValue() );
        current++;
    }
    return memo.get( n ).longValue();
  }

}

The while loop within factorial only does anything when we haven't already generated a value for the given value of n.  In this way, we take advantage of previously used values.  Note that in the general case of memoization, one uses a map instead of a list.  In such a case, inputs are keys and their mappings are the values.  For this, since the inputs were integers, and that one must have all lesser inputs calculated for input n, we were able to get away with a list.


So how does all of this relate to functional looping alternatives?  Well, for certain mathematical functions, they shine like none other in terms of conciseness.  Factorial is one of these.  The following gets the factorial of n, in Scala:
( 1 to n ).foldRight( 1 )( _ * _ )

...that's it.  It's shorter than the mathematical definition, for crying out loud.  This first generates all the numbers between 1 and n inclusive.  It then does a fold, going from the highest numbers at the right to the lowest numbers on the left, constantly multiplying them.  The given 1 acts as a default value, catching 0 (0! = 1).  In all other cases, it's a harmless no-op, since 1 * n = n.

This is just as bad as the original definition, if not moreso.  Again, we explicitly generate all numbers between 1 and n.  Granted, now we generate them on the heap, but we still make them.  However, foldRight isn't tail recursive.  It has to go all the way to the end of the list before it can do any processing, so it has the same problem as the original definition.  Only now it's memory usage is O(n) for BOTH the heap and the stack!

Turns out we can do better.  Why start from the right when we can just reverse it and start from the left?  Reversing doesn't need the stack, and folding from the left is O(1) with respect to the stack (it's tail-recursive), so doing this is more efficient:
( 1 to n ).reverse.foldLeft( 1 )( _ * _ )

Not as pretty as before, but now it's O(1) on the stack and O(n)on the heap, which is still better than the original definition which is O(n) on the stack and O(1)on the heap.  

Note that if you actually have to use the factorial in your code, memoization is probably the way to go, or at least a method that is O(1) for all memory.  This was just so short and so cute that I couldn't pass it up.

No comments:

Post a Comment