R/csv.R

Defines functions variable_dims remaining_columns_to_read unrepair_variable_names repair_variable_names check_csv_metadata_matches read_csv_metadata error_unavailable_CmdStanFit_CSV as_cmdstan_fit read_sample_csv read_cmdstan_csv

Documented in as_cmdstan_fit read_cmdstan_csv read_sample_csv

#' Read CmdStan CSV files into R
#'
#' @description `read_cmdstan_csv()` is used internally by CmdStanR to read
#'   CmdStan's output CSV files into \R. It can also be used by CmdStan users as
#'   a more flexible and efficient alternative to `rstan::read_stan_csv()`. See
#'   the **Value** section for details on the structure of the returned list.
#'
#'   It is also possible to create CmdStanR's fitted model objects directly from
#'   CmdStan CSV files using the `as_cmdstan_fit()` function.
#'
#' @export
#' @param files (character vector) The paths to the CmdStan CSV files. These can
#'   be files generated by running CmdStanR or running CmdStan directly.
#' @param variables (character vector) Optionally, the names of the variables
#'   (parameters, transformed parameters, and generated quantities) to read in.
#'   * If `NULL` (the default) then all variables are included.
#'   * If an empty string (`variables=""`) then none are included.
#'   * For non-scalar variables all elements or specific elements can be selected:
#'     - `variables = "theta"` selects all elements of `theta`;
#'     - `variables = c("theta[1]", "theta[3]")` selects only the 1st and 3rd elements.
#' @param sampler_diagnostics (character vector) Works the same way as
#'   `variables` but for sampler diagnostic variables (e.g., `"treedepth__"`,
#'   `"accept_stat__"`, etc.). Ignored if the model was not fit using MCMC.
#' @param format (string) The format for storing the draws or point estimates.
#'   The default depends on the method used to fit the model. See
#'   [draws][fit-method-draws] for details, in particular the note about speed
#'   and memory for models with many parameters.
#'
#' @return
#' `as_cmdstan_fit()` returns a [CmdStanMCMC], [CmdStanMLE], [CmdStanLaplace] or
#' [CmdStanVB] object. Some methods typically defined for those objects will not
#' work (e.g. `save_data_file()`) but the important methods like `$summary()`,
#' `$draws()`, `$sampler_diagnostics()` and others will work fine.
#'
#' `read_cmdstan_csv()` returns a named list with the following components:
#'
#' * `metadata`: A list of the meta information from the run that produced the
#' CSV file(s). See **Examples** below.
#'
#' The other components in the returned list depend on the method that produced
#' the CSV file(s).
#'
#' For [sampling][model-method-sample] the returned list also includes the
#' following components:
#'
#' * `time`: Run time information for the individual chains. The returned object
#' is the same as for the [$time()][fit-method-time] method except the total run
#' time can't be inferred from the CSV files (the chains may have been run in
#' parallel) and is therefore `NA`.
#' * `inv_metric`: A list (one element per chain) of inverse mass matrices
#' or their diagonals, depending on the type of metric used.
#' * `step_size`: A list (one element per chain) of the step sizes used.
#' * `warmup_draws`:  If `save_warmup` was `TRUE` when fitting the model then a
#' [`draws_array`][posterior::draws_array] (or different format if `format` is
#' specified) of warmup draws.
#' * `post_warmup_draws`: A [`draws_array`][posterior::draws_array] (or
#' different format if `format` is specified) of post-warmup draws.
#' * `warmup_sampler_diagnostics`:  If `save_warmup` was `TRUE` when fitting the
#' model then a [`draws_array`][posterior::draws_array] (or different format if
#' `format` is specified) of warmup draws of the sampler diagnostic variables.
#' * `post_warmup_sampler_diagnostics`: A
#' [`draws_array`][posterior::draws_array] (or different format if `format` is
#' specified) of post-warmup draws of the sampler diagnostic variables.
#'
#' For [optimization][model-method-optimize] the returned list also includes the
#' following components:
#'
#' * `point_estimates`: Point estimates for the model parameters.
#'
#' For [laplace][model-method-laplace] and
#' [variational inference][model-method-variational] the returned list also
#' includes the following components:
#'
#' * `draws`: A [`draws_matrix`][posterior::draws_matrix] (or different format
#' if `format` is specified) of draws from the approximate posterior
#' distribution.
#'
#' For [standalone generated quantities][model-method-generate-quantities] the
#' returned list also includes the following components:
#'
#' * `generated_quantities`: A [`draws_array`][posterior::draws_array] of
#' the generated quantities.
#'
#' @examples
#' \dontrun{
#' # Generate some CSV files to use for demonstration
#' fit1 <- cmdstanr_example("logistic", method = "sample", save_warmup = TRUE)
#' csv_files <- fit1$output_files()
#' print(csv_files)
#'
#' # Creating fitting model objects
#'
#' # Create a CmdStanMCMC object from the CSV files
#' fit2 <- as_cmdstan_fit(csv_files)
#' fit2$print("beta")
#'
#' # Using read_cmdstan_csv
#' #
#' # Read in everything
#' x <- read_cmdstan_csv(csv_files)
#' str(x)
#'
#' # Don't read in any of the sampler diagnostic variables
#' x <- read_cmdstan_csv(csv_files, sampler_diagnostics = "")
#'
#' # Don't read in any of the parameters or generated quantities
#' x <- read_cmdstan_csv(csv_files, variables = "")
#'
#' # Read in only specific parameters and sampler diagnostics
#' x <- read_cmdstan_csv(
#'   csv_files,
#'   variables = c("alpha", "beta[2]"),
#'   sampler_diagnostics = c("n_leapfrog__", "accept_stat__")
#' )
#'
#' # For non-scalar parameters all elements can be selected or only some elements,
#' # e.g. all of the vector "beta" but only one element of the vector "log_lik"
#' x <- read_cmdstan_csv(
#'   csv_files,
#'   variables = c("beta", "log_lik[3]")
#' )
#' }
#'
read_cmdstan_csv <- function(files,
                             variables = NULL,
                             sampler_diagnostics = NULL,
                             format = getOption("cmdstanr_draws_format", NULL)) {
  # If the CSV files are stored in the WSL filesystem then it is significantly
  # faster (~4x) to first copy them (via WSL) to a Windows tempdir before reading
  if (os_is_wsl() && any(grepl("^//wsl", files))) {
    wsl_files <- sapply(files, wsl_safe_path)
    temp_storage <- tempdir(check = TRUE)
    csv_copy <- processx::run(
      "wsl", c("cp", wsl_files, wsl_safe_path(temp_storage)),
      error_on_status = FALSE
    )

    files <- file.path(temp_storage, basename(files))
  }
  format <- assert_valid_draws_format(format)
  assert_file_exists(files, access = "r", extension = "csv")
  metadata <- NULL
  warmup_draws <- list()
  draws <- list()
  warmup_sampler_diagnostics <- list()
  post_warmup_sampler_diagnostics <- list()
  inv_metric <- list()
  step_size <- list()
  csv_metadata <- list()
  time <- data.frame()
  file_idx <- 0
  for (output_file in files) {
    file_idx <- length(csv_metadata) + 1
    csv_metadata[[file_idx]] <- read_csv_metadata(output_file)
  }
  if (file_idx > 1) {
    check_csv_metadata_matches(csv_metadata)
  }
  id <- csv_metadata[[1]]$id
  if (!is.null(csv_metadata[[1]]$inv_metric)) {
    inv_metric[[as.character(id)]] <- csv_metadata[[1]]$inv_metric
  }
  if (!is.null(csv_metadata[[1]]$step_size_adaptation)) {
    step_size[[as.character(id)]] <- csv_metadata[[1]]$step_size_adaptation
  }
  if (!is.null(csv_metadata[[1]]$time)) {
    time <- rbind(time, csv_metadata[[1]]$time)
  }
  if (length(csv_metadata) > 1) {
    for (file_id in 2:length(csv_metadata)) {
      file_metadata <- csv_metadata[[file_id]]
      id <- file_metadata$id
      csv_metadata[[1]]$id <- c(csv_metadata[[1]]$id, id)
      csv_metadata[[1]]$seed <- c(csv_metadata[[1]]$seed, file_metadata$seed)
      csv_metadata[[1]]$init <- c(csv_metadata[[1]]$init, file_metadata$init)
      csv_metadata[[1]]$step_size <- c(csv_metadata[[1]]$step_size, file_metadata$step_size)
      csv_metadata[[1]]$step_size_adaptation <- c(csv_metadata[[1]]$step_size_adaptation, file_metadata$step_size_adaptation)
      csv_metadata[[1]]$fitted_params <- c(csv_metadata[[1]]$fitted_params, file_metadata$fitted_params)
      if (!is.null(file_metadata$inv_metric)) {
        inv_metric[[as.character(id)]] <- file_metadata$inv_metric
      }
      if (!is.null(file_metadata$step_size_adaptation)) {
        step_size[[as.character(id)]] <- file_metadata$step_size_adaptation
      }
      if (!is.null(file_metadata$time)) {
        time <- rbind(time, file_metadata$time)
      }
    }
  }
  metadata <- csv_metadata[[1]]
  uniq_seed <- unique(metadata$seed)
  if (length(uniq_seed) == 1) {
    metadata$seed <- uniq_seed
  }
  metadata$time <- time
  if (metadata$method == "diagnose") {
    gradients <- metadata$gradients
    metadata$gradients <- NULL
    lp <- metadata$lp
    metadata$lp <- NULL
    return(list(
      metadata = metadata,
      gradients = gradients,
      lp = lp
    ))
  }
  if (is.null(variables)) { # variables = NULL returns all
    variables <- metadata$variables
  } else if (!any(nzchar(variables))) { # if variables = "" returns none
    variables <- NULL
  } else { # filter using variables
    res <- matching_variables(variables, repair_variable_names(metadata$variables))
    if (length(res$not_found)) {
      stop("Can't find the following variable(s) in the output: ",
            paste(res$not_found, collapse = ", "), call. = FALSE)
    }
    variables <- unrepair_variable_names(res$matching)
  }
  if (is.null(sampler_diagnostics)) {
    sampler_diagnostics <- metadata$sampler_diagnostics
  } else if (!any(nzchar(sampler_diagnostics))) { # if sampler_diagnostics = "" returns none
    sampler_diagnostics <- NULL
  } else {
    selected_sampler_diag <- rep(FALSE, length(metadata$sampler_diagnostics))
    not_found <- NULL
    for (p in sampler_diagnostics) {
      matches <- metadata$sampler_diagnostics == p | startsWith(metadata$sampler_diagnostics, paste0(p, "."))
      if (!any(matches)) {
        not_found <- c(not_found, p)
      }
      selected_sampler_diag <- selected_sampler_diag | matches
    }
    if (length(not_found)) {
      stop("Can't find the following sampler diagnostic(s) in the output: ",
            paste(not_found, collapse = ", "), call. = FALSE)
    }
    sampler_diagnostics <- metadata$sampler_diagnostics[selected_sampler_diag]
  }
  num_warmup_draws <- ceiling(metadata$iter_warmup / metadata$thin)
  num_post_warmup_draws <- ceiling(metadata$iter_sampling / metadata$thin)
  for (output_file in files) {
    if (os_is_windows()) {
      grep_path_repaired <- withr::with_path(
        c(
          toolchain_PATH_env_var()
        ),
        repair_path(Sys.which("grep.exe"))
      )
      grep_path_quotes <- paste0('"', grep_path_repaired, '"')
      fread_cmd <- paste0(
        grep_path_quotes,
        " -v \"^#\" --color=never \"",
        wsl_safe_path(output_file, revert = TRUE),
        "\""
      )
    } else {
      fread_cmd <- paste0("grep -v '^#' --color=never '", output_file, "'")
    }
    if (length(sampler_diagnostics) > 0) {
      post_warmup_sd_id <- length(post_warmup_sampler_diagnostics) + 1
      warmup_sd_id <- length(warmup_sampler_diagnostics) + 1
      suppressWarnings(
        post_warmup_sampler_diagnostics[[post_warmup_sd_id]] <- data.table::fread(
          cmd = fread_cmd,
          select = sampler_diagnostics,
          data.table = FALSE
        )
      )
      if (metadata$method == "sample" && metadata$save_warmup == 1 && num_warmup_draws > 0) {
        warmup_sampler_diagnostics[[warmup_sd_id]] <-
          post_warmup_sampler_diagnostics[[post_warmup_sd_id]][1:num_warmup_draws, , drop = FALSE]
        if (num_post_warmup_draws > 0) {
          post_warmup_sampler_diagnostics[[post_warmup_sd_id]] <-
            post_warmup_sampler_diagnostics[[post_warmup_sd_id]][(num_warmup_draws + 1):(num_warmup_draws + num_post_warmup_draws), , drop = FALSE]
        } else {
          post_warmup_sampler_diagnostics[[post_warmup_sd_id]] <- NULL
        }
      }
    }
    if (length(variables) > 0) {
      draws_list_id <- length(draws) + 1
      warmup_draws_list_id <- length(warmup_draws) + 1
      if (metadata$method == "pathfinder") {
        metadata$variables = union(metadata$sampler_diagnostics, metadata$variables)
        variables = union(metadata$sampler_diagnostics, variables)
      }
      suppressWarnings(
        draws[[draws_list_id]] <- data.table::fread(
          cmd = fread_cmd,
          select = variables,
          data.table = FALSE
        )
      )
      if (metadata$method == "sample" && metadata$save_warmup == 1 && num_warmup_draws > 0) {
        warmup_draws[[warmup_draws_list_id]] <-
          draws[[draws_list_id]][1:num_warmup_draws, , drop = FALSE]
        if (num_post_warmup_draws > 0) {
          draws[[draws_list_id]] <- draws[[draws_list_id]][(num_warmup_draws + 1):(num_warmup_draws + num_post_warmup_draws), , drop = FALSE]
        } else {
          draws[[draws_list_id]] <- NULL
        }
      }
    }
  }
  metadata$inv_metric <- NULL
  metadata$variables <- repair_variable_names(metadata$variables)
  repaired_variables <- repair_variable_names(variables)
  if (metadata$method == "variational") {
    metadata$variables <- metadata$variables[metadata$variables != "lp__"]
    metadata$variables <- gsub("log_p__", "lp__", metadata$variables)
    metadata$variables <- gsub("log_g__", "lp_approx__", metadata$variables)
    repaired_variables <- repaired_variables[repaired_variables != "lp__"]
    repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
    repaired_variables <- gsub("log_g__", "lp_approx__", repaired_variables)
  } else if (metadata$method == "laplace") {
    metadata$variables <- gsub("log_p__", "lp__", metadata$variables)
    metadata$variables <- gsub("log_q__", "lp_approx__", metadata$variables)
    repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
    repaired_variables <- gsub("log_q__", "lp_approx__", repaired_variables)
  }
  model_param_dims <- variable_dims(metadata$variables)
  metadata$stan_variable_sizes <- model_param_dims
  metadata$stan_variables <- names(model_param_dims)
  # $model_params is deprecated, remove for release 1.0
  metadata$model_params <- metadata$variables
  if (metadata$method == "sample") {
    if (is.null(format)) {
      format <- "draws_array"
    }
    as_draws_format <- as_draws_format_fun(format)
    if (length(warmup_draws) > 0) {
      warmup_draws <- do.call(as_draws_format, list(warmup_draws))
      posterior::variables(warmup_draws) <- repaired_variables
      if (posterior::niterations(warmup_draws) == 0) {
        warmup_draws <- NULL
      }
    } else {
      warmup_draws <- NULL
    }
    if (length(draws) > 0) {
      draws <-  do.call(as_draws_format, list(draws))
      posterior::variables(draws) <- repaired_variables
      if (posterior::niterations(draws) == 0) {
        draws <- NULL
      }
    } else {
      draws <- NULL
    }
    if (length(warmup_sampler_diagnostics) > 0) {
      warmup_sampler_diagnostics <- do.call(as_draws_format, list(warmup_sampler_diagnostics))
      if (posterior::niterations(warmup_sampler_diagnostics) == 0) {
        warmup_sampler_diagnostics <- NULL
      }
    } else {
      warmup_sampler_diagnostics <- NULL
    }
    if (length(post_warmup_sampler_diagnostics) > 0) {
      post_warmup_sampler_diagnostics <- do.call(as_draws_format, list(post_warmup_sampler_diagnostics))
      if (posterior::niterations(post_warmup_sampler_diagnostics) == 0) {
        post_warmup_sampler_diagnostics <- NULL
      }
    } else {
      post_warmup_sampler_diagnostics <- NULL
    }
    list(
      metadata = metadata,
      time = list(total = NA_integer_, chains = time),
      inv_metric = inv_metric,
      step_size = step_size,
      warmup_draws = warmup_draws,
      post_warmup_draws = draws,
      warmup_sampler_diagnostics = warmup_sampler_diagnostics,
      post_warmup_sampler_diagnostics = post_warmup_sampler_diagnostics
    )
  } else if (metadata$method == "variational") {
    if (is.null(format)) {
      format <- "draws_matrix"
    }
    as_draws_format <- as_draws_format_fun(format)
    if (length(draws) == 0) {
      variational_draws <- NULL
    } else {
      variational_draws <- do.call(as_draws_format, list(draws[[1]][-1, colnames(draws[[1]]) != "lp__", drop = FALSE]))
    }
    if (!is.null(variational_draws)) {
      if ("log_p__" %in% posterior::variables(variational_draws)) {
        variational_draws <- posterior::rename_variables(variational_draws, lp__ = "log_p__")
      }
      if ("log_g__" %in% posterior::variables(variational_draws)) {
        variational_draws <- posterior::rename_variables(variational_draws, lp_approx__ = "log_g__")
      }
      posterior::variables(variational_draws) <- repaired_variables
    }
    list(
      metadata = metadata,
      draws = variational_draws
    )
  } else if (metadata$method == "laplace") {
    if (is.null(format)) {
      format <- "draws_matrix"
    }
    as_draws_format <- as_draws_format_fun(format)
    if (length(draws) == 0) {
      laplace_draws <- NULL
    } else {
      laplace_draws <- do.call(as_draws_format, list(draws[[1]]))
    }
    if (!is.null(laplace_draws)) {
      if ("log_p__" %in% posterior::variables(laplace_draws)) {
        laplace_draws <- posterior::rename_variables(laplace_draws, lp__ = "log_p__")
      }
      if ("log_q__" %in% posterior::variables(laplace_draws)) {
        laplace_draws <- posterior::rename_variables(laplace_draws, lp_approx__ = "log_q__")
      }
      posterior::variables(laplace_draws) <- repaired_variables
    }
    list(
      metadata = metadata,
      draws = laplace_draws
    )
  } else if (metadata$method == "optimize") {
    if (is.null(format)) {
      format <- "draws_matrix"
    }
    as_draws_format <- as_draws_format_fun(format)
    if (length(draws) == 0) {
      point_estimates <- NULL
    } else {
      point_estimates <- do.call(as_draws_format, list(draws[[1]][1, , drop = FALSE]))
      point_estimates <- posterior::subset_draws(point_estimates, variable = variables)
    }
    if (!is.null(point_estimates)) {
      posterior::variables(point_estimates) <- repaired_variables
    }
    list(
      metadata = metadata,
      point_estimates = point_estimates
    )
  } else if (metadata$method == "generate_quantities") {
    if (is.null(format)) {
      format <- "draws_array"
    }
    as_draws_format <- as_draws_format_fun(format)
    draws <- do.call(as_draws_format, list(draws))
    if (!is.null(draws)) {
      posterior::variables(draws) <- repaired_variables
    }
    list(
      metadata = metadata,
      generated_quantities = draws
    )
  } else if (metadata$method == "pathfinder") {
    if (is.null(format)) {
      format <- "draws_matrix"
    }
    as_draws_format <- as_draws_format_fun(format)
    if (length(draws) == 0) {
      pathfinder_draws <- NULL
    } else {
      pathfinder_draws <- do.call(as_draws_format, list(draws[[1]][, colnames(draws[[1]]), drop = FALSE]))
      posterior::variables(pathfinder_draws) <- repaired_variables
    }
    list(
      metadata = metadata,
      draws = pathfinder_draws
    )
  }
}

