View source: R/get_preds_by_chain.R
get_preds_by_chain | R Documentation |
get_preds_by_chain
is a helper function which aims to automate the
loading of posterior predictions, with an optional method to help avoid
memory overload (and crashes).
get_preds_by_chain(
out_files,
out_dir = "",
obs_df,
n_draws_chain,
save_dir = out_dir,
test = FALSE,
prefix = "",
splits = list(blocks = 1:6, sum_blks = list(c(1, 3), c(4, 6))),
exclude = NULL,
memory_save = TRUE,
...
)
out_files |
Vector of .csv file names which contain posterior
predictions (e.g., outputted from |
out_dir |
Path to output directory (defaults to current working directory). |
obs_df |
Raw training data, e.g., outputted from |
n_draws_chain |
Number of MCMC sampling iterations per chain. |
save_dir |
Directory to save items to, will be created if it does not exist. Defaults to the directory of the output files. |
test |
Boolean indicating whether the posterior samples are from the test phase. |
prefix |
Optional prefix to add to the saved objects. |
splits |
A list specifying which individual and summed blocks to average draws over; may be interesting e.g., to see whether predictions become more accurate later in the task. |
exclude |
ID numbers to exclude from output (e.g., if there was insufficient mixing for their parameters). |
memory_save |
An alternative method to obtain predictions, which loads the predictions for each individual (across all chains) one-by-one, as opposed to importing all the draws for all individuals. This will be significantly slower but enables the function to run on systems with limited RAM. |
... |
Other arguments which are unlikely to be necessary to change:
|
An updated tibble
with summed choices per chain and their
overall proportion.
## Not run:
data(example_data)
dir.create("outputs/cmdstan/predictions")
fit <- fit_learning_model(
example_data$nd,
model = "2a",
vb = FALSE,
exp_part = "training",
iter_sampling = 2000,
outputs = c("model_env", "raw_df", "stan_datalist")
)
pred_paths <- generate_posterior_quantities(
fit_mcmc = fit,
data_list = fit$stan_datalist,
return_type = "paths"
)
obs_df_preds <- get_preds_by_chain(
out_files = pred_paths,
out_dir = "outputs/cmdstan/predictions",
obs_df = fit$raw_df,
n_draws_chain = 2000
)
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.