Yo Dawg We Heard You Like Derivatives

I noticed this article by Tom Ellis today that provides an excellent ‘demystified’ introduction to automatic differentiation. His exposition is exceptionally clear and simple.

Hopefully not in the spirit of re-mystifying things too much, I wanted to demonstrate that this kind of forward-mode automatic differentiation can be implemented using a catamorphism, which cleans up the various let statements found in Tom’s version (at the expense of slightly more pattern matching).

Let me first duplicate his setup using the standard recursion scheme machinery:

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}

import Data.Functor.Foldable

data ExprF r =
    VarF
  | ZeroF
  | OneF
  | NegateF r
  | SumF r r
  | ProductF r r
  | ExpF r
  deriving (Show, Functor)

type Expr = Fix ExprF

Since my expression type uses a fixed-point wrapper I’ll define my own embedded language terms to get around it:

var :: Expr
var = Fix VarF

zero :: Expr
zero = Fix ZeroF

one :: Expr
one = Fix OneF

neg :: Expr -> Expr
neg x = Fix (NegateF x)

add :: Expr -> Expr -> Expr
add a b = Fix (SumF a b)

prod :: Expr -> Expr -> Expr
prod a b = Fix (ProductF a b)

e :: Expr -> Expr
e x = Fix (ExpF x)

To implement a corresponding eval function we can use a catamorphism:

eval :: Double -> Expr -> Double
eval x = cata $ \case
  VarF         -> x
  ZeroF        -> 0
  OneF         -> 1
  NegateF a    -> negate a
  SumF a b     -> a + b
  ProductF a b -> a * b
  ExpF a       -> exp a

Very clear. We just match things mechanically.

Now, symbolic differentiation. If you refer to the original diff function you’ll notice that in cases like Product or Exp there are uses of both an original expression and also its derivative. This can be captured simply by a paramorphism:

diff :: Expr -> Expr
diff = para $ \case
  VarF                     -> one
  ZeroF                    -> zero
  OneF                     -> zero
  NegateF (_, x')          -> neg x'
  SumF (_, x') (_, y')     -> add x' y'
  ProductF (x, x') (y, y') -> add (prod x y') (prod x' y)
  ExpF (x, x')             -> prod (e x) x'

Here the primes indicate derivatives in the usual fashion, and the standard rules of differentiation are self-explanatory.

For automatic differentiation we just do sort of the same thing, except we’re also also going to lug around the evaluated function value itself at each point and evaluate to doubles instead of other expressions.

It’s worth noting here: why doubles? Because the expression type that we’ve defined has no notion of sharing, and thus the expressions will blow up à la diff (to see what I mean, try printing the analogue of diff bigExpression in GHCi). This could probably be mitigated by incorporating sharing into the embedded language in some way, but that’s a topic for another post.

Anyway, a catamorphism will do the trick:

ad :: Double -> Expr -> (Double, Double)
ad x = cata $ \case
  VarF                     -> (x, 1)
  ZeroF                    -> (0, 0)
  OneF                     -> (1, 0)
  NegateF (x, x')          -> (negate x, negate x')
  SumF (x, x') (y, y')     -> (x + y, x' + y')
  ProductF (x, x') (y, y') -> (x * y, x * y' + x' * y)
  ExpF (x, x')             -> (exp x, exp x * x')

Take a look at the pairs to the right of the pattern matches; the first element in each is just the corresponding term from eval, and the second is just the corresponding term from diff (made ‘Double’-friendly). The catamorphism gives us access to all the terms we need, and we can avoid a lot of work on the right-hand side by doing some more pattern matching on the left.

Some sanity checks to make sure that these functions match up with Tom’s:

*Main> map (snd . (`ad` testSmall)) [0.0009, 1.0, 1.0001]
[0.12254834896191881,1.0,1.0003000600100016]
*Main> map (snd . (`ad` testBig)) [0.00009, 1.0, 1.00001]
[3.2478565715996756e-6,1.0,1.0100754777229357]

UPDATE:

I had originally defined ad using a paramorphism but noticed that we can get by just fine with cata.