#' Read CmdStan CSV files from sampling into \R
#'
#' Deprecated. Use [read_cmdstan_csv()] instead.
#' @keywords internal
#' @export
#' @param files,variables,sampler_diagnostics Deprecated. Use
#'   [read_cmdstan_csv()] instead.
#'
read_sample_csv <- function(files,
                            variables = NULL,
                            sampler_diagnostics = NULL) {
  warning("read_sample_csv() is deprecated. Please use read_cmdstan_csv().")
  read_cmdstan_csv(files, variables, sampler_diagnostics)
}

#' @rdname read_cmdstan_csv
#' @export
#' @param check_diagnostics (logical) For models fit using MCMC, should
#'   diagnostic checks be performed after reading in the files? The default is
#'   `TRUE` but set to `FALSE` to avoid checking for problems with divergences
#'   and treedepth.
#'
as_cmdstan_fit <- function(files, check_diagnostics = TRUE, format = getOption("cmdstanr_draws_format")) {
  csv_contents <- read_cmdstan_csv(files, format = format)
  switch(
    csv_contents$metadata$method,
    "sample" = CmdStanMCMC_CSV$new(csv_contents, files, check_diagnostics),
    "optimize" = CmdStanMLE_CSV$new(csv_contents, files),
    "variational" = CmdStanVB_CSV$new(csv_contents, files),
    "pathfinder" = CmdStanPathfinder_CSV$new(csv_contents, files),
    "laplace" = CmdStanLaplace_CSV$new(csv_contents, files)
  )
}


