Foundations of the Giry Monad

The Giry monad is the canonical probability monad that operates on the level of measures, which are the abstract constructs that canonically represent probability distributions. It’s sort of the baseline by which all other probability monads can be judged.

In this article I’m going to go through the categorical and measure-theoretic foundations of the Giry monad. In another article, I’ll describe how you can implement it in a very faithful sense in Haskell.

I was putting some notes together for another project and wound up writing up things up in a somewhat blog-friendly style, but this isn’t intended to be a tutorial per se. Really this isn’t the kind of content I’d usually post here, but since I’ve jotted everything up, I figured I may as well. If you like extremely dry mathematics and computer science, you’re in the right place.

I won’t define everything under the sun here - for properties or coherence conditions or other things that I’ve elided details on, check out something like Mac Lane or Aliprantis & Border. I’ll include some references at the end.

This is the game plan we’re working with:

  • Define monads and their supporting machinery in a categorical sense.
  • Define probability measures and some required background around that.
  • Construct the functor that maps a measurable space to the collection of all probability measures on that space.
  • Demonstrate that it’s a monad.

Let’s get started.

Categorical Foundations

A category is a collection of objects and morphisms between them. So if , , , and are objects in , then , , and are examples of morphisms. These morphisms can be composed in the obvious associative way, i.e.

and there exist identity morphisms (or automorphisms) that simply map objects to themselves.

A functor is a mapping between categories (equivalently, it’s a morphism in the category of so-called ‘small’ categories). The functor takes every object in to some object in , and every morphism in to some morphism in , such that the structure of morphism composition is preserved. An endofunctor is a functor from a category to itself, and a bifunctor is a functor from a pair of categories to another category, i.e. .

A natural transformation is a mapping between functors. So for two functors , a natural transformation associates to every object in a morphism in .

A monoidal category is a category with some additional monoidal structure, namely an identity object and a bifunctor called the tensor product, plus several natural isomorphisms that provide the associativity of the tensor product and its right and left identity with the identity object .

A monoid in a monoidal category is an object in together with two morphisms (obeying the standard associativity and identity properties) that make use of the category’s monoidal structure: the associative binary operator , and the identity .

A monad is (infamously) a ‘monoid in the category of endofunctors’. So take the category of endofunctors whose objects are endofunctors and whose morphisms are natural transformations between them. This is a monoidal category; there exists an identity endofunctor for all in , plus a tensor product defined by functor composition such that the required associativity and identity properties hold. is thus a monoidal category, and any specific monoid we construct on it is a specific monad.

Probabilistic Foundations

A measurable space is a set equipped with a topology-like structure called a -algebra that essentially contains every well-behaved subset of in some sense. A measure is a particular kind of set function from the -algebra to the nonnegative real line. A measure just assigns a generalized notion of area or volume to well-behaved subsets of . In particular, if the total possible area or volume of the underlying set is 1 then we’re dealing with a probability measure. A measurable space completed with a measure, e.g. is called a measure space, and a measurable space completed with a probability measure is called a probability space.

There is a lot of overloaded lingo around the word ‘measurable’. A ‘measurable set’ is an element of a -algebra in a measurable space. A measurable mapping is a mapping between measurable spaces. Given a ‘source’ measurable space and ‘target’ measurable space , a measurable mapping is a map with the property that, for any measurable set in the target, the inverse image is measurable in the source. Or, formally, for any in , you have that is in .

The Space of Probability Measures on a Measurable Space

If you consider the collection of all measurable spaces and measurable mappings between them, you get a category. Define to be the category of measurable spaces. So, objects are measurable spaces and morphisms are the measurable mappings between them.

For any specific measurable space in , we can consider the space of all possible probability measures that could be placed on it and denote that . To be clear, is a space of measures - that is, a space in which the points themselves are probability measures.

What’s remarkable about is that it is itself a measurable space. Let me explain.

As a probability measure, any element of is a function from measurable subsets of to the interval in . That is: if is the measurable space , then a point in is a function . For any measurable in , there just naturally exists a sort of ‘evaluation’ mapping I’ll call that takes a measure on and evaluates it on the set . To be explicit: if is a measure in , then simply evaluates . It ‘runs’ the measure in a sense; in Haskell, would be analogous to a function like \f -> f a for some a.

This evaluation map corresponds to an integral. If you have a measurable space , then for any a subset in , for the characteristic or indicator function of (where is if is in , and is otherwise). And we can actually extend to operate over measurable mappings from to , where is a suitable -algebra on . Here we typically use what’s called the Borel -algebra, which takes a topology on the set and then generates a -algebra from the open sets in the topology (for we can just use the ‘usual’ topology generated by the Euclidean metric). For a measurable function, we can define the evaluation mapping as .

We can abuse notation here a bit and just use to refer to ‘duck typed’ mappings that evaluate measures over measurable sets or measurable functions depending on context. If we treat as a function , then has type . If we treat as a function , then has type . I’ll say to refer to the mappings that accept either measurable sets or functions.

In any case. For a measurable space , there exists a topology on called the weak-* topology that makes all the evaluation mappings continuous for any measurable set or measurable function . From there, we can generate the Borel -algebra that makes the evaluation functions measurable. The result is that is itself a measurable space, and thus an object in .

The space actually has all sorts of insane properties that one wouldn’t expect - there are implications on convexity, completeness, compactness and such that carry over from . But I digress.

is a Functor

So: for any an object in , we have that is also an object in . And if you look at like a functor, you notice that it takes objects of to objects of . Indeed, you can define an analogous procedure on morphisms in as follows. Take to be another object (read: measurable space) in and to be a morphism (read: measurable mapping) between them. Now, for any measure in we can define (this is called the image, distribution, or pushforward of under ). For some and , thus takes measurable sets in to a value in the interval - that is, it is a measure on . So we have that:

and so is an endofunctor on .

is a Monad

See where we’re going here? If we can define natural transformations and such that is a monoid in the category of endofunctors, we’ll have defined a monad. We thus need to come up with a suitable monoidal structure, et voilà.

First the identity. We want a natural transformation between the identity functor and the functor such that for any measurable space in . Evaluating the identity functor simplifies things to .

