08 Jan 2016
I noticed this
article
by Tom Ellis today that provides an excellent ‘demystified’ introduction to
automatic differentiation. His exposition is exceptionally clear and simple.
Hopefully not in the spirit of re-mystifying things too much, I wanted to
demonstrate that this kind of forward-mode automatic differentiation can be
implemented using a catamorphism, which cleans up the various let
statements
found in Tom’s version (at the expense of slightly more pattern matching).
Let me first duplicate his setup using the standard recursion
scheme machinery:
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}
import Data.Functor.Foldable
data ExprF r =
VarF
| ZeroF
| OneF
| NegateF r
| SumF r r
| ProductF r r
| ExpF r
deriving (Show, Functor)
type Expr = Fix ExprF
Since my expression type uses a fixed-point wrapper I’ll define my own
embedded language terms to get around it:
var :: Expr
var = Fix VarF
zero :: Expr
zero = Fix ZeroF
one :: Expr
one = Fix OneF
neg :: Expr -> Expr
neg x = Fix (NegateF x)
add :: Expr -> Expr -> Expr
add a b = Fix (SumF a b)
prod :: Expr -> Expr -> Expr
prod a b = Fix (ProductF a b)
e :: Expr -> Expr
e x = Fix (ExpF x)
To implement a corresponding eval
function we can use a catamorphism:
eval :: Double -> Expr -> Double
eval x = cata $ \case
VarF -> x
ZeroF -> 0
OneF -> 1
NegateF a -> negate a
SumF a b -> a + b
ProductF a b -> a * b
ExpF a -> exp a
Very clear. We just match things mechanically.
Now, symbolic differentiation. If you refer to the original diff
function
you’ll notice that in cases like Product
or Exp
there are uses of both an
original expression and also its derivative. This can be captured simply by a
paramorphism:
diff :: Expr -> Expr
diff = para $ \case
VarF -> one
ZeroF -> zero
OneF -> zero
NegateF (_, x') -> neg x'
SumF (_, x') (_, y') -> add x' y'
ProductF (x, x') (y, y') -> add (prod x y') (prod x' y)
ExpF (x, x') -> prod (e x) x'
Here the primes indicate derivatives in the usual fashion, and the standard
rules of differentiation are self-explanatory.
For automatic differentiation we just do sort of the same thing, except we’re
also also going to lug around the evaluated function value itself at each point
and evaluate to doubles instead of other expressions.
It’s worth noting here: why doubles? Because the expression type that we’ve
defined has no notion of sharing, and thus the expressions will blow up à la
diff
(to see what I mean, try printing the analogue of diff bigExpression
in GHCi). This could probably be mitigated by incorporating sharing into the
embedded language in some way, but that’s a topic
for another post.
Anyway, a catamorphism will do the trick:
ad :: Double -> Expr -> (Double, Double)
ad x = cata $ \case
VarF -> (x, 1)
ZeroF -> (0, 0)
OneF -> (1, 0)
NegateF (x, x') -> (negate x, negate x')
SumF (x, x') (y, y') -> (x + y, x' + y')
ProductF (x, x') (y, y') -> (x * y, x * y' + x' * y)
ExpF (x, x') -> (exp x, exp x * x')
Take a look at the pairs to the right of the pattern matches; the first element
in each is just the corresponding term from eval
, and the second is just the
corresponding term from diff
(made ‘Double’-friendly). The catamorphism
gives us access to all the terms we need, and we can avoid a lot of work on
the right-hand side by doing some more pattern matching on the left.
Some sanity checks to make sure that these functions match up with Tom’s:
*Main> map (snd . (`ad` testSmall)) [0.0009, 1.0, 1.0001]
[0.12254834896191881,1.0,1.0003000600100016]
*Main> map (snd . (`ad` testBig)) [0.00009, 1.0, 1.00001]
[3.2478565715996756e-6,1.0,1.0100754777229357]
UPDATE:
I had originally defined ad
using a paramorphism but noticed that we can get
by just fine with cata.
09 Dec 2015
I’m presently at NIPS and so felt like writing about some
appropriate machine learning topic, but along the way I wound up talking about
parameterized recursive types, and here we are. Enjoy!
One starts to see common ‘shapes’ in algebraic data types after working with
them for a while. Take the natural numbers and a standard linked list, for
example:
data Natural =
One
| Succ Natural
data List a =
Empty
| Cons a (List a)
These are similar in some sense. There are some differences - a list has an
additional type parameter, and each recursive point in the list is tagged with
a value of that type - but the nature of the recursion in each is the same.
There is a single recursive point wrapped up in a single constructor, plus a
single base case.
Consider a recursive type that is parameterized by a functor with kind ‘* ->
*’, such that the kind of the resulting type is something like ‘(* -> *) ->
*’ or ‘(* -> *) -> * -> *’ or so on. It’s interesting to look at the
‘shapes’ of some useful types like this and see what kind of similarities and
differences in recursive structure that we can find.
In this article we’ll look at three such recursive types: ‘Fix’, ‘Free’, and
‘Cofree’. I’ll demonstrate that each can be viewed as a kind of program
parameterized by some underlying instruction set.
Fix
To start, let’s review the famous fixed-point type ‘Fix’. I’ve talked about it
before, but will go into a bit more detail here.
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecideableInstances #-}
newtype Fix f = Fix (f (Fix f))
deriving instance (Show (f (Fix f))) => Show (Fix f)
Note: I’ll omit interpreter output for examples throughout this article, but
feel free to try the code yourself in GHCi. I’ll post some gists at the
bottom. The above code block also contains some pragmas that you can ignore;
they’re just there to help GHC derive some instances for us.
Anyway. ‘Fix’ is in some sense a template recursive structure. It relies on
some underlying functor ‘f’ to define the scope of recursion that you can
expect a value with type ‘Fix f’ to support. There is the degenerate constant
case, for example, which supports no recursion:
data DegenerateF r = DegenerateF
deriving (Functor, Show)
type Degenerate = Fix DegenerateF
degenerate :: Degenerate
degenerate = Fix DegenerateF
Then you have the case like the one below, where only an infinitely recursive
value exists:
newtype InfiniteF r = InfiniteF r
deriving (Functor, Show)
type Infinite = Fix InfiniteF
infinite :: Infinite
infinite = Fix (InfiniteF infinite)
But in practice you’ll have something in between; a type with at least one
recursive point or ‘running’ case and also at least one base or ‘terminating’
case. Take the natural numbers, for example:
data NatF r =
OneF
| SuccF r
deriving (Functor, Show)
type Nat = Fix NatF
one :: Nat
one = Fix OneF
succ :: Nat -> Nat
succ = Fix . SuccF
Here ‘NatF’ provides both a ‘running’ case - ‘SuccF’ - and a ‘terminating’ case
in - ‘OneF’. ‘Fix’ just lets ‘NatF’ do whatever it wants, having no say of its
own about termination. In fact, we could have defined ‘Fix’ like this:
data Program f = Running (f (Program f))
Indeed, you can think of ‘Fix’ as defining a program that runs until ‘f’
decides to terminate. In turn, you can think of ‘f’ as an instruction
set for the program. The whole shebang of ‘Fix f’ may only terminate if ‘f’
contains a terminating instruction.
Here’s a simple set of instructions, for example:
data Instruction r =
Increment r
| Decrement r
| Terminate
deriving (Functor, Show)
increment :: Program Instruction -> Program Instruction
increment = Running . Increment
decrement :: Program Instruction -> Program Instruction
decrement = Running . Decrement
terminate :: Program Instruction
terminate = Running Terminate
And we can write a sort of stack-based program like so:
program :: Program Instruction
program =
increment
. increment
. decrement
$ terminate
Richness of ‘Fix’
It’s worthwhile to review two functions that are useful for working with ‘Fix’,
unimaginatively named ‘fix’ and ‘unfix’:
fix :: f (Fix f) -> Fix f
fix = Fix
unfix :: Fix f -> f (Fix f)
unfix (Fix f) = f
You can think of them like so: ‘fix’ embeds a value of type ‘f’ into a
recursive structure by adding a new layer of recursion, while ‘unfix’ projects
a value of type ‘f’ out of a recursive structure by peeling back a layer of
recursion.
This is a pretty rich recursive structure - we have a guarantee that we can
always embed into or project out of something with type ‘Fix f’, no matter
what ‘f’ is.
Free
Next up is ‘Free’, which is really just ‘Fix’ with some added structure. It is
defined as follows:
data Free f a =
Free (f (Free f a))
| Pure a
deriving Functor
deriving instance (Show a, Show (f (Free f a))) => Show (Free f a)
The ‘Free’ constructor has an analogous definition to the ‘Fix’ constructor,
meaning we can use ‘Free’ to implement the same things we did previously. Here
are the natural numbers redux, for example:
type NatFree = Free NatF
oneFree :: NatFree a
oneFree = Free OneF
succFree :: NatFree a -> NatFree a
succFree = Free . SuccF
There’s also another branch here called ‘Pure’, though, that just bluntly wraps
a value of type ‘a’, and has nothing to do with the parameter ‘f’. This has an
interesting consequence: it means that ‘Free’ can have an opinion of its own
about termination, regardless about what ‘f’ might decree:
type NotSoInfinite = Free InfiniteF
notSoInfinite :: NotSoInfinite ()
notSoInfinite = Free (InfiniteF (Free (InfiniteF (Pure ()))))
(Note that here I’ve returned the value of type unit when terminating under the
‘Pure’ branch, but you could pick whatever else you’d like.)
You’ll recall that ‘InfiniteF’ provides no terminating instruction,
and left to its own devices will just recurse endlessly.
So: instead of being forced to choose a branch of the underlying functor to
recurse on, ‘Free’ can just bail out on a whim and return some value wrapped up
in ‘Pure’. We could have defined the whole type like this:
data Program f a =
Running (f (Program f a))
| Terminated a
deriving Functor
Again, it’s ‘Fix’ with more structure. It’s a program that runs until ‘f’
decides to terminate, or that terminates and returns a value of type ‘a’
As a quick illustration, take our simple stack-based instruction set again. We
can define the following embedded language terms:
increment :: Program Instruction a -> Program Instruction a
increment = Running . Increment
decrement :: Program Instruction a -> Program Instruction a
decrement = Running . Decrement
terminate :: Program Instruction a
terminate = Running Terminate
sigkill :: Program f Int
sigkill = Terminated 1
So note that ‘sigkill’ is independent of whatever instruction set we’re working
with. We can thus write another simple program like before, except this time
have ‘sigkill’ terminate it:
program :: Program Instruction Int
program =
increment
. increment
. decrement
$ sigkill
Richness of ‘Free’
Try to define the equivalent versions of ‘fix’ and ‘unfix’ for ‘Free’. The
equivalent to ‘fix’ is easy:
free :: f (Free f a) -> Free f a
free = Free
You’ll hit a wall, though, if you want to implement the (total) analogue to
‘unfix’. One wants a function of type ‘Free f a -> f (Free f a)’, but the
existence of the ‘Pure’ branch makes this impossible to implement totally. In
general there is not going to be an ‘f’ to pluck out:
unfree :: Free f a -> f (Free f a)
unfree (Free f) = f
unfree (Pure a) = error "kaboom"
The recursion provided by ‘Free’ is thus a little less rich than that provided
by ‘Fix’. With ‘Fix’ one can always project a value out of its recursive
structure - but that’s not the case with ‘Free’.
It’s well-known that ‘Free’ is monadic, and indeed it’s usually called the
‘free
monad’.
The namesake ‘free’ comes from an algebraic definition; roughly, a free ‘foo’
is a ‘foo’ that satisfies the minimum possible constraints to make it a ‘foo’,
and nothing else. Check out the
slides
from Dan Piponi’s excellent talk from Bayhac a few years back for a deeper dive
on algebraic freeness.
Cofree
‘Cofree’ is also like ‘Fix’, but again with some extra structure. It can be
defined as follows:
data Cofree f a = Cofree a (f (Cofree f a))
deriving Functor
deriving instance (Show a, Show (f (Cofree f a))) => Show (Cofree f a)
Again, part of the definition - the second field of the ‘Cofree’ constructor -
looks just like ‘Fix’. So predictably we can do a redux-redux of the natural
numbers using ‘Cofree’:
type NatCofree = Cofree NatF
oneCofree :: NatCofree ()
oneCofree = Cofree () OneF
succFree :: NatCofree () -> NatCofree ()
succFree f = Cofree () (SuccF f)
(Note that here I’ve again used unit to fill in the first field - you could
of course choose whatever you’d like.)
This looks a lot like ‘Free’, and in fact it’s the categorical dual of
‘Free’. Whereas ‘Free’ is a sum type with two branches, ‘Cofree’ is a
product type with two fields. In the case of ‘Free’, we could have a program
that either runs an instruction from a set ‘f’, or terminates with a value
having type ‘a’. In the case of ‘Cofree’, we have a program that runs an
instruction from a set ‘f’ and returns a value of type ‘a’.
A ‘Free’ value thus contains at most one recursive point wrapping the value
with type ‘a’, while a ‘Cofree’ value contains potentially infinite recursive
points - each one of which is tagged with a value of type ‘a’.
Rolling with the ‘Program’ analogy, we could have written this alternate
definition for ‘Cofree’:
data Program f a = Program {
annotation :: a
, running :: f (Program f a)
} deriving Show
A ‘Cofree’ value is thus a program in which every instruction is annotated with
a value of type ‘a’. This means that, unlike ‘Free’, it can’t have its own
opinion on termination. Like ‘Fix’, it has to let ‘f’ decide how to do that.
We’ll use the stack-based instruction set example to wrap up. Here we can
annotate instructions with progress about how many instructions remain to
execute. First our new embedded language terms:
increment :: Program Instruction Int -> Program Instruction Int
increment p = Program (remaining p) (Increment p)
decrement :: Program Instruction Int -> Program Instruction Int
decrement p = Program (remaining p) (Decrement p)
terminate :: Program Instruction Int
terminate = Program 0 Terminate
Notice that two of these terms use a helper function ‘remaining’ that counts
the number of instructions left in the program. It’s defined as follows:
remaining :: Program Instruction Int -> Int
remaining = loop where
loop (Program a f) = case f of
Increment p -> succ (loop p)
Decrement p -> succ (loop p)
Terminate -> succ a
And we can write our toy program like so:
program :: Program Instruction Int
program =
increment
. increment
. decrement
$ terminate
Evaluate it in GHCi to see what the resulting value looks like.
Richness of ‘Cofree’
If you try and implement the ‘fix’ and ‘unfix’ analogues for ‘Cofree’ you’ll
rapidly infer that we have the opposite situation to ‘Free’ here. Implementing
the ‘unfix’ analogue is easy:
uncofree :: Cofree f a -> f (Cofree f a)
uncofree (Cofree _ f) = f
But implementing a total function corresponding to ‘fix’ is impossible - we
can’t just come up with something of arbitrary type ‘a’ to tag an instruction
‘f’ with, so, like before, we can’t do any better than define something
partially:
cofree :: f (Cofree f a) -> Cofree f a
cofree f = Cofree (error "kaboom") f
Just as how ‘Free’ forms a monad, ‘Cofree’ forms a comonad. It’s thus known as
the ‘cofree comonad’, though I can’t claim to really have any idea what the
algebraic notion of ‘cofreeness’ captures, exactly.
Wrapping Up
So: ‘Fix’, ‘Free’, and ‘Cofree’ all share a similar sort of recursive structure
that make them useful for encoding programs, given some instruction set. And
while their definitions are similar, ‘Fix’ supports the richest recursion of
the three in some sense - it can both ‘embed’ things into and ‘project’
things out of its recursive structure, while ‘Free’ supports only embedding and
‘Cofree’ supports only projecting.
This has a practical implication: it means one can’t make use of certain
recursion schemes for ‘Free’ and ‘Cofree’ in the same way that one can for
‘Fix’. There do exist analogues, but they’re sort of out-of-scope for this
post.
I haven’t actually mentioned any truly practical uses of ‘Free’ and ‘Cofree’
here, but they’re wonderful things to keep in your toolkit if you’re doing any
work with embedded languages, and I’ll likely write more about them in the
future. In the meantime, Dave Laing wrote an excellent series of
posts on ‘Free’ and
‘Cofree’ that are more than worth reading. They go into much more interesting
detail than I’ve done here - in particular he details a nice pairing that
exists between ‘Free’ and ‘Cofree’ (also
discussed by Dan
Piponi), plus a whack of examples.
You can also find industrial-strength infrastructure for both ‘Free’ and
‘Cofree’ in Edward Kmett’s excellent
free library, and for ‘Fix’ in
recursion-schemes.
I’ve dumped the code for this article into a few gists.
Here’s one of everything
excluding the running ‘Program’ examples, and here are the corresponding
‘Program’ examples for the
Fix,
Free, and
Cofree cases
respectively.
Thanks to Fredrik Olsen for review and great feedback.
02 Dec 2015
Merge sort is a famous
comparison-based sorting algorithm that starts by first recursively dividing a
collection of orderable elements into smaller subcollections, and then finishes
by recursively sorting and merging the smaller subcollections together to
reconstruct the (now sorted) original.
A clear implementation of mergesort should by definition be as faithful to that
high-level description as possible. We can get pretty close to that using the
whole recursion schemes
business that I’ve talked about in the past. Near the end of that article I
briefly mentioned the idea of implementing mergesort via a
hylomorphism,
and here I just want to elaborate on that a little.
Start with a collection of orderable elements. We can divide the collection
into a bunch of smaller collections by using a binary tree:
{-# LANGUAGE DeriveFunctor #-}
import Data.Functor.Foldable (hylo)
import Data.List.Ordered (merge)
data Tree a r =
Empty
| Leaf a
| Node r r
deriving Functor
The idea is that each node in the tree holds two subtrees, each of which
contains half of the remaining elements. We can build a tree like this from a
collection - say, a basic Haskell list. The following unfolder
function
defines what part of a tree to build for any corresponding part of a list:
unfolder [] = Empty
unfolder [x] = Leaf x
unfolder xs = Node l r where
(l, r) = splitAt (length xs `div` 2) xs
On the other hand, we can also collapse an existing tree back into a list. The
following folder
function defines how to collapse any given part of a tree
into the corresponding part of a list; again we just pattern match on whatever
part of the tree we’re looking at, and construct the complementary list:
folder Empty = []
folder (Leaf x) = [x]
folder (Node l r) = merge l r
Now to sort a list we can just glue these instructions together using
a hylomorphism:
mergesort :: Ord a => [a] -> [a]
mergesort = hylo folder unfolder
And it works just like you’d expect:
> mergesort [1,10,3,4,5]
[1,3,4,5,10]
> mergesort "aloha"
"aahlo"
> mergesort [True, False, False, True, False]
[False, False, False, True, True]
Pretty concise!
The code is eminently clean and faithful to the high-level algorithm
description: first recursively divide a collection into smaller subcollections
- via a binary tree and
unfolder
- and then recursively sort and merge the
subcollections to reconstruct the (now sorted) original one - via folder
.
A version of this post originally appeared on the Fugue
blog.