# internal ----------------------------------------------------------------

# CmdStanFit_CSV -------------------------------------------------------------
#' Create CmdStanMCMC/MLE/VB-ish objects from `read_cmdstan_csv()` output
#' instead of from a CmdStanRun object
#'
#' The resulting object has fewer methods than a CmdStanMCMC/MLE/VB object
#' because it doesn't have access to a CmdStanRun object.
#'
#' @noRd
#'
CmdStanMCMC_CSV <- R6::R6Class(
  classname = "CmdStanMCMC_CSV",
  inherit = CmdStanMCMC,
  public = list(
    initialize = function(csv_contents, files, check_diagnostics = TRUE) {
      private$output_files_ <- files
      private$metadata_ <- csv_contents$metadata
      private$time_ <- csv_contents$time
      private$inv_metric_ <- csv_contents$inv_metric
      private$sampler_diagnostics_ <- csv_contents$post_warmup_sampler_diagnostics
      private$warmup_sampler_diagnostics_ <- csv_contents$warmup_sampler_diagnostics
      private$warmup_draws_ <- csv_contents$warmup_draws
      private$draws_ <- csv_contents$post_warmup_draws
      if (check_diagnostics) {
        invisible(self$diagnostic_summary())
      }
      invisible(self)
    },
    # override some methods so they work without a CmdStanRun object
    output_files = function(...) {
      private$output_files_
    },
    time = function() {
      private$time_
    },
    num_chains = function() {
      posterior::nchains(self$draws())
    }
  ),
  private = list(
    output_files_ = NULL,
    time_ = NULL
  )
)
CmdStanMLE_CSV <- R6::R6Class(
  classname = "CmdStanMLE_CSV",
  inherit = CmdStanMLE,
  public = list(
    initialize = function(csv_contents, files) {
      private$output_files_ <- files
      private$draws_ <- csv_contents$point_estimates
      private$metadata_ <- csv_contents$metadata
      invisible(self)
    },
    output_files = function(...) {
      private$output_files_
    }
  ),
  private = list(output_files_ = NULL)
)
CmdStanLaplace_CSV <- R6::R6Class(
  classname = "CmdStanLaplace_CSV",
  inherit = CmdStanLaplace,
  public = list(
    initialize = function(csv_contents, files) {
      private$output_files_ <- files
      private$draws_ <- csv_contents$draws
      private$metadata_ <- csv_contents$metadata
      invisible(self)
    },
    output_files = function(...) {
      private$output_files_
    }
  ),
  private = list(output_files_ = NULL)
)
CmdStanVB_CSV <- R6::R6Class(
  classname = "CmdStanVB_CSV",
  inherit = CmdStanVB,
  public = list(
    initialize = function(csv_contents, files) {
      private$output_files_ <- files
      private$draws_ <- csv_contents$draws
      private$metadata_ <- csv_contents$metadata
      invisible(self)
    },
    output_files = function(...) {
      private$output_files_
    }
  ),
  private = list(output_files_ = NULL)
)