We can define this concretely as follows. Grab a measurable space in and define for any point and any measurable set . is thus a probability measure on - we assign to measurable sets that contain , and 0 to those that don’t. If we peel away another argument, we have that , as required.

So takes points in measurable spaces to probability measures on those spaces. In technical parlance, it takes a point to the Dirac measure at - the probability measure that places the entirety of its mass at .

Now for the other part of the monoidal structure, . I initially found this next part to be a bit of a mind fuck, but let me see what I can do about that.

Recall that the category of endofunctors, , is monoidal, so there exists a tensor product that we can deal with, which here just corresponds to functor composition. We’re looking for a natural transformation:

which is often written as:

Take a measurable space in and then consider the space of probability measures over it, . Then take the space of probability measures over the space of probability measures on , . Since is an endofunctor, this is again a measurable space, and for any measurable subset of we again have a family of mappings that take a probability measure in and evaluate it on . We want to be the thing that turns a measure over measures into a plain old probability measure on .

In the context of probability theory, this kind of semigroup action is a marginalizing operator. We’re taking the ‘uncertainty’ captured in via the probability measure and smearing it into the probability measures in .

Take in and some a measurable subset of . We can define as follows:

Using some lambda calculus notation to see the argument for , we can expand the integrals to get the following gnarly expression:

Notice what’s happening here. For a measurable space, we’re integrating over the space of probability measures on , with respect to the probability measure , which itself is a point in the space of probability measures over probability measures on , . Whew.

The spaces we’re integrating over here are unusual, but is still a probability measure, so when applied to a measurable set in it results in a probability in . So, peeling back an argument, we have that has type . In other words, it’s a probability measure on , and thus is in . And if we peel back another argument, we find that:

so, as required, that

It’s also worth noting that we can overload the notation for in the same way we did for , i.e. to supply measurable functions in addition to measurable sets:

Combining the three components, we get , the canonical Giry monad.

In Haskell, when we’re dealing with monads we typically use the bind operator instead of manually dealing with the functorial structure and (called ‘join’). Bind has the type:

and for illustration, we can define for the Giry monad like so:

Here is in , is in , and is in , so note that we potentially simplify the outermost integral enormously. It now operates over a general measurable space, rather than a space of measures in particular, and this will come in handy when we get to implementation details in the next post.

Wrapping Up

That’s about it for now. It’s worth noting as a kind of footnote here that the existence of the Giry monad also obviously implies the existence of a Giry applicative functor. But the official situation for applicative functors seems kind of weird in this context, and I’m not yet up to the task of dealing with it formally.

Intuitively, one should be able to define the binary applicative operator characteristic of its lax monoidal structure as follows:

But this has some really weird measure-theoretic implications - namely, that it assumes the existence of a space of probability measures over the space of all measurable functions , which is not trivial to define and indeed may not even exist. It seems like some people are looking into this problem as I just happened to stumble on this paper on the arXiv while doing some googling. I notice that some people on e.g. nLab require categories with additional structure beyond for the development of the Giry monad as well, for example the category of Polish (separable, completely metrizable) spaces , so maybe the extra structure there takes care of the quirks.

Anyway. Applicatives are neat here because applicative probability measures are independent probability measures. And the existence of applicativeness means you can do all the things with independent probability measures that you might be used to. Measure convolution and friends are good examples. Given a measurable space that supports some notion of addition and two probability measures and in , we can add measures together via:

where and are both points in . Subtraction and multiplication translate trivially as well.

In another article I’ll detail how the Giry monad can be implemented in Haskell and point out some neat extensions. There are some cool connections to continuations and codensity monads, and seemingly de Finetti’s theorem and exchangeability. That kind of thing. It’d also be worth trying to justify independence of probability measures from a categorical perspective, which seems easier than resolving the nitty-gritty measurability qualms I mentioned above.

‘Til then! Thanks to Jason Forbes for sifting through this stuff and providing some great comments.

References:

Rotating Squares

Here’s a short one.

I use Colin Percival’s Hacker News Daily to catch the top ten articles of the day on Hacker News. Today an article called Why Recursive Data Structures? popped up, which illustrates that recursive algorithms can become both intuitive and borderline trivial when a suitable data structure is used to implement them. This is exactly the motivation for using recursion schemes.

In the above article, Reginald rotates squares by representing them via a quadtree. If we have a square of bits, something like:

.x..
..x.
xxx.
....

then we want to be able to easily rotate it 90 degrees clockwise, for example. So let’s define a quadtree in Haskell:

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}

import Data.Functor.Foldable
import Data.List.Split

data QuadTreeF a r =
    NodeF r r r r
  | LeafF a
  | EmptyF
  deriving (Show, Functor)

type QuadTree a = Fix (QuadTreeF a)

The four fields of the ‘NodeF’ constructor correspond to the upper left, upper right, lower right, and lower left quadrants of the tree respectively.

Gimme some embedded language terms:

node :: QuadTree a -> QuadTree a -> QuadTree a -> QuadTree a -> QuadTree a
node ul ur lr ll = Fix (NodeF ul ur lr ll)

leaf :: a -> QuadTree a
leaf = Fix . LeafF

empty :: QuadTree a
empty = Fix EmptyF

That lets us define quadtrees easily. Here’s the tree that the previous diagram corresponds to:

tree :: QuadTree Bool
tree = node ul ur lr ll where
  ul = node (leaf False) (leaf True) (leaf False) (leaf False)
  ur = node (leaf False) (leaf False) (leaf False) (leaf True)
  lr = node (leaf True) (leaf False) (leaf False) (leaf False)
  ll = node (leaf True) (leaf True) (leaf False) (leaf False)

Rotating is then really easy - we rotate each quadrant recursively. Just reach for a catamorphism:

rotate :: QuadTree a -> QuadTree a
rotate = cata $ \case
  NodeF ul ur lr ll -> node ll ul ur lr
  LeafF a           -> leaf a
  EmptyF            -> empty

Notice that you just have to shift each field of ‘NodeF’ rightward, with wraparound. Then if you rotate and render the original tree you get:

.x..
.x.x
.xx.
....

Rotating things more times yields predictable results.

