A full implementation of the code here can be found at this gist.
The Stan user's guide has a helpful section on reparametrization. It's not immediately clear how to implement some of these examples in Turing –- a Stan model is structured as follows:
data {
}
transformed data {
}
parameters {
}
transformed parameters { # Look here !
}
model {
}
With a specific block for tracking transformed variables. As far as I can tell, no identical option exists in Turing. Instead, the best way is to return any parameters of interest and then call generated_quantities.
We can set up the classic Neal's funnel example in Turing like this:
@model function Neal()
y ~ Normal(0,3)
x ~ arraydist([Normal(0, exp(y/2)) for i in 1:9])
end
Where sampling with the default NUTS()
option returns:
Turing tags divergent transitions with numerical_error == 1
. We can quickly check if our chain has any divergences with:
sum(simple_chain[:numerical_error])
Indicating that we have:
145.0
divergences. We can add divergence indicators to our posterior sample plot, pointing us towards areas of our geometry that are problematic:
divergences_naive_param = filter(row -> row.numerical_error == 1,
DataFrame(simple_chain))
scatter!(divergences_naive_param[!, "x[1]"],
divergences_naive_param[!, :y],
markershape = :x,
color = :aquamarine,
opacity = 0.8,
label = "divergence")
Divergent transitions are usually symptoms of some deeper degeneracy, and as such can indicate issues with biased inference in your MCMC sampler, non-identifiability in your model, etc.
A guide to taming divergences can be found here.
As discussed in the Stan handbook, the most powerful solution is usually model reparametrization, which can be done in Turing as follows:
@model function Neal2()
# raw draws
y_raw ~ Normal(0,1)
x_raw ~ arraydist([Normal(0, 1) for i in 1:9])
# transform:
y = 3*y_raw
x = exp.(y./2) .* x_raw
# return:
return [x; y]
end
Where we use the generated_quantities()
function to pull out the transformed variables:
rawer_chain = sample(Neal2(), NUTS(), 5000)
raw_chain = Turing.MCMCChains.get_sections(rawer_chain,
:parameters)
reparam_chain = reduce(hcat, generated_quantities(Neal2(), raw_chain))
We can check the number of divergences in our reparametrized chain:
div = sum(rawer_chain[:numerical_error])
Our new parametrization has: 0.0 divergences
Plotting, we can see that this allows much better exploration of the funnel: