Jason Pekos

Reparametrization (and some Divergence Diagnostics) in Turing

A full implementation of the code here can be found at this gist.

What's going on in Stan?

The Stan user's guide has a helpful section on reparametrization. It's not immediately clear how to implement some of these examples in Turing –- a Stan model is structured as follows:

data {
  
}
transformed data {
  
}
parameters {
  
}
transformed parameters { # Look here !
  
}
model {
  
}

With a specific block for tracking transformed variables. As far as I can tell, no identical option exists in Turing. Instead, the best way is to return any parameters of interest and then call generated_quantities.

Turing implementation

We can set up the classic Neal's funnel example in Turing like this:

@model function Neal()
    y ~ Normal(0,3)
    x ~ arraydist([Normal(0, exp(y/2)) for i in 1:9])
end

Where sampling with the default NUTS() option returns:

Divergence Diagnostics in Turing

Turing tags divergent transitions with numerical_error == 1. We can quickly check if our chain has any divergences with:

sum(simple_chain[:numerical_error])

Indicating that we have:

145.0

divergences. We can add divergence indicators to our posterior sample plot, pointing us towards areas of our geometry that are problematic:

divergences_naive_param = filter(row -> row.numerical_error == 1,
                                 DataFrame(simple_chain))

scatter!(divergences_naive_param[!, "x[1]"],
         divergences_naive_param[!, :y],
         markershape = :x,
         color = :aquamarine,
         opacity = 0.8,
         label = "divergence")

Ways Forward

Divergent transitions are usually symptoms of some deeper degeneracy, and as such can indicate issues with biased inference in your MCMC sampler, non-identifiability in your model, etc.

A guide to taming divergences can be found here.

As discussed in the Stan handbook, the most powerful solution is usually model reparametrization, which can be done in Turing as follows:

@model function Neal2()
    # raw draws
    y_raw ~ Normal(0,1)
    x_raw ~ arraydist([Normal(0, 1) for i in 1:9])

    # transform:
    y = 3*y_raw
    x = exp.(y./2) .* x_raw

    # return:
    return [x; y]
end

Where we use the generated_quantities() function to pull out the transformed variables:

rawer_chain = sample(Neal2(), NUTS(), 5000)

raw_chain = Turing.MCMCChains.get_sections(rawer_chain,
                                           :parameters)

reparam_chain = reduce(hcat, generated_quantities(Neal2(), raw_chain))

We can check the number of divergences in our reparametrized chain:

div = sum(rawer_chain[:numerical_error])
Our new parametrization has: 0.0 divergences

Plotting, we can see that this allows much better exploration of the funnel:


JasonPekos. Last modified: July 22, 2024. Website built with Franklin.jl.