If you want to rotate another structure - say, a flat list - you can go through a quadtree as an intermediate representation using the same pattern I described in Sorting with Style. Build yourself a coalgebra and algebra pair:

builder :: [a] -> QuadTreeF a [a]
builder = \case
  []  -> EmptyF
  [x] -> LeafF x
  xs  -> NodeF a b c d where
    [a, b, c, d] = chunksOf (length xs `div` 4) xs

consumer :: QuadTreeF a [a] -> [a]
consumer = \case
  EmptyF            -> []
  LeafF a           -> [a]
  NodeF ul ur lr ll -> concat [ll, ul, ur, lr]

and then glue them together with a hylomorphism:

rotateList :: [a] -> [a]
rotateList = hylo consumer builder

Neato.

For a recent recursion scheme resource I’ve spotted on the Twitters, check out Pascal Hartig’s compendium in progress.

Promorphisms, Pre and Post

To the.. uh, ‘layperson’, pre- and postpromorphisms are probably well into the WTF category of recursion schemes. This is a mistake - they’re simple and useful, and I’m going to try and convince you of this in short order.

Preliminaries:

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}

import Data.Functor.Foldable
import Prelude hiding (sum)

For simplicity, let’s take a couple of standard interpreters on lists. We’ll define ‘sumAlg’ as an interpreter for adding up list contents and ‘lenAlg’ for just counting the number of elements present:

sumAlg :: Num a => ListF a a -> a
sumAlg = \case
  Cons h t -> h + t
  Nil      -> 0

lenAlg :: ListF a Int -> Int
lenAlg = \case
  Cons h t -> 1 + t
  Nil      -> 0

Easy-peasy. We can use cata to make these things useful:

sum :: Num a => [a] -> a
sum = cata sumAlg

len :: [a] -> Int
len = cata lenAlg

Nothing new there; ‘sum [1..10]’ will give you 55 and ‘len [1..10]’ will give you 10.

An interesting twist is to consider only small elements in some sense; say, we only want to add or count elements that are less than or equal to 10, and ignore any others.

We could rewrite the previous interpreters, manually checking for the condition we’re interested in and handling it accordingly:

smallSumAlg :: (Ord a, Num a) => ListF a a -> a
smallSumAlg = \case
  Cons h t ->
    if   h <= 10
    then h + t
    else 0
  Nil      -> 0

smallLenAlg :: (Ord a, Num a) => ListF a Int -> Int
smallLenAlg = \case
  Cons h t ->
    if   h <= 10
    then 1 + t
    else 0
  Nil      -> 0

And you get ‘smallSum’ and ‘smallLen’ by using ‘cata’ on them respectively. They work like you’d expect - ‘smallLen [1, 5, 20]’ ignores the 20 and just returns 2, for example.

You can do better though. Enter the prepromorphism.

Instead of writing additional special-case interpreters for the ‘small’ case, consider the following natural transformation on the list base functor. It maps the list base functor to itself, without needing to inspect the carrier type:

small :: (Ord a, Num a) => ListF a b -> ListF a b
small Nil = Nil
small term@(Cons h t)
  | h <= 10   = term
  | otherwise = Nil

A prepromorphism is a ‘cata’-like recursion scheme that proceeds by first applying a natural transformation before interpreting via a supplied algebra. That’s.. surprisingly simple. Here are ‘smallSum’ and ‘smallLen’, defined without needing to clumsily create new special-case algebras:

smallSum :: (Ord a, Num a) => [a] -> a
smallSum = prepro small sumAlg

smallLen :: (Ord a, Num a) => [a] -> Int
smallLen = prepro small lenAlg

They work great:

> smallSum [1..100]
55
> smallLen [1..100]
10

In pseudo category-theoretic notation you visualize how a prepromorphism works via the following commutative diagram:

The only difference, when compared to a standard catamorphism, is the presence of the natural transformation applied via the looping arrow in the top left. The natural transformation ‘h’ has type ‘forall r. Base t r -> Base t r’, and ‘embed’ has type ‘Base t t -> t’, so their composition gets you exactly the type you need for an algebra, which is then the input to ‘cata’ there. Mapping the catamorphism over the type ‘Base t t’ brings it right back to ‘Base t t’.

A postpromorphism is dual to a prepromorphism. It’s ‘ana’-like; proceed with your corecursive production, applying natural transformations as you go.

Here’s a streaming coalgebra:

streamCoalg :: Enum a => a -> ListF a a
streamCoalg n = Cons n (succ n)

A normal anamorphism would just send this thing shooting off into infinity, but we can use the existing ‘small’ natural transformation to cap it at 10:

smallStream :: (Ord a, Num a, Enum a) => a -> [a]
smallStream = postpro small streamCoalg

You get what you might expect:

> smallStream 3
[3,4,5,6,7,8,9,10]

And similarly, you can visualize a postpromorphism like so:

In this case the natural transformation is applied after mapping the postpromorphism over the base functor (hence the ‘post’ namesake).

Comonadic Markov Chain Monte Carlo

Some time ago I came across a way to in-principle perform inference on certain probabilistic programs using comonadic structures and operations.

I decided to dig it up and try to use it to extend the simple probabilistic programming language I talked about a few days ago with a stateful, experimental inference backend. In this post we’ll

  • Represent probabilistic programs as recursive types parameterized by a terminating instruction set.
  • Represent execution traces of probabilistic programs via a simple transformation of our program representation.
  • Implement the Metropolis-Hastings algorithm over this space of execution traces and thus do some inference.

Let’s get started!

Representing Programs That Terminate

I like thinking of embedded languages in terms of instruction sets. That is: I want to be able to construct my embedded language by first defining a collection of abstract instructions and then using some appropriate recursive structure to represent programs over that set.

In the case of probabilistic programs, our instructions are probability distributions. Last time we used the following simple instruction set to define our embedded language:

data ModelF r =
    BernoulliF Double (Bool -> r)
  | BetaF Double Double (Double -> r)
  deriving Functor

We then created an embedded language by just wrapping it up in the higher-kinded Free type to denote programs of type Model.

data Free f a =
    Pure a
  | Free (f (Free f a))