CmdStanPathfinder_CSV <- R6::R6Class(
  classname = "CmdStanPathfinder_CSV",
  inherit = CmdStanPathfinder,
  public = list(
    initialize = function(csv_contents, files) {
      private$output_files_ <- files
      private$draws_ <- csv_contents$draws
      private$metadata_ <- csv_contents$metadata
    },
    output_files = function(...) {
      private$output_files_
    }
  ),
  private = list(output_files_ = NULL)
)


# these methods are unavailable because there's no CmdStanRun object
unavailable_methods_CmdStanFit_CSV <- c(
    "cmdstan_diagnose", "cmdstan_summary",
    "save_data_file", "data_file",
    "save_latent_dynamics_files", "latent_dynamics_files",
    "save_output_files",
    "init",
    "output",
    "return_codes",
    "code",
    "num_procs",
    "save_profile_files", "profile_files", "profiles",
    "time" # available for MCMC not others
  )
error_unavailable_CmdStanFit_CSV <- function(...) {
  stop("This method is not available for objects created using as_cmdstan_fit().",
       call. = FALSE)
}
for (method in unavailable_methods_CmdStanFit_CSV) {
  if (method != "time") {
    CmdStanMCMC_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
  }
  CmdStanMLE_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
  CmdStanVB_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
  CmdStanLaplace_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
}


