get_preds_by_chain: Get and store posterior predictions for training data

View source: R/get_preds_by_chain.R

get_preds_by_chainR Documentation

Get and store posterior predictions for training data

Description

get_preds_by_chain is a helper function which aims to automate the loading of posterior predictions, with an optional method to help avoid memory overload (and crashes).

Usage

get_preds_by_chain(
  out_files,
  out_dir = "",
  obs_df,
  n_draws_chain,
  save_dir = out_dir,
  test = FALSE,
  prefix = "",
  splits = list(blocks = 1:6, sum_blks = list(c(1, 3), c(4, 6))),
  exclude = NULL,
  memory_save = TRUE,
  ...
)

Arguments

out_files

Vector of .csv file names which contain posterior predictions (e.g., outputted from generate_posterior_quantities()).

out_dir

Path to output directory (defaults to current working directory).

obs_df

Raw training data, e.g., outputted from fit_learning_model() (this is best as it ensures individuals are matched with the correct predictions.)

n_draws_chain

Number of MCMC sampling iterations per chain.

save_dir

Directory to save items to, will be created if it does not exist. Defaults to the directory of the output files.

test

Boolean indicating whether the posterior samples are from the test phase.

prefix

Optional prefix to add to the saved objects.

splits

A list specifying which individual and summed blocks to average draws over; may be interesting e.g., to see whether predictions become more accurate later in the task.

exclude

ID numbers to exclude from output (e.g., if there was insufficient mixing for their parameters).

memory_save

An alternative method to obtain predictions, which loads the predictions for each individual (across all chains) one-by-one, as opposed to importing all the draws for all individuals. This will be significantly slower but enables the function to run on systems with limited RAM.

...

Other arguments which are unlikely to be necessary to change: n_trials (default = 360); vars (default = "y_pred"), and pred_types (default = c("AB", "CD", "EF")).

Value

An updated tibble with summed choices per chain and their overall proportion.

Examples

## Not run: 
data(example_data)
dir.create("outputs/cmdstan/predictions")

fit <- fit_learning_model(
  example_data$nd,
  model = "2a",
  vb = FALSE,
  exp_part = "training",
  iter_sampling = 2000,
  outputs = c("model_env", "raw_df", "stan_datalist")
)

pred_paths <- generate_posterior_quantities(
  fit_mcmc = fit,
  data_list = fit$stan_datalist,
  return_type = "paths"
)

obs_df_preds <- get_preds_by_chain(
  out_files = pred_paths,
  out_dir = "outputs/cmdstan/predictions",
  obs_df = fit$raw_df,
  n_draws_chain = 2000
)

## End(Not run)


qdercon/pstpipeline documentation built on June 1, 2025, 1:11 p.m.