type Model = Free ModelF

Recall that Free represents programs that can terminate, either by some instruction in the underlying instruction set, or via the Pure constructor of the Free type itself. The language defined by Free ModelF is expressive enough to easily construct a ‘forward-sampling’ interpreter, as well as a simple rejection sampler for performing inference.

Notice that we don’t have a terminating instruction in ModelF itself - if we’re using it, then we need to rely on the Pure constructor of Free to terminate programs. Otherwise they’d just have to recurse forever. This can be a bit limiting if we want to transform a program of type Free ModelF to something else that doesn’t have a notion of termination baked-in (Fix, for example).

Let’s tweak the ModelF type to get the following:

data ModelF a r =
    BernoulliF Double (Bool -> r)
  | BetaF Double Double (Double -> r)
  | NormalF Double Double (Double -> r)
  | DiracF a
  deriving Functor

Aside from adding another foundational distribution - NormalF - we’ve also added a new constructor, DiracF, which carries a parameter with type a. We need to incorporate this carrier type in the overall type of ModelF as well, so ModelF itself also gets a new type parameter to carry around.

The DiracF instruction is a terminating instruction; it has no recursive point and just terminates with a value of type a when reached. It’s structurally equivalent to the Pure a branch of Free that we were relying on to terminate our programs previously - the only thing we’ve done is add it to our instruction set proper.

Why DiracF? A Dirac distribution places the entirety of its probability mass on a single point, and this is the exact probabilistic interpretation of the applicative pure or monadic return that one encounters with an appropriate probability type. Intuitively, if I sample a value from a uniform distribution, then that is indistinguishable from sampling from said uniform distribution and then sampling from a Dirac distribution with parameter .

Make sense? If not, it might be helpful to note that there is no difference between any of the following (to which uniform and dirac are analogous):

> action :: m a
> action >>= return :: m a
> action >>= return >>= return >>= return :: m a

Wrapping ModelF a up in Free, we get the following general type for our programs:

type Program a = Free (ModelF a)

And we can construct a bunch of embedded language terms in the standard way:

beta :: Double -> Double -> Program a Double
beta a b = liftF (BetaF a b id)

bernoulli :: Double -> Program a Bool
bernoulli p = liftF (BernoulliF p id)

normal :: Double -> Double -> Program a Double
normal m s = liftF (NormalF m s id)

dirac :: a -> Program a b
dirac x = liftF (DiracF x)

Program is a general type, capturing both terminating and nonterminating programs via its type parameters. What do I mean by this? Note that in Program a b, the a type parameter can only be concretely instantiated via use of the terminating dirac term. On the other hand, the b type parameter is unaffected by the dirac term; it can only be instantiated by the other nonterminating terms: beta, bernoulli, normal, or compound expressions of these.

We can thus distinguish between terminating and nonterminating programs at the type level, like so:

type Terminating a = Program a Void

type Model b = forall a. Program a b

Void is the uninhabited type, brought into scope via Data.Void or simply defined via data Void = Void Void. Any program that ends via a dirac instruction must be Terminating, and any program that doesn’t end with a dirac instruction can not be Terminating. We’ll just continue to call a nonterminating program a Model, as before.

Good. So if it’s not clear: from a user’s perspective, nothing has changed. We still write probabilistic programs using simple monadic language terms. Here’s a Gaussian mixture model where the mixing parameter follows a beta distribution, for example:

mixture :: Double -> Double -> Model Double
mixture a b = do
  prob   <- beta a b
  accept <- bernoulli prob
  if   accept
  then normal (negate 2) 0.5
  else normal 2 0.5

Meanwhile the syntax tree generated looks something like the following. It’s more or less a traditional probabilistic graphical model description of our program:

It’s important to note that in this embedded framework, the only pieces of the syntax tree that we can observe are those related directly to our primitive instructions. For our purposes this is excellent - we can focus on programs entirely at the level of their probabilistic components, and ignore the deterministic parts that would otherwise be distractions.

To collect samples from mixture, we can first interpret it into a sampling function, and then simulate from it. The toSampler function from last time doesn’t change much:

toSampler :: Program a a -> Prob IO a
toSampler = iterM $ \case
  BernoulliF p f -> Prob.bernoulli p >>= f
  BetaF a b f    -> Prob.beta a b >>= f
  NormalF m s f  -> Prob.normal m s >>= f
  DiracF x       -> return x

Sampling from mixture 2 3 a thousand times yields the following

> simulate (toSampler (mixture 2 3))

Note that the rightmost component gets more traffic due to the hyperparameter combination of 2 and 3 that we provided to mixture.

Also, a note - since we have general recursion in Haskell, so-called ‘terminating’ programs here can actually.. uh, fail to terminate. They must only terminate as far as we can express the sentiment at the embedded language level. Consider the following, for example:

foo :: Terminating a
foo = (loop 1) >>= dirac where
  loop a = do
    p <- beta a 1
    loop p

foo here doesn’t actually terminate. But at least this kind of weird case can be picked up in the types:

> :t simulate (toSampler foo)
simulate (toSampler foo) :: IO Void

If you try to sample from a distribution over Void or forall a. a then I can’t be held responsible for what you get up to. But there are other cases, sadly, where we’re also out of luck:

trollGeometric :: Double -> Model Int
trollGeometric p = loop where
  loop = do
    accept <- return False
    if   accept
    then return 1
    else fmap succ loop

A geometric distribution that actually used its argument , for , could be guaranteed to terminate with probability 1. This one doesn’t, so trollGeometric undefined >>= dirac won’t.

At the end of the day we’re stuck with what our host language offers us. So, take the termination guarantees for our embedded language with a grain of salt.

Stateful Inference

In the previous post we used a simple rejection sampler to sample from a conditional distribution. ‘Vanilla’ Monte Carlo algorithms like rejection and importance sampling are stateless. This makes them nice in some ways - they tend to be simple to implement and are embarrassingly parallel, for example. But the curse of dimensionality prevents them from scaling well to larger problems. I won’t go into detail on that here - for a deep dive on the topic, you probably won’t find anything better than this phenomenal couple of talks on MCMC that Iain Murray gave at a MLSS session in Cambridge in 2009. I think they’re unparalleled to this day.