# csv reading internals ---------------------------------------------------

#' Reads the sampling arguments and the diagonal of the
#' inverse mass matrix from the comments in a CSV file.
#'
#' @noRd
#' @param csv_file A CSV file containing results from CmdStan.
#' @return A list containing all CmdStan settings and, for sampling, the inverse
#'   mass matrix (or its diagonal depending on the metric).
#'
read_csv_metadata <- function(csv_file) {
  assert_file_exists(csv_file, access = "r", extension = "csv")
  inv_metric_next <- FALSE
  csv_file_info <- list()
  csv_file_info$inv_metric <- NULL
  inv_metric_rows_to_read <- -1
  inv_metric_rows <- -1
  dense_inv_metric <- FALSE
  diagnose_gradients <- FALSE
  gradients <- data.frame()
  warmup_time <- 0
  sampling_time <- 0
  total_time <- 0
  if (os_is_windows()) {
    grep_path_repaired <- withr::with_path(
      c(
        toolchain_PATH_env_var()
      ),
      repair_path(Sys.which("grep.exe"))
    )
    grep_path_quotes <- paste0('"', grep_path_repaired, '"')
    fread_cmd <- paste0(
      grep_path_quotes,
      " \"^[#a-zA-Z]\" --color=never \"",
      wsl_safe_path(csv_file, revert = TRUE),
      "\""
    )
  } else {
    fread_cmd <- paste0("grep '^[#a-zA-Z]' --color=never '", path.expand(csv_file), "'")
  }
  suppressWarnings(
    metadata <- data.table::fread(
      cmd = fread_cmd,
      colClasses = "character",
      stringsAsFactors = FALSE,
      fill = TRUE,
      sep = "",
      header = FALSE
    )
  )
  if (is.null(metadata) || length(metadata) == 0) {
    stop("Supplied CSV file is corrupt!", call. = FALSE)
  }
  for (line in metadata[[1]]) {
    if (!startsWith(line, "#") && is.null(csv_file_info[["variables"]])) {
      # if no # at the start of line, the line is the CSV header
      all_names <- strsplit(line, ",")[[1]]
      if (all(csv_file_info$algorithm != "fixed_param")) {
        csv_file_info[["sampler_diagnostics"]] <- all_names[endsWith(all_names, "__")]
        csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__", "log_q__"))]
        csv_file_info[["variables"]] <- all_names[!(all_names %in% csv_file_info[["sampler_diagnostics"]])]
      } else {
        csv_file_info[["variables"]] <- all_names[!endsWith(all_names, "__")]
      }
    } else {
      parse_key_val <- TRUE
      if (grepl("# Diagonal elements of inverse mass matrix:", line, perl = TRUE)) {
        inv_metric_next <- TRUE
        parse_key_val <- FALSE
        inv_metric_rows <- 1
        inv_metric_rows_to_read <- 1
        dense_inv_metric <- FALSE
      } else if (grepl("# Elements of inverse mass matrix:", line, perl = TRUE)) {
        inv_metric_next <- TRUE
        parse_key_val <- FALSE
        dense_inv_metric <- TRUE
      } else if (inv_metric_next) {
        inv_metric_split <- strsplit(gsub("# ", "", line), ",")
        numeric_inv_metric_split <- rapply(inv_metric_split, as.numeric)
        if (inv_metric_rows == -1 && dense_inv_metric) {
          inv_metric_rows <- length(inv_metric_split[[1]])
          inv_metric_rows_to_read <- inv_metric_rows
        }
        csv_file_info$inv_metric <- c(csv_file_info$inv_metric, numeric_inv_metric_split)
        inv_metric_rows_to_read <- inv_metric_rows_to_read - 1
        if (inv_metric_rows_to_read == 0) {
          inv_metric_next <- FALSE
        }
        parse_key_val <- FALSE
      } else if (diagnose_gradients) {
        parse_key_val <- FALSE
        tmp <- gsub("#", "", line, fixed = TRUE)
        if (nzchar(tmp)) {
          tmp <- strsplit(tmp, split = " ", fixed = TRUE)[[1]]
          if (!("param" %in% tmp)) {
            tmp <- as.numeric(tmp[nzchar(tmp)])
            gradients <- rbind(gradients, tmp)
            if (dim(gradients)[1] == 1) {
              colnames(gradients) <- c("param_idx", "value", "model", "finite_diff", "error")
            }
          }
        }
      }
      if (parse_key_val) {
        tmp <- gsub("#", "", line, fixed = TRUE)
        tmp <- gsub("(Default)", "", tmp, fixed = TRUE)
        key_val <- grep("=", tmp, fixed = TRUE, value = TRUE)
        key_val <- strsplit(key_val, split = "=", fixed = TRUE)
        key_val <- rapply(key_val, trimws)
        if (any(key_val[1] == "Step size")) {
          key_val[1] <- "step_size_adaptation"
        }
        if (any(key_val[1] == "Log probability")) {
          key_val[1] <- "lp"
        }
        if (length(key_val) == 2) {
          numeric_val <- suppressWarnings(as.numeric(key_val[2]))
          if (!is.na(numeric_val)) {
            csv_file_info[[key_val[1]]] <- numeric_val
          } else {
            if (nzchar(key_val[2])) {
              csv_file_info[[key_val[1]]] <- key_val[2]
            }
          }
        } else if (grepl("(Warm-up)", tmp, fixed = TRUE)) {
          tmp <- gsub("Elapsed Time:", "", tmp, fixed = TRUE)
          tmp <- gsub("seconds (Warm-up)", "", tmp, fixed = TRUE)
          warmup_time <- as.numeric(tmp)
        } else if (grepl("(Sampling)", tmp, fixed = TRUE)) {
          tmp <- gsub("seconds (Sampling)", "", tmp, fixed = TRUE)
          sampling_time <- as.numeric(tmp)
        } else if (grepl("(Total)", tmp, fixed = TRUE)) {
          tmp <- gsub("seconds (Total)", "", tmp, fixed = TRUE)
          tmp <- trimws(gsub(" Elapsed Time: ", "", tmp, fixed = TRUE))
          total_time <- as.numeric(tmp)
        }
        if (!is.null(csv_file_info$method) &&
            csv_file_info$method == "diagnose" &&
            any(key_val[1] == "lp")) {
          diagnose_gradients <- TRUE
        }
      }
    }
  }
  if (csv_file_info$method != "diagnose" &&
      length(csv_file_info$sampler_diagnostics) == 0 &&
      length(csv_file_info$variables) == 0) {
    stop("Supplied CSV file does not contain any variable names or data!", call. = FALSE)
  }
  if (inv_metric_rows > 0 && csv_file_info$metric == "dense_e") {
    rows <- inv_metric_rows
    cols <- length(csv_file_info$inv_metric) / inv_metric_rows
    dim(csv_file_info$inv_metric) <- c(rows, cols)
  }

  # rename from old cmdstan names to new cmdstanX names
  csv_file_info$model_name <- csv_file_info$model
  csv_file_info$adapt_engaged <- csv_file_info$engaged
  csv_file_info$adapt_delta <- csv_file_info$delta
  csv_file_info$max_treedepth <- csv_file_info$max_depth
  csv_file_info$step_size <- csv_file_info$stepsize
  csv_file_info$iter_warmup <- csv_file_info$num_warmup
  csv_file_info$iter_sampling <- csv_file_info$num_samples
  if (csv_file_info$method %in% c("variational", "optimize", "laplace")) {
    csv_file_info$threads <- csv_file_info$num_threads
  } else {
    csv_file_info$threads_per_chain <- csv_file_info$num_threads
  }
  if (csv_file_info$method == "sample") {
    csv_file_info$time <- data.frame(
      chain_id = csv_file_info$id,
      warmup = warmup_time,
      sampling = sampling_time,
      total = total_time
    )
  }
  csv_file_info$model <- NULL
  csv_file_info$engaged <- NULL
  csv_file_info$delta <- NULL
  csv_file_info$max_depth <- NULL
  csv_file_info$stepsize <- NULL
  csv_file_info$num_warmup <- NULL
  csv_file_info$num_samples <- NULL
  csv_file_info$file <- NULL
  csv_file_info$diagnostic_file <- NULL
  csv_file_info$metric_file <- NULL
  csv_file_info$num_threads <- NULL
  if (length(gradients) > 0) {
    csv_file_info$gradients <- gradients
  }

  # Revert any WSL-updated paths before returning the metadata
  if (os_is_wsl()) {
    csv_file_info$init <- wsl_safe_path(csv_file_info$init, revert = TRUE)
    csv_file_info$profile_file <- wsl_safe_path(csv_file_info$profile_file,
                                                revert = TRUE)
    csv_file_info$fitted_params <- wsl_safe_path(csv_file_info$fitted_params,
                                                  revert = TRUE)
  }
  csv_file_info
}

