#' Read CmdStan CSV files into R
#'
#' @description `read_cmdstan_csv()` is used internally by CmdStanR to read
#' CmdStan's output CSV files into \R. It can also be used by CmdStan users as
#' a more flexible and efficient alternative to `rstan::read_stan_csv()`. See
#' the **Value** section for details on the structure of the returned list.
#'
#' It is also possible to create CmdStanR's fitted model objects directly from
#' CmdStan CSV files using the `as_cmdstan_fit()` function.
#'
#' @export
#' @param files (character vector) The paths to the CmdStan CSV files. These can
#' be files generated by running CmdStanR or running CmdStan directly.
#' @param variables (character vector) Optionally, the names of the variables
#' (parameters, transformed parameters, and generated quantities) to read in.
#' * If `NULL` (the default) then all variables are included.
#' * If an empty string (`variables=""`) then none are included.
#' * For non-scalar variables all elements or specific elements can be selected:
#' - `variables = "theta"` selects all elements of `theta`;
#' - `variables = c("theta[1]", "theta[3]")` selects only the 1st and 3rd elements.
#' @param sampler_diagnostics (character vector) Works the same way as
#' `variables` but for sampler diagnostic variables (e.g., `"treedepth__"`,
#' `"accept_stat__"`, etc.). Ignored if the model was not fit using MCMC.
#' @param format (string) The format for storing the draws or point estimates.
#' The default depends on the method used to fit the model. See
#' [draws][fit-method-draws] for details, in particular the note about speed
#' and memory for models with many parameters.
#'
#' @return
#' `as_cmdstan_fit()` returns a [CmdStanMCMC], [CmdStanMLE], [CmdStanLaplace] or
#' [CmdStanVB] object. Some methods typically defined for those objects will not
#' work (e.g. `save_data_file()`) but the important methods like `$summary()`,
#' `$draws()`, `$sampler_diagnostics()` and others will work fine.
#'
#' `read_cmdstan_csv()` returns a named list with the following components:
#'
#' * `metadata`: A list of the meta information from the run that produced the
#' CSV file(s). See **Examples** below.
#'
#' The other components in the returned list depend on the method that produced
#' the CSV file(s).
#'
#' For [sampling][model-method-sample] the returned list also includes the
#' following components:
#'
#' * `time`: Run time information for the individual chains. The returned object
#' is the same as for the [$time()][fit-method-time] method except the total run
#' time can't be inferred from the CSV files (the chains may have been run in
#' parallel) and is therefore `NA`.
#' * `inv_metric`: A list (one element per chain) of inverse mass matrices
#' or their diagonals, depending on the type of metric used.
#' * `step_size`: A list (one element per chain) of the step sizes used.
#' * `warmup_draws`: If `save_warmup` was `TRUE` when fitting the model then a
#' [`draws_array`][posterior::draws_array] (or different format if `format` is
#' specified) of warmup draws.
#' * `post_warmup_draws`: A [`draws_array`][posterior::draws_array] (or
#' different format if `format` is specified) of post-warmup draws.
#' * `warmup_sampler_diagnostics`: If `save_warmup` was `TRUE` when fitting the
#' model then a [`draws_array`][posterior::draws_array] (or different format if
#' `format` is specified) of warmup draws of the sampler diagnostic variables.
#' * `post_warmup_sampler_diagnostics`: A
#' [`draws_array`][posterior::draws_array] (or different format if `format` is
#' specified) of post-warmup draws of the sampler diagnostic variables.
#'
#' For [optimization][model-method-optimize] the returned list also includes the
#' following components:
#'
#' * `point_estimates`: Point estimates for the model parameters.
#'
#' For [laplace][model-method-laplace] and
#' [variational inference][model-method-variational] the returned list also
#' includes the following components:
#'
#' * `draws`: A [`draws_matrix`][posterior::draws_matrix] (or different format
#' if `format` is specified) of draws from the approximate posterior
#' distribution.
#'
#' For [standalone generated quantities][model-method-generate-quantities] the
#' returned list also includes the following components:
#'
#' * `generated_quantities`: A [`draws_array`][posterior::draws_array] of
#' the generated quantities.
#'
#' @examples
#' \dontrun{
#' # Generate some CSV files to use for demonstration
#' fit1 <- cmdstanr_example("logistic", method = "sample", save_warmup = TRUE)
#' csv_files <- fit1$output_files()
#' print(csv_files)
#'
#' # Creating fitting model objects
#'
#' # Create a CmdStanMCMC object from the CSV files
#' fit2 <- as_cmdstan_fit(csv_files)
#' fit2$print("beta")
#'
#' # Using read_cmdstan_csv
#' #
#' # Read in everything
#' x <- read_cmdstan_csv(csv_files)
#' str(x)
#'
#' # Don't read in any of the sampler diagnostic variables
#' x <- read_cmdstan_csv(csv_files, sampler_diagnostics = "")
#'
#' # Don't read in any of the parameters or generated quantities
#' x <- read_cmdstan_csv(csv_files, variables = "")
#'
#' # Read in only specific parameters and sampler diagnostics
#' x <- read_cmdstan_csv(
#' csv_files,
#' variables = c("alpha", "beta[2]"),
#' sampler_diagnostics = c("n_leapfrog__", "accept_stat__")
#' )
#'
#' # For non-scalar parameters all elements can be selected or only some elements,
#' # e.g. all of the vector "beta" but only one element of the vector "log_lik"
#' x <- read_cmdstan_csv(
#' csv_files,
#' variables = c("beta", "log_lik[3]")
#' )
#' }
#'
read_cmdstan_csv <- function(files,
variables = NULL,
sampler_diagnostics = NULL,
format = getOption("cmdstanr_draws_format", NULL)) {
# If the CSV files are stored in the WSL filesystem then it is significantly
# faster (~4x) to first copy them (via WSL) to a Windows tempdir before reading
if (os_is_wsl() && any(grepl("^//wsl", files))) {
wsl_files <- sapply(files, wsl_safe_path)
temp_storage <- tempdir(check = TRUE)
csv_copy <- processx::run(
"wsl", c("cp", wsl_files, wsl_safe_path(temp_storage)),
error_on_status = FALSE
)
files <- file.path(temp_storage, basename(files))
}
format <- assert_valid_draws_format(format)
assert_file_exists(files, access = "r", extension = "csv")
metadata <- NULL
warmup_draws <- list()
draws <- list()
warmup_sampler_diagnostics <- list()
post_warmup_sampler_diagnostics <- list()
inv_metric <- list()
step_size <- list()
csv_metadata <- list()
time <- data.frame()
file_idx <- 0
for (output_file in files) {
file_idx <- length(csv_metadata) + 1
csv_metadata[[file_idx]] <- read_csv_metadata(output_file)
}
if (file_idx > 1) {
check_csv_metadata_matches(csv_metadata)
}
id <- csv_metadata[[1]]$id
if (!is.null(csv_metadata[[1]]$inv_metric)) {
inv_metric[[as.character(id)]] <- csv_metadata[[1]]$inv_metric
}
if (!is.null(csv_metadata[[1]]$step_size_adaptation)) {
step_size[[as.character(id)]] <- csv_metadata[[1]]$step_size_adaptation
}
if (!is.null(csv_metadata[[1]]$time)) {
time <- rbind(time, csv_metadata[[1]]$time)
}
if (length(csv_metadata) > 1) {
for (file_id in 2:length(csv_metadata)) {
file_metadata <- csv_metadata[[file_id]]
id <- file_metadata$id
csv_metadata[[1]]$id <- c(csv_metadata[[1]]$id, id)
csv_metadata[[1]]$seed <- c(csv_metadata[[1]]$seed, file_metadata$seed)
csv_metadata[[1]]$init <- c(csv_metadata[[1]]$init, file_metadata$init)
csv_metadata[[1]]$step_size <- c(csv_metadata[[1]]$step_size, file_metadata$step_size)
csv_metadata[[1]]$step_size_adaptation <- c(csv_metadata[[1]]$step_size_adaptation, file_metadata$step_size_adaptation)
csv_metadata[[1]]$fitted_params <- c(csv_metadata[[1]]$fitted_params, file_metadata$fitted_params)
if (!is.null(file_metadata$inv_metric)) {
inv_metric[[as.character(id)]] <- file_metadata$inv_metric
}
if (!is.null(file_metadata$step_size_adaptation)) {
step_size[[as.character(id)]] <- file_metadata$step_size_adaptation
}
if (!is.null(file_metadata$time)) {
time <- rbind(time, file_metadata$time)
}
}
}
metadata <- csv_metadata[[1]]
uniq_seed <- unique(metadata$seed)
if (length(uniq_seed) == 1) {
metadata$seed <- uniq_seed
}
metadata$time <- time
if (metadata$method == "diagnose") {
gradients <- metadata$gradients
metadata$gradients <- NULL
lp <- metadata$lp
metadata$lp <- NULL
return(list(
metadata = metadata,
gradients = gradients,
lp = lp
))
}
if (is.null(variables)) { # variables = NULL returns all
variables <- metadata$variables
} else if (!any(nzchar(variables))) { # if variables = "" returns none
variables <- NULL
} else { # filter using variables
res <- matching_variables(variables, repair_variable_names(metadata$variables))
if (length(res$not_found)) {
stop("Can't find the following variable(s) in the output: ",
paste(res$not_found, collapse = ", "), call. = FALSE)
}
variables <- unrepair_variable_names(res$matching)
}
if (is.null(sampler_diagnostics)) {
sampler_diagnostics <- metadata$sampler_diagnostics
} else if (!any(nzchar(sampler_diagnostics))) { # if sampler_diagnostics = "" returns none
sampler_diagnostics <- NULL
} else {
selected_sampler_diag <- rep(FALSE, length(metadata$sampler_diagnostics))
not_found <- NULL
for (p in sampler_diagnostics) {
matches <- metadata$sampler_diagnostics == p | startsWith(metadata$sampler_diagnostics, paste0(p, "."))
if (!any(matches)) {
not_found <- c(not_found, p)
}
selected_sampler_diag <- selected_sampler_diag | matches
}
if (length(not_found)) {
stop("Can't find the following sampler diagnostic(s) in the output: ",
paste(not_found, collapse = ", "), call. = FALSE)
}
sampler_diagnostics <- metadata$sampler_diagnostics[selected_sampler_diag]
}
num_warmup_draws <- ceiling(metadata$iter_warmup / metadata$thin)
num_post_warmup_draws <- ceiling(metadata$iter_sampling / metadata$thin)
for (output_file in files) {
if (os_is_windows()) {
grep_path_repaired <- withr::with_path(
c(
toolchain_PATH_env_var()
),
repair_path(Sys.which("grep.exe"))
)
grep_path_quotes <- paste0('"', grep_path_repaired, '"')
fread_cmd <- paste0(
grep_path_quotes,
" -v \"^#\" --color=never \"",
wsl_safe_path(output_file, revert = TRUE),
"\""
)
} else {
fread_cmd <- paste0("grep -v '^#' --color=never '", output_file, "'")
}
if (length(sampler_diagnostics) > 0) {
post_warmup_sd_id <- length(post_warmup_sampler_diagnostics) + 1
warmup_sd_id <- length(warmup_sampler_diagnostics) + 1
suppressWarnings(
post_warmup_sampler_diagnostics[[post_warmup_sd_id]] <- data.table::fread(
cmd = fread_cmd,
select = sampler_diagnostics,
data.table = FALSE
)
)
if (metadata$method == "sample" && metadata$save_warmup == 1 && num_warmup_draws > 0) {
warmup_sampler_diagnostics[[warmup_sd_id]] <-
post_warmup_sampler_diagnostics[[post_warmup_sd_id]][1:num_warmup_draws, , drop = FALSE]
if (num_post_warmup_draws > 0) {
post_warmup_sampler_diagnostics[[post_warmup_sd_id]] <-
post_warmup_sampler_diagnostics[[post_warmup_sd_id]][(num_warmup_draws + 1):(num_warmup_draws + num_post_warmup_draws), , drop = FALSE]
} else {
post_warmup_sampler_diagnostics[[post_warmup_sd_id]] <- NULL
}
}
}
if (length(variables) > 0) {
draws_list_id <- length(draws) + 1
warmup_draws_list_id <- length(warmup_draws) + 1
if (metadata$method == "pathfinder") {
metadata$variables = union(metadata$sampler_diagnostics, metadata$variables)
variables = union(metadata$sampler_diagnostics, variables)
}
suppressWarnings(
draws[[draws_list_id]] <- data.table::fread(
cmd = fread_cmd,
select = variables,
data.table = FALSE
)
)
if (metadata$method == "sample" && metadata$save_warmup == 1 && num_warmup_draws > 0) {
warmup_draws[[warmup_draws_list_id]] <-
draws[[draws_list_id]][1:num_warmup_draws, , drop = FALSE]
if (num_post_warmup_draws > 0) {
draws[[draws_list_id]] <- draws[[draws_list_id]][(num_warmup_draws + 1):(num_warmup_draws + num_post_warmup_draws), , drop = FALSE]
} else {
draws[[draws_list_id]] <- NULL
}
}
}
}
metadata$inv_metric <- NULL
metadata$variables <- repair_variable_names(metadata$variables)
repaired_variables <- repair_variable_names(variables)
if (metadata$method == "variational") {
metadata$variables <- metadata$variables[metadata$variables != "lp__"]
metadata$variables <- gsub("log_p__", "lp__", metadata$variables)
metadata$variables <- gsub("log_g__", "lp_approx__", metadata$variables)
repaired_variables <- repaired_variables[repaired_variables != "lp__"]
repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
repaired_variables <- gsub("log_g__", "lp_approx__", repaired_variables)
} else if (metadata$method == "laplace") {
metadata$variables <- gsub("log_p__", "lp__", metadata$variables)
metadata$variables <- gsub("log_q__", "lp_approx__", metadata$variables)
repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
repaired_variables <- gsub("log_q__", "lp_approx__", repaired_variables)
}
model_param_dims <- variable_dims(metadata$variables)
metadata$stan_variable_sizes <- model_param_dims
metadata$stan_variables <- names(model_param_dims)
# $model_params is deprecated, remove for release 1.0
metadata$model_params <- metadata$variables
if (metadata$method == "sample") {
if (is.null(format)) {
format <- "draws_array"
}
as_draws_format <- as_draws_format_fun(format)
if (length(warmup_draws) > 0) {
warmup_draws <- do.call(as_draws_format, list(warmup_draws))
posterior::variables(warmup_draws) <- repaired_variables
if (posterior::niterations(warmup_draws) == 0) {
warmup_draws <- NULL
}
} else {
warmup_draws <- NULL
}
if (length(draws) > 0) {
draws <- do.call(as_draws_format, list(draws))
posterior::variables(draws) <- repaired_variables
if (posterior::niterations(draws) == 0) {
draws <- NULL
}
} else {
draws <- NULL
}
if (length(warmup_sampler_diagnostics) > 0) {
warmup_sampler_diagnostics <- do.call(as_draws_format, list(warmup_sampler_diagnostics))
if (posterior::niterations(warmup_sampler_diagnostics) == 0) {
warmup_sampler_diagnostics <- NULL
}
} else {
warmup_sampler_diagnostics <- NULL
}
if (length(post_warmup_sampler_diagnostics) > 0) {
post_warmup_sampler_diagnostics <- do.call(as_draws_format, list(post_warmup_sampler_diagnostics))
if (posterior::niterations(post_warmup_sampler_diagnostics) == 0) {
post_warmup_sampler_diagnostics <- NULL
}
} else {
post_warmup_sampler_diagnostics <- NULL
}
list(
metadata = metadata,
time = list(total = NA_integer_, chains = time),
inv_metric = inv_metric,
step_size = step_size,
warmup_draws = warmup_draws,
post_warmup_draws = draws,
warmup_sampler_diagnostics = warmup_sampler_diagnostics,
post_warmup_sampler_diagnostics = post_warmup_sampler_diagnostics
)
} else if (metadata$method == "variational") {
if (is.null(format)) {
format <- "draws_matrix"
}
as_draws_format <- as_draws_format_fun(format)
if (length(draws) == 0) {
variational_draws <- NULL
} else {
variational_draws <- do.call(as_draws_format, list(draws[[1]][-1, colnames(draws[[1]]) != "lp__", drop = FALSE]))
}
if (!is.null(variational_draws)) {
if ("log_p__" %in% posterior::variables(variational_draws)) {
variational_draws <- posterior::rename_variables(variational_draws, lp__ = "log_p__")
}
if ("log_g__" %in% posterior::variables(variational_draws)) {
variational_draws <- posterior::rename_variables(variational_draws, lp_approx__ = "log_g__")
}
posterior::variables(variational_draws) <- repaired_variables
}
list(
metadata = metadata,
draws = variational_draws
)
} else if (metadata$method == "laplace") {
if (is.null(format)) {
format <- "draws_matrix"
}
as_draws_format <- as_draws_format_fun(format)
if (length(draws) == 0) {
laplace_draws <- NULL
} else {
laplace_draws <- do.call(as_draws_format, list(draws[[1]]))
}
if (!is.null(laplace_draws)) {
if ("log_p__" %in% posterior::variables(laplace_draws)) {
laplace_draws <- posterior::rename_variables(laplace_draws, lp__ = "log_p__")
}
if ("log_q__" %in% posterior::variables(laplace_draws)) {
laplace_draws <- posterior::rename_variables(laplace_draws, lp_approx__ = "log_q__")
}
posterior::variables(laplace_draws) <- repaired_variables
}
list(
metadata = metadata,
draws = laplace_draws
)
} else if (metadata$method == "optimize") {
if (is.null(format)) {
format <- "draws_matrix"
}
as_draws_format <- as_draws_format_fun(format)
if (length(draws) == 0) {
point_estimates <- NULL
} else {
point_estimates <- do.call(as_draws_format, list(draws[[1]][1, , drop = FALSE]))
point_estimates <- posterior::subset_draws(point_estimates, variable = variables)
}
if (!is.null(point_estimates)) {
posterior::variables(point_estimates) <- repaired_variables
}
list(
metadata = metadata,
point_estimates = point_estimates
)
} else if (metadata$method == "generate_quantities") {
if (is.null(format)) {
format <- "draws_array"
}
as_draws_format <- as_draws_format_fun(format)
draws <- do.call(as_draws_format, list(draws))
if (!is.null(draws)) {
posterior::variables(draws) <- repaired_variables
}
list(
metadata = metadata,
generated_quantities = draws
)
} else if (metadata$method == "pathfinder") {
if (is.null(format)) {
format <- "draws_matrix"
}
as_draws_format <- as_draws_format_fun(format)
if (length(draws) == 0) {
pathfinder_draws <- NULL
} else {
pathfinder_draws <- do.call(as_draws_format, list(draws[[1]][, colnames(draws[[1]]), drop = FALSE]))
posterior::variables(pathfinder_draws) <- repaired_variables
}
list(
metadata = metadata,
draws = pathfinder_draws
)
}
}
#' Read CmdStan CSV files from sampling into \R
#'
#' Deprecated. Use [read_cmdstan_csv()] instead.
#' @keywords internal
#' @export
#' @param files,variables,sampler_diagnostics Deprecated. Use
#' [read_cmdstan_csv()] instead.
#'
read_sample_csv <- function(files,
variables = NULL,
sampler_diagnostics = NULL) {
warning("read_sample_csv() is deprecated. Please use read_cmdstan_csv().")
read_cmdstan_csv(files, variables, sampler_diagnostics)
}
#' @rdname read_cmdstan_csv
#' @export
#' @param check_diagnostics (logical) For models fit using MCMC, should
#' diagnostic checks be performed after reading in the files? The default is
#' `TRUE` but set to `FALSE` to avoid checking for problems with divergences
#' and treedepth.
#'
as_cmdstan_fit <- function(files, check_diagnostics = TRUE, format = getOption("cmdstanr_draws_format")) {
csv_contents <- read_cmdstan_csv(files, format = format)
switch(
csv_contents$metadata$method,
"sample" = CmdStanMCMC_CSV$new(csv_contents, files, check_diagnostics),
"optimize" = CmdStanMLE_CSV$new(csv_contents, files),
"variational" = CmdStanVB_CSV$new(csv_contents, files),
"pathfinder" = CmdStanPathfinder_CSV$new(csv_contents, files),
"laplace" = CmdStanLaplace_CSV$new(csv_contents, files)
)
}
# internal ----------------------------------------------------------------
# CmdStanFit_CSV -------------------------------------------------------------
#' Create CmdStanMCMC/MLE/VB-ish objects from `read_cmdstan_csv()` output
#' instead of from a CmdStanRun object
#'
#' The resulting object has fewer methods than a CmdStanMCMC/MLE/VB object
#' because it doesn't have access to a CmdStanRun object.
#'
#' @noRd
#'
CmdStanMCMC_CSV <- R6::R6Class(
classname = "CmdStanMCMC_CSV",
inherit = CmdStanMCMC,
public = list(
initialize = function(csv_contents, files, check_diagnostics = TRUE) {
private$output_files_ <- files
private$metadata_ <- csv_contents$metadata
private$time_ <- csv_contents$time
private$inv_metric_ <- csv_contents$inv_metric
private$sampler_diagnostics_ <- csv_contents$post_warmup_sampler_diagnostics
private$warmup_sampler_diagnostics_ <- csv_contents$warmup_sampler_diagnostics
private$warmup_draws_ <- csv_contents$warmup_draws
private$draws_ <- csv_contents$post_warmup_draws
if (check_diagnostics) {
invisible(self$diagnostic_summary())
}
invisible(self)
},
# override some methods so they work without a CmdStanRun object
output_files = function(...) {
private$output_files_
},
time = function() {
private$time_
},
num_chains = function() {
posterior::nchains(self$draws())
}
),
private = list(
output_files_ = NULL,
time_ = NULL
)
)
CmdStanMLE_CSV <- R6::R6Class(
classname = "CmdStanMLE_CSV",
inherit = CmdStanMLE,
public = list(
initialize = function(csv_contents, files) {
private$output_files_ <- files
private$draws_ <- csv_contents$point_estimates
private$metadata_ <- csv_contents$metadata
invisible(self)
},
output_files = function(...) {
private$output_files_
}
),
private = list(output_files_ = NULL)
)
CmdStanLaplace_CSV <- R6::R6Class(
classname = "CmdStanLaplace_CSV",
inherit = CmdStanLaplace,
public = list(
initialize = function(csv_contents, files) {
private$output_files_ <- files
private$draws_ <- csv_contents$draws
private$metadata_ <- csv_contents$metadata
invisible(self)
},
output_files = function(...) {
private$output_files_
}
),
private = list(output_files_ = NULL)
)
CmdStanVB_CSV <- R6::R6Class(
classname = "CmdStanVB_CSV",
inherit = CmdStanVB,
public = list(
initialize = function(csv_contents, files) {
private$output_files_ <- files
private$draws_ <- csv_contents$draws
private$metadata_ <- csv_contents$metadata
invisible(self)
},
output_files = function(...) {
private$output_files_
}
),
private = list(output_files_ = NULL)
)
CmdStanPathfinder_CSV <- R6::R6Class(
classname = "CmdStanPathfinder_CSV",
inherit = CmdStanPathfinder,
public = list(
initialize = function(csv_contents, files) {
private$output_files_ <- files
private$draws_ <- csv_contents$draws
private$metadata_ <- csv_contents$metadata
},
output_files = function(...) {
private$output_files_
}
),
private = list(output_files_ = NULL)
)
# these methods are unavailable because there's no CmdStanRun object
unavailable_methods_CmdStanFit_CSV <- c(
"cmdstan_diagnose", "cmdstan_summary",
"save_data_file", "data_file",
"save_latent_dynamics_files", "latent_dynamics_files",
"save_output_files",
"init",
"output",
"return_codes",
"code",
"num_procs",
"save_profile_files", "profile_files", "profiles",
"time" # available for MCMC not others
)
error_unavailable_CmdStanFit_CSV <- function(...) {
stop("This method is not available for objects created using as_cmdstan_fit().",
call. = FALSE)
}
for (method in unavailable_methods_CmdStanFit_CSV) {
if (method != "time") {
CmdStanMCMC_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
}
CmdStanMLE_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
CmdStanVB_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
CmdStanLaplace_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
}
# csv reading internals ---------------------------------------------------
#' Reads the sampling arguments and the diagonal of the
#' inverse mass matrix from the comments in a CSV file.
#'
#' @noRd
#' @param csv_file A CSV file containing results from CmdStan.
#' @return A list containing all CmdStan settings and, for sampling, the inverse
#' mass matrix (or its diagonal depending on the metric).
#'
read_csv_metadata <- function(csv_file) {
assert_file_exists(csv_file, access = "r", extension = "csv")
inv_metric_next <- FALSE
csv_file_info <- list()
csv_file_info$inv_metric <- NULL
inv_metric_rows_to_read <- -1
inv_metric_rows <- -1
dense_inv_metric <- FALSE
diagnose_gradients <- FALSE
gradients <- data.frame()
warmup_time <- 0
sampling_time <- 0
total_time <- 0
if (os_is_windows()) {
grep_path_repaired <- withr::with_path(
c(
toolchain_PATH_env_var()
),
repair_path(Sys.which("grep.exe"))
)
grep_path_quotes <- paste0('"', grep_path_repaired, '"')
fread_cmd <- paste0(
grep_path_quotes,
" \"^[#a-zA-Z]\" --color=never \"",
wsl_safe_path(csv_file, revert = TRUE),
"\""
)
} else {
fread_cmd <- paste0("grep '^[#a-zA-Z]' --color=never '", path.expand(csv_file), "'")
}
suppressWarnings(
metadata <- data.table::fread(
cmd = fread_cmd,
colClasses = "character",
stringsAsFactors = FALSE,
fill = TRUE,
sep = "",
header = FALSE
)
)
if (is.null(metadata) || length(metadata) == 0) {
stop("Supplied CSV file is corrupt!", call. = FALSE)
}
for (line in metadata[[1]]) {
if (!startsWith(line, "#") && is.null(csv_file_info[["variables"]])) {
# if no # at the start of line, the line is the CSV header
all_names <- strsplit(line, ",")[[1]]
if (all(csv_file_info$algorithm != "fixed_param")) {
csv_file_info[["sampler_diagnostics"]] <- all_names[endsWith(all_names, "__")]
csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__", "log_q__"))]
csv_file_info[["variables"]] <- all_names[!(all_names %in% csv_file_info[["sampler_diagnostics"]])]
} else {
csv_file_info[["variables"]] <- all_names[!endsWith(all_names, "__")]
}
} else {
parse_key_val <- TRUE
if (grepl("# Diagonal elements of inverse mass matrix:", line, perl = TRUE)) {
inv_metric_next <- TRUE
parse_key_val <- FALSE
inv_metric_rows <- 1
inv_metric_rows_to_read <- 1
dense_inv_metric <- FALSE
} else if (grepl("# Elements of inverse mass matrix:", line, perl = TRUE)) {
inv_metric_next <- TRUE
parse_key_val <- FALSE
dense_inv_metric <- TRUE
} else if (inv_metric_next) {
inv_metric_split <- strsplit(gsub("# ", "", line), ",")
numeric_inv_metric_split <- rapply(inv_metric_split, as.numeric)
if (inv_metric_rows == -1 && dense_inv_metric) {
inv_metric_rows <- length(inv_metric_split[[1]])
inv_metric_rows_to_read <- inv_metric_rows
}
csv_file_info$inv_metric <- c(csv_file_info$inv_metric, numeric_inv_metric_split)
inv_metric_rows_to_read <- inv_metric_rows_to_read - 1
if (inv_metric_rows_to_read == 0) {
inv_metric_next <- FALSE
}
parse_key_val <- FALSE
} else if (diagnose_gradients) {
parse_key_val <- FALSE
tmp <- gsub("#", "", line, fixed = TRUE)
if (nzchar(tmp)) {
tmp <- strsplit(tmp, split = " ", fixed = TRUE)[[1]]
if (!("param" %in% tmp)) {
tmp <- as.numeric(tmp[nzchar(tmp)])
gradients <- rbind(gradients, tmp)
if (dim(gradients)[1] == 1) {
colnames(gradients) <- c("param_idx", "value", "model", "finite_diff", "error")
}
}
}
}
if (parse_key_val) {
tmp <- gsub("#", "", line, fixed = TRUE)
tmp <- gsub("(Default)", "", tmp, fixed = TRUE)
key_val <- grep("=", tmp, fixed = TRUE, value = TRUE)
key_val <- strsplit(key_val, split = "=", fixed = TRUE)
key_val <- rapply(key_val, trimws)
if (any(key_val[1] == "Step size")) {
key_val[1] <- "step_size_adaptation"
}
if (any(key_val[1] == "Log probability")) {
key_val[1] <- "lp"
}
if (length(key_val) == 2) {
numeric_val <- suppressWarnings(as.numeric(key_val[2]))
if (!is.na(numeric_val)) {
csv_file_info[[key_val[1]]] <- numeric_val
} else {
if (nzchar(key_val[2])) {
csv_file_info[[key_val[1]]] <- key_val[2]
}
}
} else if (grepl("(Warm-up)", tmp, fixed = TRUE)) {
tmp <- gsub("Elapsed Time:", "", tmp, fixed = TRUE)
tmp <- gsub("seconds (Warm-up)", "", tmp, fixed = TRUE)
warmup_time <- as.numeric(tmp)
} else if (grepl("(Sampling)", tmp, fixed = TRUE)) {
tmp <- gsub("seconds (Sampling)", "", tmp, fixed = TRUE)
sampling_time <- as.numeric(tmp)
} else if (grepl("(Total)", tmp, fixed = TRUE)) {
tmp <- gsub("seconds (Total)", "", tmp, fixed = TRUE)
tmp <- trimws(gsub(" Elapsed Time: ", "", tmp, fixed = TRUE))
total_time <- as.numeric(tmp)
}
if (!is.null(csv_file_info$method) &&
csv_file_info$method == "diagnose" &&
any(key_val[1] == "lp")) {
diagnose_gradients <- TRUE
}
}
}
}
if (csv_file_info$method != "diagnose" &&
length(csv_file_info$sampler_diagnostics) == 0 &&
length(csv_file_info$variables) == 0) {
stop("Supplied CSV file does not contain any variable names or data!", call. = FALSE)
}
if (inv_metric_rows > 0 && csv_file_info$metric == "dense_e") {
rows <- inv_metric_rows
cols <- length(csv_file_info$inv_metric) / inv_metric_rows
dim(csv_file_info$inv_metric) <- c(rows, cols)
}
# rename from old cmdstan names to new cmdstanX names
csv_file_info$model_name <- csv_file_info$model
csv_file_info$adapt_engaged <- csv_file_info$engaged
csv_file_info$adapt_delta <- csv_file_info$delta
csv_file_info$max_treedepth <- csv_file_info$max_depth
csv_file_info$step_size <- csv_file_info$stepsize
csv_file_info$iter_warmup <- csv_file_info$num_warmup
csv_file_info$iter_sampling <- csv_file_info$num_samples
if (csv_file_info$method %in% c("variational", "optimize", "laplace")) {
csv_file_info$threads <- csv_file_info$num_threads
} else {
csv_file_info$threads_per_chain <- csv_file_info$num_threads
}
if (csv_file_info$method == "sample") {
csv_file_info$time <- data.frame(
chain_id = csv_file_info$id,
warmup = warmup_time,
sampling = sampling_time,
total = total_time
)
}
csv_file_info$model <- NULL
csv_file_info$engaged <- NULL
csv_file_info$delta <- NULL
csv_file_info$max_depth <- NULL
csv_file_info$stepsize <- NULL
csv_file_info$num_warmup <- NULL
csv_file_info$num_samples <- NULL
csv_file_info$file <- NULL
csv_file_info$diagnostic_file <- NULL
csv_file_info$metric_file <- NULL
csv_file_info$num_threads <- NULL
if (length(gradients) > 0) {
csv_file_info$gradients <- gradients
}
# Revert any WSL-updated paths before returning the metadata
if (os_is_wsl()) {
csv_file_info$init <- wsl_safe_path(csv_file_info$init, revert = TRUE)
csv_file_info$profile_file <- wsl_safe_path(csv_file_info$profile_file,
revert = TRUE)
csv_file_info$fitted_params <- wsl_safe_path(csv_file_info$fitted_params,
revert = TRUE)
}
csv_file_info
}
#' Check that the sampling information from two CSV files matches.
#' Will throw errors if the sampling information doesn't match. If
#' it returns, the sampling information matches.
#'
#' @noRd
#' @param a,b Two lists returned by `read_csv_metadata()` to compare.
#'
check_csv_metadata_matches <- function(csv_metadata) {
model_name <- sapply(csv_metadata, function(x) x$model_name)
if (!all(model_name == model_name[1])) {
stop("Supplied CSV files were not generated with the same model!", call. = FALSE)
}
method <- sapply(csv_metadata, function(x) x$method)
if (!all(method == method[1])) {
stop("Supplied CSV files were produced by different methods and need to be read in separately!", call. = FALSE)
}
for (i in 2:length(csv_metadata)) {
if (length(csv_metadata[[1]]$variables) != length(csv_metadata[[i]]$variables) ||
!all(csv_metadata[[1]]$variables == csv_metadata[[i]]$variables)) {
stop("Supplied CSV files have samples for different variables!", call. = FALSE)
}
}
if (method[1] == "sample") {
iter_sampling <- sapply(csv_metadata, function(x) x$iter_sampling)
thin <- sapply(csv_metadata, function(x) x$thin)
save_warmup <- sapply(csv_metadata, function(x) x$save_warmup)
iter_warmup <- sapply(csv_metadata, function(x) x$iter_warmup)
if (!all(iter_sampling == iter_sampling[1]) ||
!all(thin == thin[1]) ||
!all(save_warmup == save_warmup[1]) ||
(save_warmup[1] == 1 && !all(iter_warmup == iter_warmup[1]))) {
stop("Supplied CSV files do not match in the number of output samples!", call. = FALSE)
}
} else if (method[1] == "variational") {
output_samples <- sapply(csv_metadata, function(x) x$output_samples)
if (!all(output_samples == output_samples[1])) {
stop("Supplied CSV files do not match in the number of output samples!", call. = FALSE)
}
}
match_list <- c("stan_version_major", "stan_version_minor", "stan_version_patch", "gamma", "kappa",
"t0", "init_buffer", "term_buffer", "window", "algorithm", "engine", "max_treedepth",
"metric", "stepsize_jitter", "adapt_engaged", "adapt_delta", "iter_warmup")
not_matching <- c()
for (name in names(csv_metadata[[1]])) {
if (name %in% match_list) {
values <- sapply(csv_metadata, function(x) x[[name]])
if (any(sapply(values, function(x) is.null(x)))) {
not_matching <- c(not_matching, name)
next
}
if (!all(values == values[1])) {
not_matching <- c(not_matching, name)
}
}
}
if (length(not_matching) > 0) {
not_matching_list <- paste(unique(not_matching), collapse = ", ")
warning("Supplied CSV files do not match in the following arguments: ",
paste(not_matching_list, collapse = ", "), call. = FALSE)
}
NULL
}
# convert names like beta.1.1 to beta[1,1]
repair_variable_names <- function(names) {
names <- sub("\\.", "[", names)
names <- gsub("\\.", ",", names)
names[grep("\\[", names)] <-
paste0(names[grep("\\[", names)], "]")
names
}
# convert names like beta[1,1] to beta.1.1
unrepair_variable_names <- function(names) {
names <- sub("\\[", "\\.", names)
names <- gsub(",", "\\.", names)
names <- gsub("\\]", "", names)
names
}
remaining_columns_to_read <- function(requested, currently_read, all) {
if (is.null(requested)) {
if (is.null(all)) {
return(NULL)
}
requested <- all
}
if (!any(nzchar(requested))) {
return(requested)
}
if (is.null(all)) {
unread <- requested[!(requested %in% currently_read)]
} else {
all_remaining <- all[!(all %in% currently_read)]
# identify exact matches
matched <- as.list(match(requested, all_remaining))
# loop over requests not exactly matched
for (id in which(is.na(matched))) {
matched[[id]] <-
which(startsWith(all_remaining, paste0(requested[id], "[")))
}
# collect all unread variables
unread <- all_remaining[unlist(matched)]
}
if (length(unread)) {
unique(unread)
} else {
""
}
}
#' Returns a list of dimensions for the input variables.
#'
#' @noRd
#' @param variable_names A character vector of variable names including all
#' individual elements (e.g., `c("beta[1]", "beta[2]")`, not just `"beta"`).
#' @return A list giving the dimensions of the variables. The equivalent of the
#' `par_dims` slot of RStan's stanfit objects, except that scalars have
#' dimension `1` instead of `0`.
#' @note For this function to return the correct dimensions the input must be
#' already sorted in ascending order. Since CmdStan always has the variables
#' sorted correctly we avoid a sort by not sorting again here.
#'
variable_dims <- function(variable_names = NULL) {
if (is.null(variable_names)) {
return(NULL)
}
dims <- list()
uniq_variable_names <- unique(gsub("\\[.*\\]", "", variable_names))
var_names <- gsub("\\]", "", variable_names)
for (var in uniq_variable_names) {
pattern <- paste0("^", var, "\\[")
var_indices <- var_names[grep(pattern, var_names)]
var_indices <- gsub(pattern, "", var_indices)
if (length(var_indices)) {
var_indices <- strsplit(var_indices[length(var_indices)], ",")[[1]]
dims[[var]] <- as.numeric(var_indices)
} else {
dims[[var]] <- 1
}
}
dims
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.