The point is that in higher dimensions we tend to get a lot out of state. Essentially, if one finds an interesting region of high-dimensional parameter space, then it’s better to remember where that is, rather than forgetting it exists as soon as one stumbles onto it. The manifold hypothesis conjectures that interesting regions of space tend to be near other interesting regions of space, so exploring neighbourhoods of interesting places tends to pay off. Stateful Monte Carlo methods - namely, the family of Markov chain Monte Carlo algorithms - handle exactly this, by using a Markov chain to wander over parameter space. I’ve written on MCMC in the past - you can check out some of those articles if you’re interested.

In the stateless rejection sampler we just performed conditional inference via the following algorithm:

  • Sample from a parameter model.
  • Sample from a data model, using the sample from the parameter model as input.
  • If the sample from the data model matches the provided observations, return the sample from the parameter model.

By repeating this many times we get a sample of arbitrary size from the appropriate conditional, inverse, or posterior distribution (whatever you want to call it).

In a stateful inference routine - here, the good old Metropolis-Hastings algorithm - we’re instead going to do the following repeatedly:

  • Sample from a parameter model, recording the way the program executed in order to return the sample that it did.
  • Compute the cost, in some sense, of generating the provided observations, using the sample from the parameter model as input.
  • Propose a new sample from the parameter model by perturbing the way the program executed and recording the new sample the program outputs.
  • Compute the cost of generating the provided observations using this new sample from the parameter model as input.
  • Compare the costs of generating the provided observations under the respective samples from the parameter models.
  • With probability depending on the ratio of the costs, flip a coin. If you see a head, then move to the new, proposed execution trace of the program. Otherwise, stay at the old execution trace.

This procedure generates a Markov chain over the space of possible execution traces of the program - essentially, plausible ways that the program could have executed in order to generate the supplied observations.

Implementations of Church use variations of this method to do inference, the most famous of which is a low-overhead transformational compilation procedure described in a great and influential 2011 paper by David Wingate et al.

Representing Running Programs

To perform inference on probabilistic programs according to the aforementioned Metropolis-Hastings algorithm, we need to represent executing programs somehow, in a form that enables us to examine and modify their internal state.

How can we do that? We’ll pluck another useful recursive structure from our repertoire and consider the humble Cofree:

data Cofree f a = a :< f (Cofree f a)

Recall that Cofree allows one to annotate programs with arbitrary information at each internal node. This is a great feature; if we can annotate each internal node with important information about its state - its current value, the current state of its generator, the ‘cost’ associated with it - then we can walk through the program and examine it as required. So, it can capture a ‘running’ program in exactly the way we need.

Let’s describe running programs as values having the following Execution type:

type Execution a = Cofree (ModelF a) Node

The Node type is what we’ll use to describe the internal state of each node on the program. I’ll define it like so:

data Node = Node {
    nodeCost    :: Double
  , nodeValue   :: Dynamic
  , nodeSeed    :: MWC.Seed
  , nodeHistory :: [Dynamic]
  } deriving Show

I’ll elaborate on this type below, but you can see that it captures a bunch of information about the state of each node.

One can mechanically transform any Free-encoded program into a Cofree-encoded program, so long as the original Free-encoded program can terminate of its own accord, i.e. on the level of its own instructions. Hence the need for our Terminating type and all that.

In our case, setting everything up just right takes a bit of code, mainly around handling pseudo-random number generators in a pure fashion. So I won’t talk about every little detail of it right here. The general idea is to write a function that takes instructions to the appropriate state captured by a Node value, like so:

initialize :: Typeable a => MWC.Seed -> ModelF a b -> Node
initialize seed = \case
  BernoulliF p _ -> runST $ do
    (nodeValue, nodeSeed) <- samplePurely (Prob.bernoulli p) seed
    let nodeCost    = logDensityBernoulli p (unsafeFromDyn nodeValue)
        nodeHistory = mempty
    return Node {..}

  BetaF a b _ -> runST $ do
    (nodeValue, nodeSeed) <- samplePurely (Prob.beta a b) seed
    let nodeCost    = logDensityBeta a b (unsafeFromDyn nodeValue)
        nodeHistory = mempty
    return Node {..}

  ...

You can see that for each node, I sample from it, calculate its cost, and then initialize its ‘history’ as an empty list.

Here it’s worth going into a brief aside.

There are two mildly annoying things we have to deal with in this situation. First, individual nodes in the program typically sample values at different types, and second, we can’t easily use effects when annotating. This means that we have to pack heterogeneously-typed things into a homogeneously-typed container, and also use pure random number generation facilities to sample them.

A quick-and-dirty answer for the first case is to just use dynamic typing when storing the values. It works and is easy, but of course is subject to the standard caveats. I use a function called unsafeFromDyn to convert dynamically-typed values back to a typed form, so you can gauge the safety of all this for yourself.

For the second case, I just use the ST monad, along with manual state snapshotting, to execute and iterate a random number generator. Pretty simple.

Also: in terms of efficiency, keeping a node’s history on-site at each execution falls into the ‘completely insane’ category, but let’s not worry much about efficiency right now. Prototypes gonna prototype and all that.

Anyway.

Given this initialize function, we can transform a terminating program into a running program by simple recursion. Again, we can only transform programs with type Terminating a because we need to rule out the case of ever visiting the Pure constructor of Free. We handle that by the absurd function provided by Data.Void:

execute :: Typeable a => Terminating a -> Execution a
execute = annotate defaultSeed where
  defaultSeed         = (42, 108512)
  annotate seeds term = case term of
    Pure r -> absurd r
    Free instruction ->
      let (nextSeeds, generator) = xorshift seeds
          seed  = MWC.toSeed (V.singleton generator)
          node  = initialize seed instruction
      in  node :< fmap (annotate nextSeeds) instruction

And there you have it - execute takes a terminating program as input and returns a running program - an execution trace - as output. The syntax tree we had previously gets turned into something like this:

Perturbing Running Programs

Given an execution trace, we’re able to step through it sequentially and investigate the program’s internal state. But to do inference we also need to modify it as well. What’s the answer here?