#' Check that the sampling information from two CSV files matches.
#' Will throw errors if the sampling information doesn't match. If
#' it returns, the sampling information matches.
#'
#' @noRd
#' @param a,b Two lists returned by `read_csv_metadata()` to compare.
#'
check_csv_metadata_matches <- function(csv_metadata) {
  model_name <- sapply(csv_metadata, function(x) x$model_name)
  if (!all(model_name == model_name[1])) {
    stop("Supplied CSV files were not generated with the same model!", call. = FALSE)
  }
  method <- sapply(csv_metadata, function(x) x$method)
  if (!all(method == method[1])) {
    stop("Supplied CSV files were produced by different methods and need to be read in separately!", call. = FALSE)
  }
  for (i in 2:length(csv_metadata)) {
    if (length(csv_metadata[[1]]$variables) != length(csv_metadata[[i]]$variables) ||
      !all(csv_metadata[[1]]$variables == csv_metadata[[i]]$variables)) {
      stop("Supplied CSV files have samples for different variables!", call. = FALSE)
    }
  }
  if (method[1] == "sample") {
    iter_sampling <- sapply(csv_metadata, function(x) x$iter_sampling)
    thin <- sapply(csv_metadata, function(x) x$thin)
    save_warmup <- sapply(csv_metadata, function(x) x$save_warmup)
    iter_warmup <- sapply(csv_metadata, function(x) x$iter_warmup)
    if (!all(iter_sampling == iter_sampling[1]) ||
        !all(thin == thin[1]) ||
        !all(save_warmup == save_warmup[1]) ||
        (save_warmup[1] == 1 && !all(iter_warmup == iter_warmup[1]))) {
      stop("Supplied CSV files do not match in the number of output samples!", call. = FALSE)
    }
  } else if (method[1] == "variational") {
    output_samples <- sapply(csv_metadata, function(x) x$output_samples)
    if (!all(output_samples == output_samples[1])) {
      stop("Supplied CSV files do not match in the number of output samples!", call. = FALSE)
    }
  }
  match_list <- c("stan_version_major", "stan_version_minor", "stan_version_patch", "gamma", "kappa",
                  "t0", "init_buffer", "term_buffer", "window", "algorithm", "engine", "max_treedepth",
                  "metric", "stepsize_jitter", "adapt_engaged", "adapt_delta", "iter_warmup")
  not_matching <- c()
  for (name in names(csv_metadata[[1]])) {
    if (name %in% match_list) {
      values <- sapply(csv_metadata, function(x) x[[name]])
      if (any(sapply(values, function(x) is.null(x)))) {
        not_matching <- c(not_matching, name)
        next
      }
      if (!all(values == values[1])) {
        not_matching <- c(not_matching, name)
      }
    }
  }
  if (length(not_matching) > 0) {
    not_matching_list <- paste(unique(not_matching), collapse = ", ")
    warning("Supplied CSV files do not match in the following arguments: ",
            paste(not_matching_list, collapse = ", "), call. = FALSE)
  }
  NULL
}

