Parameter Estimation with Turing

It is possible to use Turing.jl to perform Bayesian parameter estimation on models defined in SequentialSamplingModels.jl. Below, we show you how to estimate the parameters for the Linear Ballistic Accumulator (LBA).

Example

Load Packages

The first step is to load the required packages. You will need to install each package in your local environment in order to run the code locally.

using Turing
using SequentialSamplingModels
using Random
using LinearAlgebra

Define Turing Model

The code snippet below defines a model in Turing. The model function accepts a tuple containing a vector of choices and a vector of reaction times. The sampling statements define the prior distributions for each parameter. The non-decision time parameter $\tau$ must be founded by the minimum reaction time, min_rt. The last sampling statement defines the likelihood of the data given the sampled parameter values.

@model model(data) = begin
    min_rt = minimum(data[2])
    ν ~ MvNormal(zeros(2), I * 2)
    A ~ truncated(Normal(.8, .4), 0.0, Inf)
    k ~ truncated(Normal(.2, .2), 0.0, Inf)
    τ  ~ Uniform(0.0, min_rt)
    data ~ LBA(;ν, A, k, τ )
end
model (generic function with 2 methods)

Generate Simulated Data

In the code snippet below, we set a seed for the random number generator and generate $100$ simulated trials from the LBA from which we will estimate parameters.

# generate some data
Random.seed!(45461)
dist = LBA(ν=[3.0,2.0], A = .8, k = .2, τ = .3)
data = rand(dist, 100)
(choice = [2, 1, 1, 1, 1, 1, 1, 1, 1, 2  …  2, 1, 1, 1, 1, 1, 2, 2, 1, 1], rt = [0.6290010134562214, 0.507511686077583, 0.4077035786951601, 0.5845853642627061, 0.5518606313501111, 0.357077395905047, 0.4430796909532061, 0.35525992231554643, 0.4858235685699115, 0.37197001383636874  …  0.8595469187781852, 0.5705118298266227, 0.5348699092588236, 0.4397605741138838, 0.48381613889237696, 0.4576847624125857, 0.4835375184039913, 0.5544902383424269, 0.583537266351565, 0.36048725612183097])

Estimate the Parameters

Finally, we perform parameter estimation with sample, which accepts the following inputs:

  1. model(data): the Turing model with data passed
  2. NUTS(1000, .65): a sampler object for the No U-Turn Sampler for 1000 warmup samples.
  3. MCMCThreads(): instructs turing to run each chain on a seperate thread
  4. n_iterations: the number of iterations performed after warmup
  5. n_chains: the number of chains
# estimate parameters
chain = sample(model(data), NUTS(1000, .85), MCMCThreads(), 1000, 4)
Chains MCMC chain (1000×17×4 Array{Float64, 3}):

Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 44.75 seconds
Compute duration  = 44.7 seconds
parameters        = ν[1], ν[2], A, k, τ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

        ν[1]    2.8365    0.4287    0.0125   1180.2188   1798.1497    1.0027   ⋯
        ν[2]    1.6636    0.3752    0.0111   1150.0528   1649.2331    1.0025   ⋯
           A    0.7300    0.1720    0.0053   1063.8510   1249.1151    1.0048   ⋯
           k    0.2462    0.1178    0.0037    924.9209    844.9503    1.0058   ⋯
           τ    0.2792    0.0284    0.0009    994.4414    882.1512    1.0041   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

        ν[1]    2.0535    2.5440    2.8197    3.1200    3.7132
        ν[2]    0.9614    1.4106    1.6566    1.9104    2.4265
           A    0.4154    0.6119    0.7215    0.8407    1.0977
           k    0.0471    0.1631    0.2332    0.3192    0.5171
           τ    0.2177    0.2614    0.2815    0.2988    0.3289

Evaluation

It is important to verify that the chains converged. We see that the chains converged according to $\hat{r} \leq 1.05$, and the trace plots below show that the chains look like "hairy catipillars", whichin indictes the chains did not get stuck. As expected, the posterior distributions are close to the data generating parameter values.

plot(chain, grid=false)