Just as Free has a monadic structure that allows us to write embedded programs using built-in monadic combinators and do-notation, Cofree has a comonadic structure that is amenable to use with the various comonadic combinators found in Control.Comonad. The most important for our purposes is the comonadic ‘extend’ operation that’s dual to monad’s ‘bind’:

extend :: Comonad w => (w a -> b) -> w a -> w b
extend f = fmap f . duplicate

To perturb a running program, we can thus write a function that perturbs any given annotated node, and then extend it over the entire execution trace.

The perturbNode function can be similar to the initialize function from earlier; it describes how to perturb every node based on the instruction found there:

perturbNode :: Execution a -> Node
perturbNode (node@Node {..} :< cons) = case cons of
  BernoulliF p _ -> runST $ do
    (nvalue, nseed) <- samplePurely (Prob.bernoulli p) nodeSeed
    let nscore   = logDensityBernoulli p (unsafeFromDyn nvalue)
    return $! Node nscore nvalue nseed nodeHistory

  BetaF a b _ -> runST $ do
    (nvalue, nseed) <- samplePurely (Prob.beta a b) nodeSeed
    let nscore   = logDensityBeta a b (unsafeFromDyn nvalue)
    return $! Node nscore nvalue nseed nodeHistory

  ...

Note that this is a very crude way to perturb nodes - we’re just sampling from whatever distribution we find at each one. A more refined procedure would sample from each node on a more local basis, sampling from its respective domain in a neighbourhood of its current location. For example, to perturb a BetaF node we might sample from a tiny Gaussian bubble around its current location, repeating the process if we happen to ‘fall off’ the support. I’ll leave matters like that for another post.

Perturbing an entire trace is then as easy as I claimed it to be:

perturb :: Execution a -> Execution a
perturb = extend perturbNode

For some comonadic intuition: when we ‘extend’ a function over an execution, the trace itself gets ‘duplicated’ in a comonadic context. Each node in the program becomes annotated with a view of the rest of the execution trace from that point forward. It can be difficult to visualize at first, but I reckon the following image is pretty faithful:

Each annotation then has perturbNode applied to it, which reduces the trace back to the standard annotated version we saw before.

Iterating the Markov Chain

So: to move around in parameter space, we’ll propose state changes by perturbing the current state, and then accept or reject proposals according to local economic conditions.

If you already have no idea what I’m talking about, then the phrase ‘local economic conditions’ probably didn’t help you much. But it’s a useful analogy to have in one’s head. Each state in parameter space has a cost associated with it - the cost of generating the observations that we’re conditioning on while doing inference. If certain parameter values yield a data model that is unlikely to generate the provided observations, then those observations will be expensive to generate when measured in terms of log-likelihood. Parameter values that yield data models more likely to generate the supplied observations will be comparatively cheaper.

If a proposed execution trace is significantly cheaper than the trace we’re currently at, then we usually want to move to it. We allow some randomness in our decision to keep everything nice and measure-preserving.

We can thus construct the conditional distribution over execution traces using the following invert function, using the same nomenclature as the rejection sampler we used previously. To focus on the main points, I’ll elide some of its body:

invert
  :: (Eq a, Typeable a, Typeable b)
  => Int -> [a] -> Model b -> (b -> a -> Double)
  -> Model (Execution b)
invert epochs obs prior ll = loop epochs (execute terminated) where
  terminated = prior >>= dirac
  loop n current
    | n == 0    = return current
    | otherwise = do
        let proposal = perturb current

            -- calculate costs and movement probability here

        accept <- bernoulli prob
        let next = if accept then proposal else stepGenerators current
        loop (pred n) (snapshot next)

There are a few things to comment on here.

First, notice how the return type of invert is Model (Execution b)? Using the semantics of our embedded language, it’s literally a standard model over execution traces. The above function returns a first-class value that is completely uninterpreted and abstract. Cool.

We’re also dealing with things a little differently from the rejection sampler that we built previously. Here, the data model is expressed by a cost function; that is, a function that takes a parameter value and observation as input, and returns the cost of generating the observation (conditional on the supplied parameter value) as output. This is the approach used in the excellent Practical Probabilistic Programming with Monads paper by Adam Scibior et al and also mentioned by Dan Roy in his recent talk at the Simons Institute. Ideally we’d just reify the cost function here from the description of a model directly (to keep the interface similar to the one used in the rejection sampler implementation), but I haven’t yet found a way to do this in a type-safe fashion.

Regardless of whether or not we accept a proposed move, we need to snapshot the current value of each node and add it to that node’s history. This can be done using another comonadic extend:

snapshotValue :: Cofree f Node -> Node
snapshotValue (Node {..} :< cons) = Node { nodeHistory = history, .. } where
  history = nodeValue : nodeHistory

snapshot :: Functor f => Cofree f Node -> Cofree f Node
snapshot = extend snapshotValue

The other point of note is minor, but an extremely easy detail to overlook. Since we’re handling random value generation at each node purely, using on-site PRNGs, we need to iterate the generators forward a step in the event that we don’t accept a proposal. Otherwise we’d propose a new execution based on the same generator states that we’d used previously! For now I’ll just iterate the generators by forcing a sample of a uniform variate at each node, and then throwing away the result. To do this we can use the now-standard comonadic pattern:

stepGenerator :: Cofree f Node -> Node
stepGenerator (Node {..} :< cons) = runST $ do
  (nval, nseed) <- samplePurely (Prob.beta 1 1) nodeSeed
  return Node {nodeSeed = nseed, ..}

stepGenerators :: Functor f => Cofree f Node -> Cofree f Node
stepGenerators = extend stepGenerator

Inspecting Execution Traces

Alright so let’s see how this all works. Let’s write a model, condition it on some observations, and do inference.

We’ll choose our simple Gaussian mixture model from earlier, where the mixing probability follows a beta distribution, and cluster assignment itself follows a Bernoulli distribution. We thus choose the ‘leftmost’ component of the mixture with the appropriate mixture probability.

We can break the mixture model up as follows:

prior :: Double -> Double -> Model Bool
prior a b = do
  p <- beta a b
  bernoulli p

