Nothing
#' Build input for Amortized Bayesian Inference (ABI)
#'
#' Prepares simulation output for Amortized Bayesian Inference (ABI) analysis
#' using the \code{NeuralEstimators} package. Extracts parameters and summary statistics
#' from simulation results, splits data into training and validation sets, and formats
#' them into matrices suitable for neural network training.
#'
#' @details
#' This function provides a streamlined workflow for preparing ABI inputs. It
#' requires that \code{simulation_output} be created by \code{\link{run_simulation}}
#' or \code{\link{load_simulation_output}}. The function automatically handles
#' missing trials and ranks by filling with zeros to ensure complete data matrices.
#'
#' The output format is optimized for the \code{abi} package's training functions,
#' with parameters formatted as matrices (each column is a condition) and summary
#' statistics formatted as lists of matrices (one per condition, with trials as columns).
#'
#' @param simulation_output A eam_simulation_output object from
#' \code{\link{run_simulation}} or \code{\link{load_simulation_output}}.
#' @param theta Character vector of parameter names to extract from simulation_output.
#' These parameters will be used as the target variables for inference.
#' @param Z Character vector of summary statistic column names to extract from
#' the simulation output dataset (e.g., "rt", "item_idx", "choice").
#' @param train_ratio Numeric value between 0 and 1 specifying the proportion
#' of non-test conditions to use for training, with the remainder used for
#' validation (default: 0.8).
#' @param n_test Integer specifying the number of conditions to use for testing
#' (default: 10).
#' @param rank_levels Numeric vector specifying which rank indices to include.
#' If NULL (default), uses all ranks from 1 to n_items from simulation config.
#'
#' @return A list with components suitable for \code{abi} package training:
#' \describe{
#' \item{theta_train}{Matrix of training parameters (parameters × conditions)}
#' \item{theta_val}{Matrix of validation parameters (parameters × conditions)}
#' \item{theta_test}{Matrix of test parameters (parameters × conditions)}
#' \item{Z_train}{List of matrices, one per training condition (ranks*Z × trials)}
#' \item{Z_val}{List of matrices, one per validation condition (ranks*Z × trials)}
#' \item{Z_test}{List of matrices, one per test condition (ranks*Z × trials)}
#' \item{train_idx}{Vector of condition indices used for training}
#' \item{val_idx}{Vector of condition indices used for validation}
#' \item{test_idx}{Vector of condition indices used for testing}
#' \item{train_ratio}{The training ratio used (train / non-test conditions)}
#' \item{n_test}{The number of test conditions used}
#' \item{rank_levels}{The rank levels included in Z matrices}
#' \item{theta}{Character vector of parameter names used}
#' \item{Z}{Character vector of summary statistic names used}
#' }
#'
#' @examples
#' # Load the example dataset
#' rdm_minimal_example <- system.file("extdata", "rdm_minimal", package = "eam")
#' sim_output <- load_simulation_output(file.path(rdm_minimal_example, "simulation"))
#'
#' # build the ABI input
#' abi_input <- build_abi_input(
#' sim_output,
#' c(
#' "V_beta_1",
#' "V_beta_group"
#' ),
#' c(
#' "item_idx",
#' "rt",
#' "choice"
#' )
#' )
#'
#' # view the structure of the ABI input
#' str(abi_input)
#'
#' \dontrun{
#' # Example of using the ABI input for training
#' # (requires NeuralEstimators package and build your estimator first, see our tutorials)
#' train(
#' estimator,
#' theta_train = abi_input$theta_train,
#' theta_val = abi_input$theta_val,
#' Z_train = abi_input$Z_train,
#' Z_val = abi_input$Z_val,
#' epochs = 500,
#' stopping_epochs = 200
#' )
#' }
#' @export
build_abi_input <- function(
simulation_output,
theta,
Z,
train_ratio = 0.8,
n_test = 100,
rank_levels = NULL) {
# TODO: totally re-design current api for abi, only pass the index of conditions,
# and load data in julia rather than r.
# Validate inputs
if (!inherits(simulation_output, "eam_simulation_output")) {
stop("simulation_output must be a eam_simulation_output object")
}
if (!is.character(theta) || length(theta) == 0) {
stop("theta must be a non-empty character vector")
}
if (!is.character(Z) || length(Z) == 0) {
stop("Z must be a non-empty character vector")
}
if (!is.numeric(train_ratio) || length(train_ratio) != 1) {
stop("train_ratio must be a single numeric value")
}
if (train_ratio <= 0 || train_ratio >= 1) {
stop("train_ratio must be between 0 and 1")
}
if (!is.numeric(n_test) || length(n_test) != 1) {
stop("n_test must be a single numeric value")
}
if (n_test < 1 || n_test != floor(n_test)) {
stop("n_test must be a positive integer")
}
if (!is.null(rank_levels) && !is.numeric(rank_levels)) {
stop("rank_levels must be NULL or a numeric vector")
}
# Open the arrow datasets
output <- simulation_output$open_dataset()
conditions <- simulation_output$open_evaluated_conditions()
# Set default rank_levels if NULL
if (is.null(rank_levels)) {
max_rank <- simulation_output$simulation_config$n_items
rank_levels <- seq_len(max_rank)
} else if (!is.numeric(rank_levels)) {
stop("rank_levels must be a numeric vector")
} else {
rank_levels <- sort(unique(rank_levels))
}
# collect unique condition_idx
# NSE variable bindings for R CMD check
condition_idx <- conditions |>
dplyr::select(condition_idx) |>
dplyr::distinct() |>
dplyr::arrange(condition_idx) |>
dplyr::collect() |>
dplyr::pull(condition_idx)
# Split into train, test, and validation sets
n_total <- length(condition_idx)
# Validate n_test
if (n_test >= n_total) {
stop("n_test must be less than the total number of conditions")
}
# First, reserve test set
test_idx <- sample(condition_idx, n_test, replace = FALSE)
# Split remaining conditions by train_ratio
remaining_idx <- setdiff(condition_idx, test_idx)
n_remaining <- length(remaining_idx)
n_train <- floor(n_remaining * train_ratio)
train_idx <- sample(remaining_idx, n_train, replace = FALSE)
val_idx <- setdiff(remaining_idx, train_idx)
# build theta matrices
theta_train <- build_abi_input.theta(
conditions,
theta,
train_idx
)
theta_val <- build_abi_input.theta(
conditions,
theta,
val_idx
)
theta_test <- build_abi_input.theta(
conditions,
theta,
test_idx
)
z_train <- build_abi_input.Z(
output,
Z,
rank_levels,
train_idx,
simulation_output$simulation_config$n_trials_per_condition
)
z_val <- build_abi_input.Z(
output,
Z,
rank_levels,
val_idx,
simulation_output$simulation_config$n_trials_per_condition
)
z_test <- build_abi_input.Z(
output,
Z,
rank_levels,
test_idx,
simulation_output$simulation_config$n_trials_per_condition
)
# build output list
abi_input <- list(
theta_train = theta_train,
theta_val = theta_val,
theta_test = theta_test,
Z_train = z_train,
Z_val = z_val,
Z_test = z_test,
train_idx = train_idx,
val_idx = val_idx,
test_idx = test_idx,
train_ratio = train_ratio,
n_test = n_test,
rank_levels = rank_levels,
theta = theta,
Z = Z
)
class(abi_input) <- c("eam_abi_input", "list")
return(abi_input)
}
#' Extract and format parameter matrix for ABI
#'
#' @param conditions Arrow dataset of evaluated conditions
#' @param theta Character vector of parameter names
#' @param select_idx Vector of condition indices to extract
#' @return Matrix with parameters as rows and conditions as columns
#' @keywords internal
build_abi_input.theta <- function(
conditions,
theta,
select_idx) {
# NSE variable bindings for R CMD check
condition_idx <- NULL
# Filter by train_idx and select theta columns
theta_df <- conditions |>
dplyr::filter(condition_idx %in% select_idx) |>
dplyr::select(condition_idx, dplyr::all_of(theta)) |>
dplyr::arrange(condition_idx) |>
dplyr::collect()
# Convert to matrix (excluding condition_idx column)
# Each column is a trial (Fortran style)
theta_matrix <- as.matrix(theta_df[, theta, drop = FALSE])
theta_matrix <- t(theta_matrix)
return(theta_matrix)
}
#' Extract and format summary statistics for ABI
#'
#' @param output Arrow dataset of simulation output
#' @param Z Character vector of summary statistic column names
#' @param rank_levels Numeric vector of rank indices to include
#' @param select_idx Vector of condition indices to extract
#' @param n_trials_per_condition Number of trials per condition
#' @return List of matrices, one per condition (ranks*Z rows × trials columns)
#' @keywords internal
build_abi_input.Z <- function(
output,
Z,
rank_levels,
select_idx,
n_trials_per_condition) {
# NSE variable bindings for R CMD check
condition_idx <- trial_idx <- rank_idx <- NULL
fill_list <- rlang::set_names(
as.list(rep(0, length(Z))),
Z
)
# Generate all trial indices
all_trial_idx <- seq_len(n_trials_per_condition)
# Filter by select_idx and rank_levels, then select Z columns
z_complete <- output |>
dplyr::filter(condition_idx %in% select_idx, rank_idx %in% rank_levels) |>
dplyr::select(condition_idx, trial_idx, rank_idx, dplyr::all_of(Z)) |>
dplyr::collect() |>
tidyr::complete(
condition_idx = select_idx,
trial_idx = all_trial_idx,
rank_idx = rank_levels,
fill = fill_list
) |>
dplyr::arrange(condition_idx, trial_idx, rank_idx) |>
tidyr::pivot_wider(
names_from = rank_idx,
values_from = dplyr::all_of(Z),
names_glue = "rank_{rank_idx}_{.value}"
)
# Convert to list of matrices, one per condition
# Each matrix has trials as columns and (rank * Z) as rows
z_list <- z_complete |>
dplyr::select(-trial_idx) |>
dplyr::group_by(condition_idx) |>
dplyr::group_split(.keep = FALSE) |>
purrr::map(function(cond_df) {
cond_df |>
as.matrix() |>
t()
})
return(z_list)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.