# convert names like beta.1.1 to beta[1,1]
repair_variable_names <- function(names) {
  names <- sub("\\.", "[", names)
  names <- gsub("\\.", ",", names)
  names[grep("\\[", names)] <-
    paste0(names[grep("\\[", names)], "]")
  names
}

# convert names like beta[1,1] to beta.1.1
unrepair_variable_names <- function(names) {
  names <- sub("\\[", "\\.", names)
  names <- gsub(",", "\\.",  names)
  names <- gsub("\\]", "",  names)
  names
}

remaining_columns_to_read <- function(requested, currently_read, all) {
  if (is.null(requested)) {
    if (is.null(all)) {
      return(NULL)
    }
    requested <- all
  }
  if (!any(nzchar(requested))) {
    return(requested)
  }
  if (is.null(all)) {
    unread <- requested[!(requested %in% currently_read)]
  } else {
    all_remaining <- all[!(all %in% currently_read)]
    # identify exact matches
    matched <- as.list(match(requested, all_remaining))
    # loop over requests not exactly matched
    for (id in which(is.na(matched))) {
      matched[[id]] <-
        which(startsWith(all_remaining, paste0(requested[id], "[")))
    }
    # collect all unread variables
    unread <- all_remaining[unlist(matched)]
  }
  if (length(unread)) {
    unique(unread)
  } else {
    ""
  }
}