likelihood :: Bool -> Model Double
likelihood left
  | left      = normal (negate 2) 0.5
  | otherwise = normal 2 0.5

Let’s take a look at some samples from the marginal distribution. This time I’ll flip things and assign hyperparameters of 3 and 2 for the prior:

> simulate (toSampler (prior 3 2 >>= likelihood))

It looks like we’re slightly more likely to sample from the left mixture component than the right one. Again, this makes sense - the mean of a beta(3, 2) distribution is 0.6.

Now, what about inference? I’ll define the conditional model as follows:

posterior :: Model (Execution Bool)
posterior = invert 1000 obs prior ll where
  obs = [ -1.7, -1.8, -2.01, -2.4
        , 1.9, 1.8
        ]

  ll left
    | left      = logDensityNormal (negate 2) 0.5
    | otherwise = logDensityNormal 2 0.5

Here we have four observations that presumably arise from the leftmost component, and only two that match up with the rightmost. Note also that I’ve replaced the likelihood model with its appropriate cost function due to reasons I mentioned in the last section. (It would be easy to reify this model as its cost function, but doing it for general models is trickier)

Anyway, let’s sample from the conditional distribution:

> simulate (toSampler posterior)

Sampling returns a running program, of course, and we can step through it to examine its structure. We can use the supplied values recorded at each node to ‘automatically’ step through execution, or we can supply our own values to investigate arbitrary branches.

The conditional distribution we’ve found over the mixing probability is as follows:

Looks like we’re in the right ballpark.

We can examine the traces of other elements of the program as well. Here’s the recorded distribution over component assignments, for example - note that the rightmost bar here corresponds to the leftmost component in the mixture:

You can see that whenever we wandered into the rightmost component, we’d swiftly wind up jumping back out of it:

Comments

This is a fun take on probabilistic programming. In particular I find a few aspects of the whole setup to be pretty attractive:

We use a primitive, limited instruction set to parameterize both programs - via Free - and running programs - via Cofree. These off-the-shelf recursive types are used to wrap things up and provide most of our required control flow automatically. It’s easy to transparently add structure to embedded programs built in this way; for example, we can statically encode independence by replacing our ModelF a type with something like:

data InstructionF a = Coproduct (ModelF a) (Ap (ModelF a))

This can be hidden from the user so that we’re left with the same simple monadic syntax we presently enjoy, but we also get to take independence into account when performing inference, or any other structural interpretation for that matter.

When it comes to inference, the program representation is completely separate from whatever inference backend we choose to augment it with. We can deal with traces as first-class values that can be directly stored, inspected, manipulated, and so on. And everything is done in a typed and purely-functional framework. I’ve used dynamic typing functionality from Data.Dynamic to store values in execution traces here, but we could similarly just define a concrete Value type with the appropriate constructors for integers, doubles, bools, etc., and use that to store everything.

At the same time, this is a pretty early concept - doing inference efficiently in this setting is another matter, and there are a couple of computational and statistical issues here that need to be ironed out to make further progress.

The current way I’ve organized Markov chain generation and iteration is just woefully inefficient. Storing the history of each node on-site is needlessly costly and I’m sure results in a ton of unnecessary allocation. On a semantic level, it also ‘complects’ state and identity: why, after all, should a single execution trace know anything about traces that preceded it? Clearly this should be accumulated in another data structure. There is a lot of other low-hanging fruit around strictness and PRNG management as well.

From a more statistical angle, the present implementation does a poor job when it comes to perturbing execution traces. Some changes - such as improving the proposal mechanism for a given instruction - are easy to implement, and representing distributions as instructions indeed makes it easy to tailor local proposal distributions in a context-independent way. But another problem is that, by using a ‘blunt’ comonadic extend, we perturb an execution by perturbing every node in it. In general it’s better to make small perturbations rather than large ones to ensure a reasonable acceptance ratio, but to do that we’d need to perturb single nodes (or at least subsets of nodes) at a time.

There may be some inroads here via comonad transformers like StoreT or lenses that would allow us to zoom in on a particular node and perturb it, rather than perturbing everything at once. But my comonad-fu is not yet quite at the required level to evaluate this, so I’ll come back to that idea some other time.

I’m interested in playing with this concept some more in the future, though I’m not yet sure how much I expect it to be a tenable way to do inference at scale. If you’re interested in playing with it, I’ve dumped the code from this post into this gist.

Thanks to Niffe Hermansson and Fredrik Olsen for reviewing a draft of this post and providing helpful comments.

A Simple Embedded Probabilistic Programming Language

(This article is also published at Medium)

What does a dead-simple probabilistic programming language look like? The simplest thing I can imagine involves three components:

  • A representation for probabilistic models.
  • A way to simulate from those models (‘forward’ sampling).
  • A way to sample from a conditional model (‘backward’ sampling).

Rob Zinkov wrote an article on this type of thing around a year ago, and Dan Roy recently gave a talk on the topic as well. In the spirit of unabashed unoriginality, I’ll give a sort of composite example of the two. Most of the material here comes directly from Dan’s talk; definitely check it out if you’re curious about this whole probabilistic programming mumbojumbo.

Let’s whip together a highly-structured, typed, embedded probabilistic programming language - the core of which will encompass a tiny amount of code.

Some preliminaries - note that you’ll need my simple little mwc-probability library handy for when it comes time to do sampling:

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}

import Control.Monad
import Control.Monad.Free
import qualified System.Random.MWC.Probability as MWC

Representing Probabilistic Models

Step one is to represent the fundamental constructs found in probabilistic programs. These are abstract probability distributions; I like to call them models:

data ModelF r =
    BernoulliF Double (Bool -> r)
  | BetaF Double Double (Double -> r)
  deriving Functor

type Model = Free ModelF

Each foundational probability distribution we want to consider is represented as a constructor of the ModelF type. You can think of them as probabilistic instructions, in a sense. A Model itself is a program parameterized by this probabilistic instruction set.

In a more sophisticated implementation you’d probably want to add more primitives, but you can get pretty far with the beta and Bernoulli distributions alone. Here are some embedded language terms, only two of which correspond one-to-one with to the constructors themselves:

bernoulli :: Double -> Model Bool
bernoulli p = liftF (BernoulliF p id)

beta :: Double -> Double -> Model Double
beta a b = liftF (BetaF a b id)

uniform :: Model Double
uniform = beta 1 1

binomial :: Int -> Double -> Model Int
binomial n p = fmap count coins where
  count = length . filter id
  coins = replicateM n (bernoulli p)

betaBinomial :: Int -> Double -> Double -> Model Int
betaBinomial n a b = do
  p <- beta a b
  binomial n p

You can build a lot of other useful distributions by just starting from the beta and Bernoulli as well. And technically I guess the more foundational distributions to use here would be the Dirichlet and categorical, of which the beta and Bernoulli are special cases. But I digress. The point is that other distributions are easy to construct from a set of reliable primitives; you can check out the old lambda-naught paper by Park et al for more examples.

See how binomial and betaBinomial are defined? In the case of binomial we’re using the property that models have a functorial structure by just mapping a counting function over the result of a bunch of Bernoulli random variables. For betaBinomial we’re directly making use of our monadic structure, first describing a weight parameter via a beta distribution and then using it as an input to a binomial distribution.

Note in particular that we’ve expressed betaBinomial by binding a parameter model to a data model. This is a foundational pattern in Bayesian statistics; in the more usual lingo, the parameter model corresponds to the prior distribution, and the data model is the likelihood.

Forward-Mode Sampling

So we have our representation. Next up, we want to simulate from these models. Thus far they’re purely abstract, and don’t encode any information about probability or sampling or what have you. We have to ascribe that ourselves.

mwc-probability defines a monadic sampling-based probability distribution type called Prob, and we can use a basic recursion scheme on free monads to adapt our own model type to that:

toSampler :: Model a -> MWC.Prob IO a
toSampler = iterM $ \case
  BernoulliF p f -> MWC.bernoulli p >>= f
  BetaF a b f    -> MWC.beta a b >>= f

We can glue that around the relevant mwc-probability functionality to simulate from models directly:

simulate :: Model a -> IO a
simulate model = MWC.withSystemRandom . MWC.asGenIO $
  MWC.sample (toSampler model)

And this can be used with standard monadic combinators like replicateM to collect larger samples:

> replicateM 10 $ simulate (betaBinomial 10 1 4)
[5,7,1,4,4,1,1,0,4,2]

Reverse-Mode Sampling

Now. Here we want to condition our model on some observations and then recover the conditional distribution over its internal parameters.

This part - inference - is what makes probabilistic programming hard, and doing it really well remains an unsolved problem. One of the neat theoretical results in this space due to Ackerman, Freer, and Roy is that in the general case the problem is actually unsolvable, in that one can encode as a probabilistic program a conditional distribution that computes the halting problem. Similarly, in general it’s impossible to do this sort of thing efficiently even for computable conditional distributions. Consider the case of a program that returns the hash of a random n-long binary string, and then try to infer the distribution over strings given some hashes, for example. This is never going to be a tractable problem.

For now let’s use a simple rejection sampler to encode a conditional distribution. We’ll require some observations, a proposal distribution, and the model that we want to invert:

invert :: (Monad m, Eq b) => m a -> (a -> m b) -> [b] -> m a
invert proposal model observed = loop where
  loop = do
    parameters <- proposal
    generated  <- replicateM (length observed) (model parameters)
    if   generated == observed
    then return parameters
    else loop

Let’s use it to compute the posterior or inverse model of an (apparently) biased coin, given a few observations. We’ll just use a uniform distribution as our proposal:

posterior :: Model Double
posterior = invert [True, True, False, True] uniform bernoulli

Let’s grab some samples from the posterior distribution:

> replicateM 1000 (simulate posterior)

The central tendency of the posterior floats about 0.75, which is what we’d expect, given our observations. This has been inferred from only four points; let’s try adding a few more. But before we do that, note that the present way the rejection sampling algorithm works is:

  • Propose a parameter value according to the supplied proposal distribution.
  • Generate a sample from the model, of equal size to the supplied observations.
  • Compare the collected sample to the supplied observations. If they’re equal, then return the proposed parameter value. Otherwise start over.

Rejection sampling isn’t exactly efficient in nontrivial settings anyway, but it’s supremely inefficient for our present case. The random variables we’re interested in are exchangeable, so what we’re concerned about is the total number of True or False values observed - not any specific order they appear in.

We can add an ‘assistance’ function to the rejection sampler to help us out in this case:

invertWithAssistance
  :: (Monad m, Eq c) => ([a] -> c) -> m b -> (b -> m a) -> [a] -> m b
invertWithAssistance assister proposal model observed = loop where
  loop = do
    parameters <- proposal
    generated  <- replicateM (length observed) (model parameters)
    if   assister generated == assister observed
    then return parameters
    else loop

The assister summarizes both our observations and collected sample to ensure they’re efficiently comparable. In our situation, we can use a simple counting function to tally up the number of True values we observe:

count :: [Bool] -> Int
count = length . filter id

Now let’s create another posterior by conditioning on a few more observations:

posterior0 :: Model Double
posterior0 = invertWithAssitance count uniform bernoulli obs where
  obs =
    [True, True, True, False, True, True, False, True, True, True, True, False]

and collect another thousand samples from it. This would likely take an annoying amount of time without the use of our count function for assistance above:

> replicateM 1000 (simulate posterior0)

Note that with more information to condition on, we get a more informative posterior.

Conclusion

This is a really basic formulation - too basic to be useful in any meaningful way - but it illustrates some of the most important concepts in probabilistic programming. Representation, simulation, and inference.

I think it’s also particularly nice to do this in Haskell, rather than something like Python (which Dan used in his talk) - it provides us with a lot of extensible structure in a familiar framework for language hacking. It sort of demands you’re a fan of all these higher-kinded types and structured recursions and all that, but if you’re reading this blog then you’re probably in that camp anyway.

I’ll probably write a few more little articles like this over time. There are a ton of improvements that we can make to this basic setup - encoding independence, sampling via MCMC, etc. - and it might be fun to grow everything out piece by piece.

I’ve dropped the code from this post into this gist.