fit_learning_model: General function to run Bayesian models using cmdstanr

View source: R/fit_learning_model.R

fit_learning_modelR Documentation

General function to run Bayesian models using cmdstanr

Description

fit_learning_model uses the package cmdstanr, which is a lightweight R interface to CmdStan. Please note that while it checks if the C++ toolchain is correctly configured, running this function will not install CmdStan itself. This may be as simple as running cmdstanr::install_cmdstan(), but may require some extra effort (e.g., pointing R to the install location via cmdstanr::set_cmdstan_path()) - see the cmdstanr vignette for more detail.

Usage

fit_learning_model(
  df_all,
  model,
  exp_part,
  affect = FALSE,
  affect_sfx = c("3wt", "4wt_trial", "4wt_block", "4wt_time", "5wt_time", "delta"),
  adj_order = c("happy", "confident", "engaged"),
  vb = TRUE,
  ppc = vb,
  par_recovery = FALSE,
  task_excl = TRUE,
  accuracy_excl = FALSE,
  model_checks = !vb,
  save_model_as = "",
  out_dir = "outputs/cmdstan",
  outputs = c("raw_df", "summary", "draws_list"),
  save_outputs = TRUE,
  cores = getOption("mc.cores", 4),
  ...
)

Arguments

df_all

Raw data outputted from import_multiple().

model

Learning model to use, choose from 1a or 2a.

exp_part

Fit to training or test?

affect

Fit extended Q-learning model with affect ratings?

affect_sfx

String prefix to identify specific affect model, ignored if affect == FALSE. Defaults to model with trial-wise passage-of-time.

adj_order

Vector of affect adjectives which is used to define their numerical order in the model output.

vb

Use variational inference to get the approximate posterior? Default is TRUE for computational efficiency.

ppc

Generate quantities including mean parameters, log likelihood, and posterior predictions? Intended for use with variational algorithm; for MCMC it is recommended to run the separate generate_posterior_quantities() function, as this is far less memory intensive.

par_recovery

Method to fit model to simulated data (i.e., from simulate_QL()).

task_excl

Apply task-related exclusion criteria (catch questions, digit span = 0)?

accuracy_excl

Apply accuracy-based exclusion criteria (final block AB accuracy >= 0.6)? This is not recommended and is deprecated.

model_checks

Runs check_learning_models(), returning plots of the group-level posterior densities for the free parameters, and some visual model checks (traceplots of the chains, and rank histograms). Note the visual checks will only be returned if !vb, as they are only relevant for MCMC fits, and require the bayesplot package.

save_model_as

Name to give to saved model and used to name the .csv files and outputs. Defaults to the Stan model name.

out_dir

Output directory for model fit environment, plus all specified outputs if save_outputs = TRUE.

outputs

Specific outputs to return (and save, if save_outputs). In addition to the defaults, other options are "model_env" (note this is saved automatically, regardless of save_outputs), and "loo_obj". The latter includes the theoretical expected log-predictive density (ELPD) for a new dataset, plus the leave-one-out information criterion (LOOIC), a fully Bayesian metric for model comparison; this requires the loo package.

save_outputs

Save the specified outputs to the disk? Will save to out_dir.

cores

Maximum number of chains to run in parallel. Defaults to options(mc.cores = cores) or 4 if this is not set (this option will then apply for the rest of the session).

...

Other arguments passed to cmdstanr::sample() and/or check_learning_models. See the CmdStan user guide for full details and defaults.

Details

fit_learning_model heavily leans on various helper functions from the hBayesDM package, and is not as flexible; instead it is designed primarily to be less memory-intensive for our specific use-case and provide only relevant output.

Value

List containing a cmdstanr::CmdStanVB or cmdstanr::CmdStanMCMC fit object, plus any other outputs passed to outputs.

Examples

## Not run: 
# Single learning rate Q-learning model fit to training data with MCMC

data(example_data)
fit1 <- fit_learning_model(
  example_data$nd,
  model = "1a",
  vb = FALSE,
  exp_part = "training",
  iter_warmup = 1000, # default
  iter_sampling = 1000, # default
  chains = 4 # default
)

# Dual learning rate Q-learning model fit to training plus test data with
# variational inference

data(example_data)
fit2 <- fit_learning_model(
  example_data$nd,
  model = "2a",
  exp_part = "test",
  vb = TRUE
)

# Simplest affect model with three weights, fit with variational inference

fit3 <- fit_learning_model(
  example_data$nd,
  model = "2a",
  affect = TRUE,
  affect_sfx = "3wt",
  exp_part = "training",
  algorithm = "fullrank"
)

## End(Not run)


qdercon/pstpipeline documentation built on June 1, 2025, 1:11 p.m.