#' Returns a list of dimensions for the input variables.
#'
#' @noRd
#' @param variable_names A character vector of variable names including all
#'   individual elements (e.g., `c("beta[1]", "beta[2]")`, not just `"beta"`).
#' @return A list giving the dimensions of the variables. The equivalent of the
#'   `par_dims` slot of RStan's stanfit objects, except that scalars have
#'   dimension `1` instead of `0`.
#' @note For this function to return the correct dimensions the input must be
#'   already sorted in ascending order. Since CmdStan always has the variables
#'   sorted correctly we avoid a sort by not sorting again here.
#'
variable_dims <- function(variable_names = NULL) {
  if (is.null(variable_names)) {
    return(NULL)
  }
  dims <- list()
  uniq_variable_names <- unique(gsub("\\[.*\\]", "", variable_names))
  var_names <- gsub("\\]", "", variable_names)
  for (var in uniq_variable_names) {
    pattern <- paste0("^", var, "\\[")
    var_indices <- var_names[grep(pattern, var_names)]
    var_indices <- gsub(pattern, "", var_indices)
    if (length(var_indices)) {
      var_indices <- strsplit(var_indices[length(var_indices)], ",")[[1]]
      dims[[var]] <- as.numeric(var_indices)
    } else {
      dims[[var]] <- 1
    }
  }
  dims
}
stan-dev/cmdstanr documentation built on May 16, 2024, 12:58 a.m.