R/fit.R

Defines functions as_draws.CmdStanPathfinder as_draws.CmdStanGQ as_draws.CmdStanVB as_draws.CmdStanLaplace as_draws.CmdStanMLE as_draws.CmdStanMCMC lp_diagnose gradients lp_approx mle num_chains inv_metric diagnostic_summary sampler_diagnostics loo code profiles return_codes metadata output time metric_files config_files data_file latent_dynamics_files profile_files output_files save_metric_files save_config_files save_data_file save_profile_files save_latent_dynamics_files save_output_files cmdstan_diagnose cmdstan_summary summary lp_approx lp constrain_variables variable_skeleton unconstrain_draws unconstrain_variables hessian grad_log_prob log_prob init_model_methods init draws save_object

Documented in as_draws.CmdStanGQ as_draws.CmdStanLaplace as_draws.CmdStanMCMC as_draws.CmdStanMLE as_draws.CmdStanPathfinder as_draws.CmdStanVB cmdstan_diagnose cmdstan_summary code config_files constrain_variables data_file diagnostic_summary draws gradients grad_log_prob hessian init init_model_methods inv_metric latent_dynamics_files log_prob loo lp lp_approx metadata metric_files mle num_chains output output_files profile_files profiles return_codes sampler_diagnostics save_config_files save_data_file save_latent_dynamics_files save_metric_files save_object save_output_files save_profile_files summary time unconstrain_draws unconstrain_variables variable_skeleton

# CmdStanFit ---------------------------------------------------
#' CmdStanFit superclass
#'
#' @noRd
#' @description CmdStanMCMC, CmdStanMLE, CmdStanLaplace, CmdStanVB, CmdStanGQ
#'   all share the methods of the superclass CmdStanFit and also have their own
#'   unique methods.
#'
CmdStanFit <- R6::R6Class(
  classname = "CmdStanFit",
  public = list(
    runset = NULL,
    functions = NULL,
    initialize = function(runset) {
      checkmate::assert_r6(runset, classes = "CmdStanRun")
      self$runset <- runset

      private$return_codes_ <- self$runset$procs$return_codes()

      private$model_methods_env_ <- new.env()
      if (!is.null(runset$model_methods_env())) {
        for (n in ls(runset$model_methods_env(), all.names = TRUE)) {
          assign(n, get(n, runset$model_methods_env()), private$model_methods_env_)
        }
      }

      self$functions <- new.env()
      if (!is.null(runset$standalone_env())) {
        for (n in ls(runset$standalone_env(), all.names = TRUE)) {
          assign(n, get(n, runset$standalone_env()), self$functions)
        }
      }

      if (!is.null(private$model_methods_env_$model_ptr)) {
        initialize_model_pointer(private$model_methods_env_, self$data_file(), 0)
      }
      # Need to update the output directory path to one that can be accessed
      # from Windows, for the post-processing of results
      self$runset$args$output_dir <- wsl_safe_path(self$runset$args$output_dir, revert = TRUE)
      invisible(self)
    },
    num_procs = function() {
      self$runset$num_procs()
    },
    print = function(variables = NULL, ..., digits = 2, max_rows = getOption("cmdstanr_max_rows", 10)) {
      if (is.null(private$draws_) &&
          !length(self$output_files(include_failed = FALSE))) {
        stop("Fitting failed. Unable to print.", call. = FALSE)
      }

      # filter variables before passing to summary to avoid computing anything
      # that won't be printed because of max_rows
      all_variables <- self$metadata()$variables
      if (is.null(variables)) {
        total_rows <- length(all_variables)
        variables_to_print <- all_variables[seq_len(max_rows)]
      } else {
        matches <- matching_variables(variables, all_variables)
        if (length(matches$not_found) > 0) {
          stop("Can't find the following variable(s): ",
               paste(matches$not_found, collapse = ", "), call. = FALSE)
        }
        total_rows <- length(matches$matching)
        variables_to_print <- matches$matching[seq_len(max_rows)]
      }
      # if max_rows > length(variables_to_print) some will be NA
      variables_to_print <- variables_to_print[!is.na(variables_to_print)]

      out <- self$summary(variables_to_print, ...)
      out <- as.data.frame(out)
      out[,  1] <- base::format(out[, 1], justify = "left")
      out[, -1] <- base::format(round(out[, -1], digits = digits), nsmall = digits)
      for (col in grep("ess_", colnames(out), value = TRUE)) {
        out[[col]] <- as.integer(out[[col]])
      }

      opts <- options(max.print = prod(dim(out)))
      on.exit(options(max.print = opts$max.print), add = TRUE)
      base::print(out, row.names = FALSE)
      if (max_rows < total_rows) {
        cat("\n # showing", max_rows, "of", total_rows,
            "rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)\n")
      }
      invisible(self)
    },
    expose_functions = function(global = FALSE, verbose = FALSE) {
      expose_stan_functions(self$functions, global, verbose)
      invisible(NULL)
    }
  ),
  private = list(
    draws_ = NULL,
    metadata_ = NULL,
    init_ = NULL,
    profiles_ = NULL,
    model_methods_env_ = NULL,
    return_codes_ = NULL
  )
)

#' Save fitted model object to a file
#'
#' @name fit-method-save_object
#' @aliases save_object
#' @description This method is a wrapper around [base::saveRDS()] that ensures
#'   that all posterior draws and diagnostics are saved when saving a fitted
#'   model object. Because the contents of the CmdStan output CSV files are only
#'   read into R lazily (i.e., as needed), the `$save_object()` method is the
#'   safest way to guarantee that everything has been read in before saving.
#'
#' @param file (string) Path where the file should be saved.
#' @param ... Other arguments to pass to [base::saveRDS()] besides `object` and `file`.
#'
#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`], [`CmdStanGQ`]
#'
#' @examples
#' \dontrun{
#' fit <- cmdstanr_example("logistic")
#'
#' temp_rds_file <- tempfile(fileext = ".RDS")
#' fit$save_object(file = temp_rds_file)
#' rm(fit)
#'
#' fit <- readRDS(temp_rds_file)
#' fit$summary()
#' }
#'
save_object <- function(file, ...) {
  self$draws()
  try(self$sampler_diagnostics(), silent = TRUE)
  try(self$init(), silent = TRUE)
  try(self$profiles(), silent = TRUE)
  saveRDS(self, file = file, ...)
  invisible(self)
}
CmdStanFit$set("public", name = "save_object", value = save_object)

#' Extract posterior draws
#'
#' @name fit-method-draws
#' @aliases draws
#' @description Extract posterior draws after MCMC or approximate posterior
#'   draws after variational approximation using formats provided by the
#'   \pkg{posterior} package.
#'
#'   The variables include the parameters, transformed parameters, and
#'   generated quantities from the Stan program as well as `lp__`, the total
#'   log probability (`target`) accumulated in the model block.
#'
#' @inheritParams read_cmdstan_csv
#' @param inc_warmup (logical) Should warmup draws be included? Defaults to
#'   `FALSE`. Ignored except when used with [CmdStanMCMC] objects.
#' @param format (string) The format of the returned draws or point estimates.
#'   Must be a valid format from the \pkg{posterior} package. The defaults
#'   are the following.
#'
#'   * For sampling and generated quantities the default is
#'   [`"draws_array"`][posterior::draws_array]. This format keeps the chains
#'   separate. To combine the chains use any of the other formats (e.g.
#'   `"draws_matrix"`).
#'
#'   * For point estimates from optimization and approximate draws from
#'   variational inference the default is
#'   [`"draws_matrix"`][posterior::draws_array].
#'
#'   To use a different format it can be specified as the full name of the
#'   format from the \pkg{posterior} package (e.g. `format = "draws_df"`) or
#'   omitting the `"draws_"` prefix (e.g. `format = "df"`).
#'
#'   **Changing the default format**: To change the default format for an entire
#'   R session use `options(cmdstanr_draws_format = format)`, where `format` is
#'   the name (in quotes) of a valid format from the posterior package. For
#'   example `options(cmdstanr_draws_format = "draws_df")` will change the
#'   default to a data frame.
#'
#'   **Note about efficiency**: For models with a large number of parameters
#'   (20k+) we recommend using the `"draws_list"` format, which is the most
#'   efficient and RAM friendly when combining draws from multiple chains. If
#'   speed or memory is not a constraint we recommend selecting the format that
#'   most suits the coding style of the post processing phase.
#'
#' @return
#' Depends on the value of `format`. The defaults are:
#' * For [MCMC][model-method-sample], a 3-D
#' [`draws_array`][posterior::draws_array] object (iteration x chain x
#' variable).
#' * For standalone [generated quantities][model-method-generate-quantities], a
#' 3-D [`draws_array`][posterior::draws_array] object (iteration x chain x
#' variable).
#' * For [variational inference][model-method-variational], a 2-D
#' [`draws_matrix`][posterior::draws_matrix] object (draw x variable) because
#' there are no chains. An additional variable `lp_approx__` is also included,
#' which is the log density of the variational approximation to the posterior
#' evaluated at each of the draws.
#' * For [optimization][model-method-optimize], a 1-row
#' [`draws_matrix`][posterior::draws_matrix] with one column per variable. These
#' are *not* actually draws, just point estimates stored in the `draws_matrix`
#' format. See [`$mle()`][fit-method-mle] to extract them as a numeric vector.
#'
#'
#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`], [`CmdStanGQ`]
#'
#' @examples
#' \dontrun{
#' # logistic regression with intercept alpha and coefficients beta
#' fit <- cmdstanr_example("logistic", method = "sample")
#'
#' # returned as 3-D array (see ?posterior::draws_array)
#' draws <- fit$draws()
#' dim(draws)
#' str(draws)
#'
#' # can easily convert to other formats (data frame, matrix, list)
#' # using the posterior package
#' head(posterior::as_draws_matrix(draws))
#'
#' # or can specify 'format' argument to avoid manual conversion
#' # matrix format combines all chains
#' draws <- fit$draws(format = "matrix")
#' head(draws)
#'
#' # can select specific parameters
#' fit$draws("alpha")
#' fit$draws("beta")  # selects entire vector beta
#' fit$draws(c("alpha", "beta[2]"))
#'
#' # can be passed directly to bayesplot plotting functions
#' bayesplot::color_scheme_set("brightblue")
#' bayesplot::mcmc_dens(fit$draws(c("alpha", "beta")))
#' bayesplot::mcmc_scatter(fit$draws(c("beta[1]", "beta[2]")), alpha = 0.3)
#'
#'
#' # example using variational inference
#' fit <- cmdstanr_example("logistic", method = "variational")
#' head(fit$draws("beta")) # a matrix by default
#' head(fit$draws("beta", format = "df"))
#' }
#'
draws <- function(variables = NULL, inc_warmup = FALSE, format = getOption("cmdstanr_draws_format")) {
  # CmdStanMCMC and CmdStanGQ have separate implementations,
  # this is used for CmdStanVB and CmdStanMLE
  if (is.null(format)) {
    format <- "draws_matrix"
  } else {
    format <- assert_valid_draws_format(format)
  }
  if (!length(self$output_files(include_failed = FALSE))) {
    stop("Fitting failed. Unable to retrieve the draws.", call. = FALSE)
  }
  if (inc_warmup) {
    warning("'inc_warmup' is ignored except when used with CmdStanMCMC objects.",
            call. = FALSE)
  }
  if (is.null(private$draws_)) {
    private$read_csv_(format = format)
  }
  private$draws_ <- maybe_convert_draws_format(private$draws_, format)
  posterior::subset_draws(private$draws_, variable = variables)
}
CmdStanFit$set("public", name = "draws", value = draws)

#' Extract user-specified initial values
#'
#' @name fit-method-init
#' @aliases init
#' @description Return user-specified initial values. If the user provided
#'   initial values files or \R objects (list of lists or function) via the
#'   `init` argument when fitting the model then these are returned (always in
#'   the list of lists format). Currently it is not possible to extract initial
#'   values generated automatically by CmdStan, although CmdStan may support
#'   this in the future.
#'
#' @return A list of lists. See **Examples**.
#'
#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`]
#'
#' @examples
#' \dontrun{
#' init_fun <- function() list(alpha = rnorm(1), beta = rnorm(3))
#' fit <- cmdstanr_example("logistic", init = init_fun, chains = 2)
#' str(fit$init())
#'
#' # partial inits (only specifying for a subset of parameters)
#' init_list <- list(
#'   list(mu = 10, tau = 2),
#'   list(mu = -10, tau = 1)
#' )
#' fit <- cmdstanr_example("schools_ncp", init = init_list, chains = 2, adapt_delta = 0.9)
#'
#' # only user-specified inits returned
#' str(fit$init())
#' }
#'
init <- function() {
  if (is.null(private$init_)) {
    init_paths <- self$metadata()$init
    if (!is.character(init_paths) || any(!file.exists(init_paths))) {
      stop("Can't find initial values files.", call. = FALSE)
    }
    private$init_ <- lapply(init_paths, jsonlite::read_json, simplifyVector = TRUE)
  }
  private$init_
}
CmdStanFit$set("public", name = "init", value = init)

#' Compile additional methods for accessing the model log-probability function
#' and parameter constraining and unconstraining.
#'
#' @name fit-method-init_model_methods
#' @aliases init_model_methods
#'
#' @description The `$init_model_methods()` method compiles and initializes the
#'   `log_prob`, `grad_log_prob`, `constrain_variables`, `unconstrain_variables`
#'   and `unconstrain_draws` functions. These are then available as methods of
#'   the fitted model object. This requires the additional `Rcpp` package,
#'   which are not required for fitting models using
#'   CmdStanR.
#'
#'   Note: there may be many compiler warnings emitted during compilation but
#'   these can be ignored so long as they are warnings and not errors.
#'
#' @param seed (integer) The random seed to use when initializing the model.
#' @param verbose (logical) Whether to show verbose logging during compilation.
#' @param hessian (logical) Whether to expose the (experimental) hessian method.
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#' }
#' @seealso [log_prob()], [grad_log_prob()], [constrain_variables()],
#'   [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#'   [hessian()]
#'
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
  if (os_is_wsl()) {
    stop("Additional model methods are not currently available with ",
          "WSL CmdStan and will not be compiled",
          call. = FALSE)
  }
  require_suggested_package("Rcpp")
  if (length(private$model_methods_env_$hpp_code_) == 0) {
    stop("Model methods cannot be used with a pre-compiled Stan executable, ",
          "the model must be compiled again", call. = FALSE)
  }
  if (hessian) {
    message("The hessian method relies on higher-order autodiff ",
            "which is still experimental. Please report any compilation ",
            "errors that you encounter")
  }
  message("Compiling additional model methods...")
  if (is.null(private$model_methods_env_$model_ptr)) {
    expose_model_methods(private$model_methods_env_, verbose, hessian)
  }
  initialize_model_pointer(private$model_methods_env_, self$data_file(), seed)
  invisible(NULL)
}
CmdStanFit$set("public", name = "init_model_methods", value = init_model_methods)

#' Calculate the log-probability given a provided vector of unconstrained parameters.
#'
#' @name fit-method-log_prob
#' @aliases log_prob
#' @description The `$log_prob()` method provides access to the Stan model's
#'   `log_prob` function.
#'
#' @param unconstrained_variables (numeric) A vector of unconstrained parameters.
#' @param jacobian (logical) Whether to include the log-density adjustments from
#'   un/constraining variables.
#' @param jacobian_adjustment Deprecated. Please use `jacobian` instead.
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#' fit_mcmc$log_prob(unconstrained_variables = c(0.5, 1.2, 1.1, 2.2))
#' }
#'
#' @seealso [log_prob()], [grad_log_prob()], [constrain_variables()],
#'   [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#'   [hessian()]
#'
log_prob <- function(unconstrained_variables, jacobian = TRUE, jacobian_adjustment = NULL) {
  if (!is.null(jacobian_adjustment)) {
    warning("'jacobian_adjustment' is deprecated. Please use 'jacobian' instead.", call. = FALSE)
    jacobian <- jacobian_adjustment
  }
  if (is.null(private$model_methods_env_$model_ptr)) {
    stop("The method has not been compiled, please call `init_model_methods()` first",
        call. = FALSE)
  }
  if (length(unconstrained_variables) != private$model_methods_env_$num_upars_) {
    stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
          length(unconstrained_variables), " were provided!", call. = FALSE)
  }
  private$model_methods_env_$log_prob(private$model_methods_env_$model_ptr_,
                                      unconstrained_variables, jacobian)
}
CmdStanFit$set("public", name = "log_prob", value = log_prob)

#' Calculate the log-probability and the gradient w.r.t. each input for a
#' given vector of unconstrained parameters
#'
#' @name fit-method-grad_log_prob
#' @aliases grad_log_prob
#' @description The `$grad_log_prob()` method provides access to the Stan
#'   model's `log_prob` function and its derivative.
#' @inheritParams fit-method-log_prob
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#' fit_mcmc$grad_log_prob(unconstrained_variables = c(0.5, 1.2, 1.1, 2.2))
#' }
#'
#' @seealso [log_prob()], [grad_log_prob()], [constrain_variables()],
#'   [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#'   [hessian()]
#'
grad_log_prob <- function(unconstrained_variables, jacobian = TRUE, jacobian_adjustment = NULL) {
  if (!is.null(jacobian_adjustment)) {
    warning("'jacobian_adjustment' is deprecated. Please use 'jacobian' instead.", call. = FALSE)
    jacobian <- jacobian_adjustment
  }
  if (is.null(private$model_methods_env_$model_ptr)) {
    stop("The method has not been compiled, please call `init_model_methods()` first",
        call. = FALSE)
  }
  if (length(unconstrained_variables) != private$model_methods_env_$num_upars_) {
    stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
          length(unconstrained_variables), " were provided!", call. = FALSE)
  }
  private$model_methods_env_$grad_log_prob(private$model_methods_env_$model_ptr_,
                                            unconstrained_variables, jacobian)
}
CmdStanFit$set("public", name = "grad_log_prob", value = grad_log_prob)

#' Calculate the log-probability , the gradient w.r.t. each input, and the hessian
#' for a given vector of unconstrained parameters
#'
#' @name fit-method-hessian
#' @aliases hessian
#' @description The `$hessian()` method provides access to the Stan model's
#'   `log_prob`, its derivative, and its hessian.
#' @inheritParams fit-method-log_prob
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' # fit_mcmc$init_model_methods(hessian = TRUE)
#' # fit_mcmc$hessian(unconstrained_variables = c(0.5, 1.2, 1.1, 2.2))
#' }
#'
#' @seealso [log_prob()], [grad_log_prob()], [constrain_variables()],
#'   [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#'   [hessian()]
#'
hessian <- function(unconstrained_variables, jacobian = TRUE, jacobian_adjustment = NULL) {
  if (!is.null(jacobian_adjustment)) {
    warning("'jacobian_adjustment' is deprecated. Please use 'jacobian' instead.", call. = FALSE)
    jacobian <- jacobian_adjustment
  }
  if (is.null(private$model_methods_env_$model_ptr)) {
    stop("The method has not been compiled, please call `init_model_methods()` first",
        call. = FALSE)
  }
  if (length(unconstrained_variables) != private$model_methods_env_$num_upars_) {
    stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
          length(unconstrained_variables), " were provided!", call. = FALSE)
  }
  private$model_methods_env_$hessian(private$model_methods_env_$model_ptr_,
                                      unconstrained_variables, jacobian)
}
CmdStanFit$set("public", name = "hessian", value = hessian)

#' Transform a set of parameter values to the unconstrained scale
#'
#' @name fit-method-unconstrain_variables
#' @aliases unconstrain_variables
#' @description The `$unconstrain_variables()` method transforms input
#'   parameters to the unconstrained scale.
#'
#' @param variables (list) A list of parameter values to transform, in the same
#'   format as provided to the `init` argument of the `$sample()` method.
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#' fit_mcmc$unconstrain_variables(list(alpha = 0.5, beta = c(0.7, 1.1, 0.2)))
#' }
#'
#' @seealso [log_prob()], [grad_log_prob()], [constrain_variables()],
#'   [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#'   [hessian()]
#'
unconstrain_variables <- function(variables) {
  if (is.null(private$model_methods_env_$model_ptr)) {
    stop("The method has not been compiled, please call `init_model_methods()` first",
        call. = FALSE)
  }
  model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
  prov_par_names <- names(variables)

  # Ignore extraneous parameters
  model_pars_only <- variables[model_par_names]
  model_variables <- self$runset$args$model_variables

  # If zero-length parameters are present, they will be listed in model_variables
  # but not in metadata()$variables
  nonzero_length_params <- names(model_variables$parameters) %in% model_par_names
  model_par_names <- names(model_variables$parameters[nonzero_length_params])

  model_pars_not_prov <- which(!(model_par_names %in% prov_par_names))
  if (length(model_pars_not_prov) > 0) {
    stop("Model parameter(s): ", paste(model_par_names[model_pars_not_prov], collapse = ","),
         " not provided!", call. = FALSE)
  }

  variables_vector <- unlist(variables[model_par_names], recursive = TRUE)
  private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, variables_vector)
}
CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_variables)

#' Transform all parameter draws to the unconstrained scale
#'
#' @name fit-method-unconstrain_draws
#' @aliases unconstrain_draws
#' @description The `$unconstrain_draws()` method transforms all parameter draws
#'   to the unconstrained scale. The method returns a list for each chain,
#'   containing the parameter values from each iteration on the unconstrained
#'   scale. If called with no arguments, then the draws within the fit object
#'   are unconstrained. Alternatively, either an existing draws object or a
#'   character vector of paths to CSV files can be passed.
#'
#' @param files (character vector) The paths to the CmdStan CSV files. These can
#'   be files generated by running CmdStanR or running CmdStan directly.
#' @param draws A `posterior::draws_*` object.
#' @param format (string) The format of the returned draws. Must be a valid
#'   format from the \pkg{posterior} package.
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#'
#' # Unconstrain all internal draws
#' unconstrained_internal_draws <- fit_mcmc$unconstrain_draws()
#'
#' # Unconstrain external CmdStan CSV files
#' unconstrained_csv <- fit_mcmc$unconstrain_draws(files = fit_mcmc$output_files())
#'
#' # Unconstrain existing draws object
#' unconstrained_draws <- fit_mcmc$unconstrain_draws(draws = fit_mcmc$draws())
#' }
#'
#' @seealso [log_prob()], [grad_log_prob()], [constrain_variables()],
#'   [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#'   [hessian()]
#'
unconstrain_draws <- function(files = NULL, draws = NULL,
                              format = getOption("cmdstanr_draws_format", "draws_array")) {
  if (!is.null(files) || !is.null(draws)) {
    if (!is.null(files) && !is.null(draws)) {
      stop("Either a list of CSV files or a draws object can be passed, not both",
          call. = FALSE)
    }
    if (!is.null(files)) {
      read_csv <- read_cmdstan_csv(files = files, format = "draws_matrix")
      draws <- read_csv$post_warmup_draws
    }
    if (!is.null(draws)) {
      draws <- maybe_convert_draws_format(draws, "draws_matrix")
    }
  } else {
    if (is.null(private$draws_)) {
      if (!length(self$output_files(include_failed = FALSE))) {
        stop("Fitting failed. Unable to retrieve the draws.", call. = FALSE)
      }
      private$read_csv_(format = "draws_df")
    }
    draws <- maybe_convert_draws_format(private$draws_, "draws_matrix")
  }

  model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
  model_variables <- self$runset$args$model_variables

  # If zero-length parameters are present, they will be listed in model_variables
  # but not in metadata()$variables
  nonzero_length_params <- names(model_variables$parameters) %in% model_par_names

  # Remove zero-length parameters from model_variables, otherwise process_init
  # warns about missing inputs
  pars <- names(model_variables$parameters[nonzero_length_params])

  draws <- posterior::subset_draws(draws, variable = pars)
  unconstrained <- private$model_methods_env_$unconstrain_draws(private$model_methods_env_$model_ptr_, draws)
  uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE)
  names(unconstrained) <- repair_variable_names(uncon_names)
  maybe_convert_draws_format(unconstrained, format, .nchains = posterior::nchains(draws))
}
CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)

#' Return the variable skeleton for `relist`
#'
#' @name fit-method-variable_skeleton
#' @aliases variable_skeleton
#' @description The `$variable_skeleton()` method returns the variable skeleton
#'   needed by `utils::relist()` to re-structure a vector of constrained
#'   parameter values to a named list.
#' @param transformed_parameters (logical) Whether to include transformed
#'   parameters in the skeleton (defaults to `TRUE`).
#' @param generated_quantities (logical) Whether to include generated quantities
#'   in the skeleton (defaults to `TRUE`).
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#' fit_mcmc$variable_skeleton()
#' }
#'
#' @seealso [log_prob()], [grad_log_prob()], [constrain_variables()],
#'   [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#'   [hessian()]
#'
variable_skeleton <- function(transformed_parameters = TRUE, generated_quantities = TRUE) {
  if (is.null(private$model_methods_env_$model_ptr)) {
    stop("The method has not been compiled, please call `init_model_methods()` first",
        call. = FALSE)
  }

  create_skeleton(private$model_methods_env_$param_metadata_,
                  self$runset$args$model_variables,
                  transformed_parameters,
                  generated_quantities)
}
CmdStanFit$set("public", name = "variable_skeleton", value = variable_skeleton)

#' Transform a set of unconstrained parameter values to the constrained scale
#'
#' @name fit-method-constrain_variables
#' @aliases constrain_variables
#' @description The `$constrain_variables()` method transforms input parameters
#'   to the constrained scale.
#'
#' @param unconstrained_variables (numeric) A vector of unconstrained parameters
#'   to constrain.
#' @param transformed_parameters (logical) Whether to return transformed
#'   parameters implied by newly-constrained parameters (defaults to TRUE).
#' @param generated_quantities (logical) Whether to return generated quantities
#'   implied by newly-constrained parameters (defaults to TRUE).
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#' fit_mcmc$constrain_variables(unconstrained_variables = c(0.5, 1.2, 1.1, 2.2))
#' }
#'
#' @seealso [log_prob()], [grad_log_prob()], [constrain_variables()],
#'   [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#'   [hessian()]
#'
constrain_variables <- function(unconstrained_variables, transformed_parameters = TRUE,
                            generated_quantities = TRUE) {
  if (is.null(private$model_methods_env_$model_ptr)) {
    stop("The method has not been compiled, please call `init_model_methods()` first",
        call. = FALSE)
  }

  skeleton <- self$variable_skeleton(transformed_parameters, generated_quantities)

  if (length(unconstrained_variables) != private$model_methods_env_$num_upars_) {
    stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
          length(unconstrained_variables), " were provided!", call. = FALSE)
  }
  cpars <- private$model_methods_env_$constrain_variables(
    private$model_methods_env_$model_ptr_,
    private$model_methods_env_$model_rng_,
    unconstrained_variables, transformed_parameters, generated_quantities)
  utils::relist(cpars, skeleton)
}
CmdStanFit$set("public", name = "constrain_variables", value = constrain_variables)

#' Extract log probability (target)
#'
#' @name fit-method-lp
#' @aliases lp lp_approx
#' @description The `$lp()` method extracts `lp__`, the total log probability
#'   (`target`) accumulated in the model block of the Stan program. For
#'   variational inference the log density of the variational approximation to
#'   the posterior is available via the `$lp_approx()` method. For
#'   Laplace approximation the unnormalized density of the approximation to
#'   the posterior is available via the `$lp_approx()` method.
#'
#'   See the [Log Probability Increment vs. Sampling
#'   Statement](https://mc-stan.org/docs/reference-manual/sampling-statements.html)
#'   section of the Stan Reference Manual for details on when normalizing
#'   constants are dropped from log probability calculations.
#'
#' @section Details:
#' `lp__` is the unnormalized log density on Stan's [unconstrained
#' space](https://mc-stan.org/docs/2_23/reference-manual/variable-transforms-chapter.html).
#' This will in general be different than the unnormalized model log density
#' evaluated at a posterior draw (which is on the constrained space). `lp__` is
#' intended to diagnose sampling efficiency and evaluate approximations.
#'
#' For variational inference `lp_approx__` is the log density of the variational
#' approximation to `lp__` (also on the unconstrained space). It is exposed in
#' the variational method for performing the checks described in Yao et al.
#' (2018) and implemented in the \pkg{loo} package.
#'
#' For Laplace approximation `lp_approx__` is the unnormalized density of the
#' Laplace approximation. It can be used to perform the same checks as in the
#' case of the variational method described in Yao et al. (2018).
#'
#' @return A numeric vector with length equal to the number of (post-warmup)
#'   draws or length equal to `1` for optimization.
#'
#' @references
#' Yao, Y., Vehtari, A., Simpson, D., and Gelman, A. (2018). Yes, but did it
#' work?: Evaluating variational inference. *Proceedings of the 35th
#' International Conference on Machine Learning*, PMLR 80:5581–5590.
#'
#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanLaplace`], [`CmdStanVB`]
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic")
#' head(fit_mcmc$lp())
#'
#' fit_mle <- cmdstanr_example("logistic", method = "optimize")
#' fit_mle$lp()
#'
#' fit_vb <- cmdstanr_example("logistic", method = "variational")
#' plot(fit_vb$lp(), fit_vb$lp_approx())
#' }
#'
lp <- function() {
  lp__ <- self$draws(variables = "lp__")
  lp__ <- posterior::as_draws_matrix(lp__) # if mcmc this combines all chains, otherwise does nothing
  as.numeric(lp__)
}
CmdStanFit$set("public", name = "lp", value = lp)

# will be used by a subset of fit objects below
#' @rdname fit-method-lp
lp_approx <- function() {
  as.numeric(self$draws()[, "lp_approx__"])
}


#' Compute a summary table of estimates and diagnostics
#'
#' @name fit-method-summary
#' @aliases summary fit-method-print print.CmdStanMCMC print.CmdStanMLE print.CmdStanVB
#' @description The `$summary()` method runs
#'   [`summarise_draws()`][posterior::draws_summary] from the \pkg{posterior}
#'   package and returns the output. For MCMC, only post-warmup draws are
#'   included in the summary.
#'
#'   There is also a `$print()` method that prints the same summary stats but
#'   removes the extra formatting used for printing tibbles and returns the
#'   fitted model object itself. The `$print()` method may also be faster than
#'   `$summary()` because it is designed to only compute the summary statistics
#'   for the variables that will actually fit in the printed output whereas
#'   `$summary()` will compute them for all of the specified variables in order
#'   to be able to return them to the user. See **Examples**.
#'
#' @param variables (character vector) The variables to include.
#' @param ... Optional arguments to pass to [`posterior::summarise_draws()`][posterior::draws_summary].
#'
#' @return
#' The `$summary()` method returns the tibble data frame created by
#' [`posterior::summarise_draws()`][posterior::draws_summary].
#'
#' The `$print()` method returns the fitted model object itself (invisibly),
#' which is the standard behavior for print methods in \R.
#'
#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanLaplace`], [`CmdStanVB`], [`CmdStanGQ`]
#'
#' @examples
#' \dontrun{
#' fit <- cmdstanr_example("logistic")
#' fit$summary()
#' fit$print()
#' fit$print(max_rows = 2) # same as print(fit, max_rows = 2)
#'
#' # include only certain variables
#' fit$summary("beta")
#' fit$print(c("alpha", "beta[2]"))
#'
#' # include all variables but only certain summaries
#' fit$summary(NULL, c("mean", "sd"))
#'
#' # can use functions created from formulas
#' # for example, calculate Pr(beta > 0)
#' fit$summary("beta", prob_gt_0 = ~ mean(. > 0))
#'
#' # can combine user-specified functions with
#' # the default summary functions
#' fit$summary(variables = c("alpha", "beta"),
#'   posterior::default_summary_measures()[1:4],
#'   quantiles = ~ quantile2(., probs = c(0.025, 0.975)),
#'   posterior::default_convergence_measures()
#'   )
#'
#' # the functions need to calculate the appropriate
#' # value for a matrix input
#' fit$summary(variables = "alpha", dim)
#'
#' # the usual [stats::var()] is therefore not directly suitable as it
#' # will produce a covariance matrix unless the data is converted to a vector
#' fit$print(c("alpha", "beta"), var2 = ~var(as.vector(.x)))
#'
#' }
#'
summary <- function(variables = NULL, ...) {
  draws <- self$draws(variables)
  if (self$metadata()$method == "sample") {
    summary <- posterior::summarise_draws(draws, ...)
  } else {
    if (!length(list(...))) {
      # if user didn't supply any args use default summary measures,
      # which don't include MCMC-specific things
      summary <- posterior::summarise_draws(
        draws,
        posterior::default_summary_measures()
      )
    } else {
      # otherwise use whatever the user specified via ...
      summary <- posterior::summarise_draws(draws, ...)
    }
  }
  if (self$metadata()$method == "optimize") {
    summary <- summary[, c("variable", "mean")]
    colnames(summary) <- c("variable", "estimate")
  }
  summary
}
CmdStanFit$set("public", name = "summary", value = summary)

#' Run CmdStan's `stansummary` and `diagnose` utilities
#'
#' @name fit-method-cmdstan_summary
#' @aliases fit-method-cmdstan_diagnose cmdstan_summary cmdstan_diagnose
#' @description Run CmdStan's `stansummary` and `diagnose` utilities. These are
#'   documented in the CmdStan Guide:
#'   * https://mc-stan.org/docs/cmdstan-guide/stansummary.html
#'   * https://mc-stan.org/docs/cmdstan-guide/diagnose.html
#'
#'   Although these methods can be used for models fit using the
#'   [`$variational()`][model-method-variational] method, much of the output is
#'   currently only relevant for models fit using the
#'   [`$sample()`][model-method-sample] method.
#'
#'   See the [$summary()][fit-method-summary] for computing similar summaries in
#'   R rather than calling CmdStan's utilites.
#'
#' @param flags An optional character vector of flags (e.g.
#'   `flags = c("--sig_figs=1")`).
#'
#' @seealso [`CmdStanMCMC`], [fit-method-summary]
#'
#' @examples
#' \dontrun{
#' fit <- cmdstanr_example("logistic")
#' fit$cmdstan_diagnose()
#' fit$cmdstan_summary()
#' }
#'
cmdstan_summary <- function(flags = NULL) {
  self$runset$run_cmdstan_tool("stansummary", flags = flags)
}
CmdStanFit$set("public", name = "cmdstan_summary", value = cmdstan_summary)

#' @rdname fit-method-cmdstan_summary
cmdstan_diagnose <- function() {
  self$runset$run_cmdstan_tool("diagnose")
}
CmdStanFit$set("public", name = "cmdstan_diagnose", value = cmdstan_diagnose)

#' Save output and data files
#'
#' @name fit-method-save_output_files
#' @aliases fit-method-save_data_file fit-method-save_latent_dynamics_files
#'   fit-method-save_profile_files fit-method-output_files fit-method-data_file
#'   fit-method-latent_dynamics_files fit-method-profile_files
#'   fit-method-save_config_files fit-method-save_metric_files save_output_files
#'   save_data_file save_latent_dynamics_files save_profile_files
#'   save_config_files save_metric_files output_files data_file
#'   latent_dynamics_files profile_files config_files metric_files
#'
#' @description All fitted model objects have methods for saving (moving to a
#'   specified location) the files created by CmdStanR to hold CmdStan output
#'   csv files and input data files. These methods move the files from their
#'   current location (possibly the temporary directory) to a user-specified
#'   location. __The paths stored in the fitted model object will also be
#'   updated to point to the new file locations.__
#'
#'   The versions without the `save_` prefix (e.g., `$output_files()`) return
#'   the current file paths without moving any files.
#'
#' @param dir (string) Path to directory where the files should be saved.
#' @param basename (string) Base filename to use. See __Details__.
#' @param timestamp (logical) Should a timestamp be added to the file name(s)?
#'   Defaults to `TRUE`. See __Details__.
#' @param random (logical) Should random alphanumeric characters be added to the
#'   end of the file name(s)? Defaults to `TRUE`. See __Details__.
#'
#' @section Details:
#' For `$save_output_files()` the files moved to `dir` will have names of
#' the form `basename-timestamp-id-random`, where
#' * `basename` is the user's provided `basename` argument;
#' * `timestamp` is of the form `format(Sys.time(), "%Y%m%d%H%M")`;
#' * `id` is the MCMC chain id (or `1` for non MCMC);
#' * `random` contains six random alphanumeric characters;
#'
#' For `$save_latent_dynamics_files()` everything is the same as for
#' `$save_output_files()` except `"-diagnostic-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_profile_files()` everything is the same as for
#' `$save_output_files()` except `"-profile-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_metric_files()` everything is the same as for
#' `$save_output_files()` except `"-metric-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_config_files()` everything is the same as for
#' `$save_output_files()` except `"-config-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_data_file()` no `id` is included in the file name because even
#' with multiple MCMC chains the data file is the same.
#'
#' @return
#' The `$save_*` methods print a message with the new file paths and (invisibly)
#' return a character vector of the new paths (or `NA` for any that couldn't be
#' copied). They also have the side effect of setting the internal paths in the
#' fitted model object to the new paths.
#'
#' The methods _without_ the `save_` prefix return character vectors of file
#' paths without moving any files.
#'
#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`], [`CmdStanGQ`]
#'
#' @examples
#' \dontrun{
#' fit <- cmdstanr_example()
#' fit$output_files()
#' fit$data_file()
#'
#' # just using tempdir for the example
#' my_dir <- tempdir()
#' fit$save_output_files(dir = my_dir, basename = "banana")
#' fit$save_output_files(dir = my_dir, basename = "tomato", timestamp = FALSE)
#' fit$save_output_files(dir = my_dir, basename = "lettuce", timestamp = FALSE, random = FALSE)
#' }
#'
save_output_files <- function(dir = ".",
                              basename = NULL,
                              timestamp = TRUE,
                              random = TRUE) {
  self$runset$save_output_files(dir, basename, timestamp, random)
}
CmdStanFit$set("public", name = "save_output_files", value = save_output_files)

#' @rdname fit-method-save_output_files
save_latent_dynamics_files <- function(dir = ".",
                                       basename = NULL,
                                       timestamp = TRUE,
                                       random = TRUE) {
  self$runset$save_latent_dynamics_files(dir, basename, timestamp, random)
}
CmdStanFit$set("public", name = "save_latent_dynamics_files", value = save_latent_dynamics_files)

#' @rdname fit-method-save_output_files
save_profile_files <- function(dir = ".",
                               basename = NULL,
                               timestamp = TRUE,
                               random = TRUE) {
  self$runset$save_profile_files(dir, basename, timestamp, random)
}
CmdStanFit$set("public", name = "save_profile_files", value = save_profile_files)

#' @rdname fit-method-save_output_files
save_data_file <- function(dir = ".",
                           basename = NULL,
                           timestamp = TRUE,
                           random = TRUE) {
  self$runset$save_data_file(dir, basename, timestamp, random)
}
CmdStanFit$set("public", name = "save_data_file", value = save_data_file)

#' @rdname fit-method-save_output_files
save_config_files <- function(dir = ".",
                              basename = NULL,
                              timestamp = TRUE,
                              random = TRUE) {
  self$runset$save_config_files(dir, basename, timestamp, random)
}
CmdStanFit$set("public", name = "save_config_files", value = save_config_files)

#' @rdname fit-method-save_output_files
save_metric_files <- function(dir = ".",
                              basename = NULL,
                              timestamp = TRUE,
                              random = TRUE) {
  self$runset$save_metric_files(dir, basename, timestamp, random)
}
CmdStanFit$set("public", name = "save_metric_files", value = save_metric_files)



#' @rdname fit-method-save_output_files
#' @param include_failed (logical) Should CmdStan runs that failed also be
#'   included? The default is `FALSE.`
output_files <- function(include_failed = FALSE) {
  self$runset$output_files(include_failed)
}
CmdStanFit$set("public", name = "output_files", value = output_files)

#' @rdname fit-method-save_output_files
profile_files <- function(include_failed = FALSE) {
  self$runset$profile_files(include_failed)
}
CmdStanFit$set("public", name = "profile_files", value = profile_files)

#' @rdname fit-method-save_output_files
latent_dynamics_files <- function(include_failed = FALSE) {
  self$runset$latent_dynamics_files(include_failed)
}
CmdStanFit$set("public", name = "latent_dynamics_files", value = latent_dynamics_files)

#' @rdname fit-method-save_output_files
data_file <- function() {
  self$runset$data_file()
}
CmdStanFit$set("public", name = "data_file", value = data_file)

#' @rdname fit-method-save_output_files
config_files <- function(include_failed = FALSE) {
  self$runset$config_files(include_failed)
}
CmdStanFit$set("public", name = "config_files", value = config_files)

#' @rdname fit-method-save_output_files
metric_files <- function(include_failed = FALSE) {
  self$runset$metric_files(include_failed)
}
CmdStanFit$set("public", name = "metric_files", value = metric_files)

#' Report timing of CmdStan runs
#'
#' @name fit-method-time
#' @aliases time
#' @description Report the run time in seconds. For MCMC additional information
#'   is provided about the run times of individual chains and the warmup and
#'   sampling phases. For Laplace approximation the time only include the time
#'   for drawing the approximate sample and does not include the time
#'   taken to run the `$optimize()` method.
#'
#' @return
#' A list with elements
#' * `total`: (scalar) The total run time. For MCMC this may be different than
#' the sum of the chain run times if parallelization was used.
#' * `chains`: (data frame) For MCMC only, timing info for the individual
#' chains. The data frame has columns `"chain_id"`, `"warmup"`, `"sampling"`,
#' and `"total"`.
#'
#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`], [`CmdStanGQ`]
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
#' fit_mcmc$time()
#'
#' fit_vb <- cmdstanr_example("logistic", method = "variational")
#' fit_vb$time()
#'
#' fit_mle <- cmdstanr_example("logistic", method = "optimize", jacobian = TRUE)
#' fit_mle$time()
#'
#' # use fit_mle to draw samples from laplace approximation
#' fit_laplace <- cmdstanr_example("logistic", method = "laplace", mode = fit_mle)
#' fit_laplace$time() # just time for drawing sample not for running optimize
#' fit_laplace$time()$total + fit_mle$time()$total # total time
#' }
#'
time <- function() {
  self$runset$time()
}
CmdStanFit$set("public", name = "time", value = time)

#' Access console output
#'
#' @name fit-method-output
#' @aliases output
#' @description For MCMC, the `$output()` method returns the stdout and stderr
#'   of all chains as a list of character vectors if `id=NULL`. If the `id`
#'   argument is specified it instead pretty prints the console output for a
#'   single chain.
#'
#'   For optimization and variational inference `$output()` just pretty prints
#'   the console output.
#'
#' @param id (integer) The chain id. Ignored if the model was not fit using
#'   MCMC.
#'
#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`], [`CmdStanGQ`]
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
#' fit_mcmc$output(1)
#' out <- fit_mcmc$output()
#' str(out)
#'
#' fit_mle <- cmdstanr_example("logistic", method = "optimize")
#' fit_mle$output()
#'
#' fit_vb <- cmdstanr_example("logistic", method = "variational")
#' fit_vb$output()
#' }
#'
output <- function(id = NULL) {
  # MCMC has separate implementation but doc is shared
  # Non-MCMC fit is obtained with one process only so id is ignored
  cat(paste(self$runset$procs$proc_output(1), collapse = "\n"))
}
CmdStanFit$set("public", name = "output", value = output)

#' Extract metadata from CmdStan CSV files
#'
#' @name fit-method-metadata
#' @aliases metadata
#' @description The `$metadata()` method returns a list of information gathered
#'   from the CSV output files, including the CmdStan configuration used when
#'   fitting the model. See **Examples** and [read_cmdstan_csv()].
#'
#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`], [`CmdStanGQ`]
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample")
#' str(fit_mcmc$metadata())
#'
#' fit_mle <- cmdstanr_example("logistic", method = "optimize")
#' str(fit_mle$metadata())
#'
#' fit_vb <- cmdstanr_example("logistic", method = "variational")
#' str(fit_vb$metadata())
#' }
#'
metadata <- function() {
  if (is.null(private$metadata_)) {
    if (!length(self$output_files(include_failed = FALSE))) {
      stop("Fitting failed. Unable to retrieve the metadata.", call. = FALSE)
    }
    private$read_csv_()
  }
  private$metadata_
}
CmdStanFit$set("public", name = "metadata", value = metadata)

#' Extract return codes from CmdStan
#'
#' @name fit-method-return_codes
#' @aliases return_codes
#' @description The `$return_codes()` method returns a vector of return codes
#'   from the CmdStan run(s). A return code of 0 indicates a successful run.
#' @return An integer vector of return codes with length equal to the number of
#'   CmdStan runs (number of chains for MCMC and one otherwise).
#'
#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`], [`CmdStanGQ`]
#'
#' @examples
#' \dontrun{
#' # example with return codes all zero
#' fit_mcmc <- cmdstanr_example("schools", method = "sample")
#' fit_mcmc$return_codes() # should be all zero
#'
#' # example of non-zero return code (optimization fails for hierarchical model)
#' fit_opt <- cmdstanr_example("schools", method = "optimize")
#' fit_opt$return_codes() # should be non-zero
#' }
#'
return_codes <- function() {
  private$return_codes_
}
CmdStanFit$set("public", name = "return_codes", value = return_codes)

#' Return profiling data
#'
#' @name fit-method-profiles
#' @aliases profiles
#' @description The `$profiles()` method returns a list of data frames with
#'   profiling data if any profiling data was written to the profile CSV files.
#'   See [save_profile_files()] to control where the files are saved.
#'
#'   Support for profiling Stan programs is available with CmdStan >= 2.26 and
#'   requires adding profiling statements to the Stan program.
#'
#' @return A list of data frames with profiling data if the profiling CSV files
#'   were created.
#'
#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`], [`CmdStanGQ`]
#'
#' @examples
#'
#' \dontrun{
#' # first fit a model using MCMC
#' mcmc_program <- write_stan_file(
#'   'data {
#'     int<lower=0> N;
#'     array[N] int<lower=0,upper=1> y;
#'   }
#'   parameters {
#'     real<lower=0,upper=1> theta;
#'   }
#'   model {
#'     profile("likelihood") {
#'       y ~ bernoulli(theta);
#'     }
#'   }
#'   generated quantities {
#'     array[N] int y_rep;
#'     profile("gq") {
#'       y_rep = bernoulli_rng(rep_vector(theta, N));
#'     }
#'   }
#' '
#' )
#' mod_mcmc <- cmdstan_model(mcmc_program)
#'
#' data <- list(N = 10, y = c(1,1,0,0,0,1,0,1,0,0))
#' fit <- mod_mcmc$sample(data = data, seed = 123, refresh = 0)
#'
#' fit$profiles()
#' }
#'
profiles <- function() {
  if (is.null(private$profiles_)) {
    private$profiles_ <- list()
    i <- 1
    for (f in self$profile_files()) {
      private$profiles_[[i]] <- data.table::fread(f, integer64 = "character", data.table = FALSE)
      i <- i + 1
    }
  }
  private$profiles_
}
CmdStanFit$set("public", name = "profiles", value = profiles)

#' Return Stan code
#'
#' @name fit-method-code
#' @aliases code
#' @return A character vector with one element per line of code.
#'
#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`], [`CmdStanGQ`]
#'
#' @examples
#'
#' \dontrun{
#' fit <- cmdstanr_example()
#' fit$code() # character vector
#' cat(fit$code(), sep = "\n") # pretty print
#' }
#'
code <- function() {
  stan_code <- self$runset$stan_code()
  if (is.null(stan_code)) {
    warning("'$code()' will return NULL because the 'CmdStanModel' was not created with a Stan file.", call. = FALSE)
  }
  stan_code
}
CmdStanFit$set("public", name = "code", value = code)

# CmdStanMCMC -------------------------------------------------------------
#' CmdStanMCMC objects
#'
#' @name CmdStanMCMC
#' @family fitted model objects
#' @template seealso-docs
#'
#' @description A `CmdStanMCMC` object is the fitted model object returned by
#'   the [`$sample()`][model-method-sample] method of a [`CmdStanModel`] object.
#'   Like `CmdStanModel` objects, `CmdStanMCMC` objects are [R6][R6::R6Class]
#'   objects.
#'
#' @section Methods: `CmdStanMCMC` objects have the following associated
#'   methods, all of which have their own (linked) documentation pages.
#'
#'  ## Extract contents of fitted model object
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$draws()`][fit-method-draws] |  Return posterior draws using formats from the \pkg{posterior} package. |
#'  [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] |  Return sampler diagnostics as a [`draws_array`][posterior::draws_array]. |
#'  [`$lp()`][fit-method-lp] |  Return the total log probability density (`target`). |
#'  [`$inv_metric()`][fit-method-inv_metric] |  Return the inverse metric for each chain. |
#'  [`$init()`][fit-method-init] |  Return user-specified initial values. |
#'  [`$metadata()`][fit-method-metadata] | Return a list of metadata gathered from the CmdStan CSV files. |
#'  [`$num_chains()`][fit-method-num_chains] | Return the number of MCMC chains. |
#'  [`$code()`][fit-method-code] | Return Stan code as a character vector. |
#'
#'  ## Summarize inferences and diagnostics
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$print()`][fit-method-print] |  Run [`posterior::summarise_draws()`][posterior::draws_summary]. |
#'  [`$summary()`][fit-method-summary] |  Run [`posterior::summarise_draws()`][posterior::draws_summary]. |
#'  [`$diagnostic_summary()`][fit-method-diagnostic_summary] |  Get summaries of sampler diagnostics and warning messages. |
#'  [`$cmdstan_summary()`][fit-method-cmdstan_summary] |  Run and print CmdStan's `bin/stansummary`. |
#'  [`$cmdstan_diagnose()`][fit-method-cmdstan_summary] |  Run and print CmdStan's `bin/diagnose`. |
#'  [`$loo()`][fit-method-loo]  |  Run [loo::loo.array()] for approximate LOO-CV |
#'
#'  ## Save fitted model object and temporary files
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$save_object()`][fit-method-save_object] |  Save fitted model object to a file. |
#'  [`$save_output_files()`][fit-method-save_output_files] |  Save output CSV files to a specified location. |
#'  [`$save_data_file()`][fit-method-save_data_file] |  Save JSON data file to a specified location. |
#'  [`$save_latent_dynamics_files()`][fit-method-save_latent_dynamics_files] |  Save diagnostic CSV files to a specified location. |
#'
#'  ## Report run times, console output, return codes
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$output()`][fit-method-output]  |  Return the stdout and stderr of all chains or pretty print the output for a single chain. |
#'  [`$time()`][fit-method-time]  |  Report total and chain-specific run times. |
#'  [`$return_codes()`][fit-method-return_codes]  |  Return the return codes from the CmdStan runs. |
#'
#'  ## Expose Stan functions and additional methods to R
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$expose_functions()`][fit-method-expose_functions] |  Expose Stan functions for use in R. |
#'  [`$init_model_methods()`][fit-method-init_model_methods] | Expose methods for log-probability, gradients, parameter constraining and unconstraining. |
#'  [`$log_prob()`][fit-method-log_prob] | Calculate log-prob. |
#'  [`$grad_log_prob()`][fit-method-grad_log_prob] | Calculate log-prob and gradient. |
#'  [`$hessian()`][fit-method-hessian] | Calculate log-prob, gradient, and hessian. |
#'  [`$constrain_variables()`][fit-method-constrain_variables] | Transform a set of unconstrained parameter values to the constrained scale. |
#'  [`$unconstrain_variables()`][fit-method-unconstrain_variables] | Transform a set of parameter values to the unconstrained scale. |
#'  [`$unconstrain_draws()`][fit-method-unconstrain_draws] | Transform all parameter draws to the unconstrained scale. |
#'  [`$variable_skeleton()`][fit-method-variable_skeleton] | Helper function to re-structure a vector of constrained parameter values. |
#'
CmdStanMCMC <- R6::R6Class(
  classname = "CmdStanMCMC",
  inherit = CmdStanFit,
  public = list(
    # override the CmdStanFit initialize method
    initialize = function(runset) {
      super$initialize(runset)
      if (!length(self$output_files())) {
        warning("No chains finished successfully. Unable to retrieve the fit.",
                call. = FALSE)
      } else {
        if (runset$args$method_args$fixed_param) {
          private$read_csv_(variables = "", sampler_diagnostics = "")
        } else {
          diagnostics <- self$runset$args$method_args$diagnostics
          private$read_csv_(
            variables = "",
            sampler_diagnostics = convert_hmc_diagnostic_names(diagnostics)
          )
          invisible(self$diagnostic_summary(diagnostics, quiet = FALSE))
        }
      }
    },
    # override the CmdStanFit output method
    output = function(id = NULL) {
      if (is.null(id)) {
        self$runset$procs$proc_output()
      } else {
        cat(paste(self$runset$procs$proc_output(id), collapse = "\n"))
      }
    },

    # override the CmdStanFit draws method
    draws = function(variables = NULL, inc_warmup = FALSE, format = getOption("cmdstanr_draws_format", "draws_array")) {
      if (inc_warmup && !private$metadata_$save_warmup) {
        stop("Warmup draws were requested from a fit object without them! ",
             "Please rerun the model with save_warmup = TRUE.", call. = FALSE)
      }
      format <- assert_valid_draws_format(format)
      to_read <- remaining_columns_to_read(
        requested = variables,
        currently_read = posterior::variables(private$draws_),
        all = private$metadata_$variables
      )
      private$draws_ <- maybe_convert_draws_format(private$draws_, format)
      private$warmup_draws_ <- maybe_convert_draws_format(private$warmup_draws_, format)
      private$sampler_diagnostics_ <- maybe_convert_draws_format(private$sampler_diagnostics_, format)
      private$warmup_sampler_diagnostics_ <- maybe_convert_draws_format(private$warmup_sampler_diagnostics_, format)
      if (is.null(to_read) || any(nzchar(to_read))) {
        private$read_csv_(variables = to_read, sampler_diagnostics = "", format = format)
      }
      if (is.null(variables)) {
        variables <- private$metadata_$variables
      } else {
        matching_res <- matching_variables(variables, private$metadata_$variables)
        if (length(matching_res$not_found)) {
          stop("Can't find the following variable(s) in the output: ",
              paste(matching_res$not_found, collapse = ", "), call. = FALSE)
        }
        variables <- matching_res$matching
      }
      if (inc_warmup) {
        posterior::subset_draws(posterior::bind_draws(private$warmup_draws_, private$draws_, along = "iteration"), variable = variables)
      } else {
        posterior::subset_draws(private$draws_, variable = variables)
      }
    }
  ),
  private = list(
    # also inherits draws_ and metadata_ from CmdStanFit
    sampler_diagnostics_ = NULL,
    warmup_sampler_diagnostics_ = NULL,
    warmup_draws_ = NULL,
    inv_metric_ = NULL,
    read_csv_ = function(variables = NULL, sampler_diagnostics = NULL, format = getOption("cmdstanr_draws_format", "draws_array")) {
      if (!length(self$output_files(include_failed = FALSE))) {
        stop("No chains finished successfully. Unable to retrieve the draws.", call. = FALSE)
      }
      csv_contents <- read_cmdstan_csv(
        files = self$output_files(include_failed = FALSE),
        variables = variables,
        sampler_diagnostics = sampler_diagnostics,
        format = format
      )
      private$inv_metric_ <- csv_contents$inv_metric
      private$metadata_ <- csv_contents$metadata

      if (!is.null(csv_contents$post_warmup_draws)) {
        if (is.null(private$draws_)) {
          private$draws_ <- csv_contents$post_warmup_draws
        } else {
          missing_variables <- posterior::variables(csv_contents$post_warmup_draws)[!(posterior::variables(csv_contents$post_warmup_draws) %in% posterior::variables(private$draws_))]
          private$draws_ <- posterior::bind_draws(
            private$draws_,
            posterior::subset_draws(csv_contents$post_warmup_draws, variable = missing_variables),
            along = "variable"
          )
        }
      }
      if (!is.null(csv_contents$post_warmup_sampler_diagnostics)) {

        if (is.null(private$sampler_diagnostics_)) {
          private$sampler_diagnostics_ <- csv_contents$post_warmup_sampler_diagnostics
        } else {
          missing_variables <- posterior::variables(csv_contents$post_warmup_sampler_diagnostics)[!(posterior::variables(csv_contents$post_warmup_sampler_diagnostics) %in% posterior::variables(private$sampler_diagnostics_))]
          private$sampler_diagnostics_ <- posterior::bind_draws(
            private$sampler_diagnostics_,
            posterior::subset_draws(csv_contents$post_warmup_sampler_diagnostics, variable = missing_variables),
            along = "variable"
          )
        }
      }
      if (!is.null(csv_contents$metadata$save_warmup)
         && csv_contents$metadata$save_warmup) {
        if (!is.null(csv_contents$warmup_draws)) {
          if (is.null(private$warmup_draws_)) {
            private$warmup_draws_ <- csv_contents$warmup_draws
          } else {
            missing_variables <- posterior::variables(csv_contents$warmup_draws)[!(posterior::variables(csv_contents$warmup_draws) %in% posterior::variables(private$warmup_draws_))]
            private$warmup_draws_ <- posterior::bind_draws(
              private$warmup_draws_,
              posterior::subset_draws(csv_contents$warmup_draws, variable = missing_variables),
              along = "variable"
            )
          }
        }
        if (!is.null(csv_contents$warmup_sampler_diagnostics)) {
          if (is.null(private$warmup_sampler_diagnostics_)) {
            private$warmup_sampler_diagnostics_ <- csv_contents$warmup_sampler_diagnostics
          } else {
            missing_variables <- posterior::variables(csv_contents$warmup_sampler_diagnostics)[!(posterior::variables(csv_contents$warmup_sampler_diagnostics) %in% posterior::variables(private$warmup_sampler_diagnostics_))]
            private$warmup_sampler_diagnostics_ <- posterior::bind_draws(
              private$warmup_sampler_diagnostics_,
              posterior::subset_draws(csv_contents$warmup_sampler_diagnostics, variable = missing_variables),
              along = "variable"
            )
          }
        }
      }
      invisible(self)
    }
  )
)

#' Leave-one-out cross-validation (LOO-CV)
#'
#' @name fit-method-loo
#' @aliases loo
#' @description The `$loo()` method computes approximate LOO-CV using the
#'   \pkg{loo} package. In order to use this method you must compute and save
#'   the pointwise log-likelihood in your Stan program. See [loo::loo.array()]
#'   and the \pkg{loo} package [vignettes](https://mc-stan.org/loo/articles/)
#'   for details.
#'
#' @param variables (character vector) The name(s) of the variable(s) in the
#'   Stan program containing the pointwise log-likelihood. The default is to
#'   look for `"log_lik"`. This argument is passed to the
#'   [`$draws()`][fit-method-draws] method.
#' @param r_eff (multiple options) How to handle the `r_eff` argument for `loo()`:
#'   * `TRUE` (the default) will automatically call [loo::relative_eff.array()]
#'   to compute the `r_eff` argument to pass to [loo::loo.array()].
#'   * `FALSE` or `NULL` will avoid computing `r_eff` (which can sometimes be slow),
#'   but the reported ESS and MCSE estimates can be over-optimistic if the
#'   posterior draws are not (near) independent.
#'   * If `r_eff` is anything else, that object will be passed as the `r_eff`
#'   argument to [loo::loo.array()].
#' @param moment_match (logical) Whether to use a
#'   [moment-matching][loo::loo_moment_match()] correction for problematic
#'   observations. The default is `FALSE`. Using `moment_match=TRUE` will result
#'   in compiling the additional methods described in
#'   [fit-method-init_model_methods]. This allows CmdStanR to automatically
#'   supply the functions for the `log_lik_i`, `unconstrain_pars`,
#'   `log_prob_upars`, and `log_lik_i_upars` arguments to
#'   [loo::loo_moment_match()].
#' @param ... Other arguments (e.g., `cores`, `save_psis`, etc.) passed to
#'   [loo::loo.array()] or [loo::loo_moment_match.default()]
#'   (if `moment_match` = `TRUE` is set).
#'
#' @return The object returned by [loo::loo.array()] or
#'   [loo::loo_moment_match.default()].
#'
#' @seealso The \pkg{loo} package website with
#'   [documentation](https://mc-stan.org/loo/reference/index.html) and
#'   [vignettes](https://mc-stan.org/loo/articles/).
#'
#' @examples
#'
#' \dontrun{
#' # the "logistic" example model has "log_lik" in generated quantities
#' fit <- cmdstanr_example("logistic")
#' loo_result <- fit$loo(cores = 2)
#' print(loo_result)
#' }
#'
loo <- function(variables = "log_lik", r_eff = TRUE, moment_match = FALSE, ...) {
  require_suggested_package("loo")
  LLarray <- self$draws(variables, format = "draws_array")
  if (is.logical(r_eff)) {
    if (isTRUE(r_eff)) {
      r_eff_cores <- list(...)[["cores"]] %||% getOption("mc.cores", 1)
      r_eff <- loo::relative_eff(exp(LLarray), cores = r_eff_cores)
    } else {
      r_eff <- NULL
    }
  }

  if (moment_match == TRUE) {
    # Moment-matching requires log-prob, constrain, and unconstrain methods
    if (is.null(private$model_methods_env_$model_ptr)) {
      self$init_model_methods()
    }

    suppressWarnings(loo_result <- loo::loo.array(LLarray, r_eff = r_eff, ...))

    log_lik_i <- function(x, i, parameter_name = "log_lik", ...) {
      ll_array <- x$draws(variables = parameter_name, format = "draws_array")[,,i]
      # draws_array types don't drop the last dimension when it's 1, so we do this manually
      attr(ll_array, "dim") <- attributes(ll_array)$dim[1:2]
      ll_array
    }

    log_lik_i_upars <- function(x, upars, i, parameter_name = "log_lik", ...) {
      apply(upars, 1, function(up_i) { x$constrain_variables(up_i)[[parameter_name]][i] })
    }

    loo::loo_moment_match.default(
      x = self,
      loo = loo_result,
      post_draws = function(x, ...) { x$draws(format = "draws_matrix") },
      log_lik_i = log_lik_i,
      unconstrain_pars = function(x, pars, ...) { x$unconstrain_draws(format = "draws_matrix") },
      log_prob_upars = function(x, upars, ...) { apply(upars, 1, x$log_prob) },
      log_lik_i_upars = log_lik_i_upars,
      ...
    )
  } else {
    loo::loo.array(LLarray, r_eff = r_eff, ...)
  }
}
CmdStanMCMC$set("public", name = "loo", value = loo)

#' Extract sampler diagnostics after MCMC
#'
#' @name fit-method-sampler_diagnostics
#' @aliases sampler_diagnostics
#' @description Extract the values of sampler diagnostics for each iteration and
#'   chain of MCMC. To instead get summaries of these diagnostics and associated
#'   warning messages use the
#'   [`$diagnostic_summary()`][fit-method-diagnostic_summary] method.
#'
#' @param inc_warmup (logical) Should warmup draws be included? Defaults to `FALSE`.
#' @param format (string) The draws format to return. See
#'   [draws][fit-method-draws] for details.
#'
#' @return
#' Depends on `format`, but the default is a 3-D
#' [`draws_array`][posterior::draws_array] object (iteration x chain x
#' variable). The variables for Stan's default MCMC algorithm are
#' `"accept_stat__"`, `"stepsize__"`, `"treedepth__"`, `"n_leapfrog__"`,
#' `"divergent__"`, `"energy__"`.
#'
#' @seealso [`CmdStanMCMC`]
#'
#' @examples
#' \dontrun{
#' fit <- cmdstanr_example("logistic")
#' sampler_diagnostics <- fit$sampler_diagnostics()
#' str(sampler_diagnostics)
#'
#' library(posterior)
#' as_draws_df(sampler_diagnostics)
#'
#' # or specify format to get a data frame instead of calling as_draws_df
#' fit$sampler_diagnostics(format = "df")
#' }
#'
sampler_diagnostics <- function(inc_warmup = FALSE, format = getOption("cmdstanr_draws_format", "draws_array")) {
  if (is.null(private$sampler_diagnostics_) &&
      !length(self$output_files(include_failed = FALSE))) {
    stop("No chains finished successfully. Unable to retrieve the sampler diagnostics.", call. = FALSE)
  }
  to_read <- remaining_columns_to_read(
    requested = NULL,
    currently_read = posterior::variables(private$sampler_diagnostics_),
    all = private$metadata_$sampler_diagnostics
  )
  private$warmup_sampler_diagnostics_ <- maybe_convert_draws_format(private$warmup_sampler_diagnostics_, format)
  private$sampler_diagnostics_ <- maybe_convert_draws_format(private$sampler_diagnostics_, format)
  if (is.null(to_read) || any(nzchar(to_read))) {
    private$read_csv_(variables = "", sampler_diagnostics = NULL, format = format)
  }
  if (inc_warmup) {
    if (!private$metadata_$save_warmup) {
      stop("Warmup sampler diagnostics were requested from a fit object without them! ",
           "Please rerun the model with save_warmup = TRUE.", call. = FALSE)
    }
    posterior::bind_draws(
      private$warmup_sampler_diagnostics_,
      private$sampler_diagnostics_,
      along = "iteration"
    )
  } else {
    private$sampler_diagnostics_
  }
}
CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnostics)

#' Sampler diagnostic summaries and warnings
#'
#' @name fit-method-diagnostic_summary
#' @aliases diagnostic_summary
#' @description Warnings and summaries of sampler diagnostics. To instead get
#'   the underlying values of the sampler diagnostics for each iteration and
#'   chain use the [`$sampler_diagnostics()`][fit-method-sampler_diagnostics]
#'   method.
#'
#'   Currently parameter-specific diagnostics like R-hat and effective sample
#'   size are _not_ handled by this method. Those diagnostics are provided via
#'   the [`$summary()`][fit-method-summary] method (using
#'   [posterior::summarize_draws()]).
#'
#' @param diagnostics (character vector) One or more diagnostics to check. The
#'   currently supported diagnostics are `"divergences`, `"treedepth"`, and
#'   `"ebfmi`. The default is to check all of them.
#' @param quiet (logical) Should warning messages about the diagnostics be
#'   suppressed? The default is `FALSE`, in which case warning messages are
#'   printed in addition to returning the values of the diagnostics.
#'
#' @return A list with as many named elements as `diagnostics` selected. The
#'   possible elements and their values are:
#'   * `"num_divergent"`: A vector of the number of divergences per chain.
#'   * `"num_max_treedepth"`: A vector of the number of times `max_treedepth` was hit per chain.
#'   * `"ebfmi"`: A vector of E-BFMI values per chain.
#'
#' @seealso [`CmdStanMCMC`] and the
#'   [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] method
#'
#' @examples
#' \dontrun{
#' fit <- cmdstanr_example("schools")
#' fit$diagnostic_summary()
#' fit$diagnostic_summary(quiet = TRUE)
#' }
#'
diagnostic_summary <- function(diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE) {
  out <- list()
  if (is.null(diagnostics) || identical(diagnostics, "")) {
    return(out)
  }
  diagnostics <- match.arg(
    diagnostics,
    choices = available_hmc_diagnostics(),
    several.ok = TRUE
  )
  post_warmup_sampler_diagnostics <- self$sampler_diagnostics(inc_warmup = FALSE)
  if ("divergences" %in% diagnostics) {
    if (quiet) {
      divergences <- suppressMessages(check_divergences(post_warmup_sampler_diagnostics))
    } else {
      divergences <- check_divergences(post_warmup_sampler_diagnostics)
    }
    out[["num_divergent"]] <- divergences
  }
  if ("treedepth" %in% diagnostics) {
    if (quiet) {
      max_treedepth_hit <- suppressMessages(check_max_treedepth(post_warmup_sampler_diagnostics, self$metadata()))
    } else {
      max_treedepth_hit <- check_max_treedepth(post_warmup_sampler_diagnostics, self$metadata())
    }
    out[["num_max_treedepth"]] <- max_treedepth_hit
  }
  if ("ebfmi" %in% diagnostics) {
    if (quiet) {
      ebfmi <- suppressMessages(check_ebfmi(post_warmup_sampler_diagnostics))
    } else {
      ebfmi <- check_ebfmi(post_warmup_sampler_diagnostics)
    }
    out[["ebfmi"]] <- ebfmi %||% NA
  }
  out
}
CmdStanMCMC$set("public", name = "diagnostic_summary", value = diagnostic_summary)


#' Extract inverse metric (mass matrix) after MCMC
#'
#' @name fit-method-inv_metric
#' @aliases inv_metric
#' @description Extract the inverse metric (mass matrix) for each MCMC chain.
#'
#' @param matrix (logical) If a diagonal metric was used, setting `matrix =
#'   FALSE` returns a list containing just the diagonals of the matrices instead
#'   of the full matrices. Setting `matrix = FALSE` has no effect for dense
#'   metrics.
#'
#' @return A list of length equal to the number of MCMC chains. See the `matrix`
#'   argument for details.
#'
#' @seealso [`CmdStanMCMC`]
#'
#' @examples
#' \dontrun{
#' fit <- cmdstanr_example("logistic")
#' fit$inv_metric()
#' fit$inv_metric(matrix=FALSE)
#'
#' fit <- cmdstanr_example("logistic", metric = "dense_e")
#' fit$inv_metric()
#' }
#'
inv_metric <- function(matrix = TRUE) {
  if (!length(self$output_files(include_failed = FALSE))) {
    stop("No chains finished successfully. Unable to retrieve the inverse metrics.", call. = FALSE)
  }
  if (is.null(private$inv_metric_)) {
    private$read_csv_(variables = "", sampler_diagnostics = "")
  }
  out <- private$inv_metric_
  if (matrix && !is.matrix(out[[1]])) {
    # convert each vector to a diagonal matrix
    out <- lapply(out, function(x) diag(x, nrow = length(x)))
  } else if (length(out[[1]]) == 1) {
    # convert each scalar to an array with dimension 1
    out <- lapply(out, array, dim = c(1))
  }
  out
}
CmdStanMCMC$set("public", name = "inv_metric", value = inv_metric)

#' Extract number of chains after MCMC
#'
#' @name fit-method-num_chains
#' @aliases num_chains
#' @description The `$num_chains()` method returns the number of MCMC chains.
#' @return An integer.
#'
#' @seealso [`CmdStanMCMC`]
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example(chains = 2)
#' fit_mcmc$num_chains()
#' }
#'
num_chains <- function() {
  super$num_procs()
}
CmdStanMCMC$set("public", name = "num_chains", value = num_chains)


# CmdStanMLE -------------------------------------------------------------
#' CmdStanMLE objects
#'
#' @name CmdStanMLE
#' @family fitted model objects
#' @template seealso-docs
#'
#' @description A `CmdStanMLE` object is the fitted model object returned by the
#'   [`$optimize()`][model-method-optimize] method of a [`CmdStanModel`] object.
#'
#' @section Methods: `CmdStanMLE` objects have the following associated methods,
#'   all of which have their own (linked) documentation pages.
#'
#'  ## Extract contents of fitted model object
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`draws()`][fit-method-draws]  |  Return the point estimate as a 1-row [`draws_matrix`][posterior::draws_matrix]. |
#'  [`$mle()`][fit-method-mle]  |  Return the point estimate as a numeric vector. |
#'  [`$lp()`][fit-method-lp]  |  Return the total log probability density (`target`). |
#'  [`$init()`][fit-method-init]  |  Return user-specified initial values. |
#'  [`$metadata()`][fit-method-metadata] | Return a list of metadata gathered from the CmdStan CSV files. |
#'  [`$code()`][fit-method-code] | Return Stan code as a character vector. |
#'
#'  ## Summarize inferences
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$summary()`][fit-method-summary]  |  Run [`posterior::summarise_draws()`][posterior::draws_summary]. |
#'
#'  ## Save fitted model object and temporary files
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$save_object()`][fit-method-save_object] |  Save fitted model object to a file. |
#'  [`$save_output_files()`][fit-method-save_output_files]  |  Save output CSV files to a specified location. |
#'  [`$save_data_file()`][fit-method-save_data_file]  |  Save JSON data file to a specified location. |
#'
#'  ## Report run times, console output, return codes
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$time()`][fit-method-time]      |  Report the total run time. |
#'  [`$output()`][fit-method-output]  |  Pretty print the output that was printed to the console. |
#'  [`$return_codes()`][fit-method-return_codes]  |  Return the return codes from the CmdStan runs. |
#'
#'  ## Expose Stan functions and additional methods to R
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$expose_functions()`][fit-method-expose_functions] |  Expose Stan functions for use in R. |
#'  [`$init_model_methods()`][fit-method-init_model_methods] | Expose methods for log-probability, gradients, parameter constraining and unconstraining. |
#'  [`$log_prob()`][fit-method-log_prob] | Calculate log-prob. |
#'  [`$grad_log_prob()`][fit-method-grad_log_prob] | Calculate log-prob and gradient. |
#'  [`$hessian()`][fit-method-hessian] | Calculate log-prob, gradient, and hessian. |
#'  [`$constrain_variables()`][fit-method-constrain_variables] | Transform a set of unconstrained parameter values to the constrained scale. |
#'  [`$unconstrain_variables()`][fit-method-unconstrain_variables] | Transform a set of parameter values to the unconstrained scale. |
#'  [`$unconstrain_draws()`][fit-method-unconstrain_draws] | Transform all parameter draws to the unconstrained scale. |
#'  [`$variable_skeleton()`][fit-method-variable_skeleton] | Helper function to re-structure a vector of constrained parameter values. |
#'
CmdStanMLE <- R6::R6Class(
  classname = "CmdStanMLE",
  inherit = CmdStanFit,
  public = list(),
  private = list(
    # inherits draws_ and metadata_ slots from CmdStanFit
    read_csv_ = function(format = getOption("cmdstanr_draws_format", "draws_matrix")) {
      if (!length(self$output_files(include_failed = FALSE))) {
        stop("Optimization failed. Unable to retrieve the draws.", call. = FALSE)
      }
      csv_contents <- read_cmdstan_csv(self$output_files(), format = format)
      private$draws_ <- csv_contents$point_estimates
      private$metadata_ <- csv_contents$metadata
      invisible(self)
    }
  )
)

#' Extract (penalized) maximum likelihood estimate after optimization
#'
#' @name fit-method-mle
#' @aliases mle
#' @description The `$mle()` method is only available for [`CmdStanMLE`] objects.
#' It returns the penalized maximum likelihood estimate (posterior mode) as a
#' numeric vector with one element per variable. The returned vector does *not*
#' include `lp__`, the total log probability (`target`) accumulated in the
#' model block of the Stan program, which is available via the
#' [`$lp()`][fit-method-lp] method and also included in the
#' [`$draws()`][fit-method-draws] method.
#'
#' @param variables (character vector) The variables (parameters, transformed
#'   parameters, and generated quantities) to include. If NULL (the default)
#'   then all variables are included.
#'
#' @return A numeric vector. See **Examples**.
#'
#' @seealso [`CmdStanMLE`]
#'
#' @examples
#' \dontrun{
#' fit <- cmdstanr_example("logistic", method = "optimize")
#' fit$mle("alpha")
#' fit$mle("beta")
#' fit$mle("beta[2]")
#' }
#'
mle <- function(variables = NULL) {
  x <- self$draws(variables)
  x <- x[, colnames(x) != "lp__"]
  stats::setNames(as.numeric(x), posterior::variables(x))
}
CmdStanMLE$set("public", name = "mle", value = mle)

# CmdStanLaplace ---------------------------------------------------------------
#' CmdStanLaplace objects
#'
#' @name CmdStanLaplace
#' @family fitted model objects
#' @template seealso-docs
#'
#' @description A `CmdStanLaplace` object is the fitted model object returned by the
#'   [`$laplace()`][model-method-laplace] method of a
#'   [`CmdStanModel`] object.
#'
#' @section Methods: `CmdStanLaplace` objects have the following associated methods,
#'   all of which have their own (linked) documentation pages.
#'
#'  ## Extract contents of fitted model object
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$draws()`][fit-method-draws]  |  Return approximate posterior draws as a [`draws_matrix`][posterior::draws_matrix]. |
#'  `$mode()` | Return the mode as a [`CmdStanMLE`] object. |
#'  [`$lp()`][fit-method-lp]  |  Return the total log probability density (`target`) computed in the model block of the Stan program. |
#'  [`$lp_approx()`][fit-method-lp]  |  Return the log density of the approximation to the posterior. |
#'  [`$init()`][fit-method-init] |  Return user-specified initial values. |
#'  [`$metadata()`][fit-method-metadata] | Return a list of metadata gathered from the CmdStan CSV files. |
#'  [`$code()`][fit-method-code] | Return Stan code as a character vector. |
#'
#'  ## Summarize inferences
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$summary()`][fit-method-summary]  | Run [`posterior::summarise_draws()`][posterior::draws_summary]. |
#'
#'  ## Save fitted model object and temporary files
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$save_object()`][fit-method-save_object] |  Save fitted model object to a file. |
#'  [`$save_output_files()`][fit-method-save_output_files] |  Save output CSV files to a specified location. |
#'  [`$save_data_file()`][fit-method-save_data_file] |  Save JSON data file to a specified location. |
#'  [`$save_latent_dynamics_files()`][fit-method-save_latent_dynamics_files] |  Save diagnostic CSV files to a specified location. |
#'
#'  ## Report run times, console output, return codes
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$time()`][fit-method-time]  |  Report the run time of the Laplace sampling step. |
#'  [`$output()`][fit-method-output]  |  Pretty print the output that was printed to the console. |
#'  [`$return_codes()`][fit-method-return_codes]  |  Return the return codes from the CmdStan runs. |
#'
CmdStanLaplace <- R6::R6Class(
  classname = "CmdStanLaplace",
  inherit = CmdStanFit,
  public = list(
    mode = function() self$runset$args$method_args$mode_object
  ),
  private = list(
    # inherits draws_ and metadata_ slots from CmdStanFit
    read_csv_ = function(format = getOption("cmdstanr_draws_format", "draws_matrix")) {
      if (!length(self$output_files(include_failed = FALSE))) {
        stop("Laplace inference failed. Unable to retrieve the draws.", call. = FALSE)
      }
      csv_contents <- read_cmdstan_csv(self$output_files(), format = format)
      private$draws_ <- csv_contents$draws
      private$metadata_ <- csv_contents$metadata
      invisible(self)
    }
  )
)
CmdStanLaplace$set("public", name = "lp_approx", value = lp_approx)


# CmdStanVB ---------------------------------------------------------------
#' CmdStanVB objects
#'
#' @name CmdStanVB
#' @family fitted model objects
#' @template seealso-docs
#'
#' @description A `CmdStanVB` object is the fitted model object returned by the
#'   [`$variational()`][model-method-variational] method of a
#'   [`CmdStanModel`] object.
#'
#' @section Methods: `CmdStanVB` objects have the following associated methods,
#'   all of which have their own (linked) documentation pages.
#'
#'  ## Extract contents of fitted model object
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$draws()`][fit-method-draws]  |  Return approximate posterior draws as a [`draws_matrix`][posterior::draws_matrix]. |
#'  [`$lp()`][fit-method-lp]  |  Return the total log probability density (`target`) computed in the model block of the Stan program. |
#'  [`$lp_approx()`][fit-method-lp]  |  Return the log density of the variational approximation to the posterior. |
#'  [`$init()`][fit-method-init] |  Return user-specified initial values. |
#'  [`$metadata()`][fit-method-metadata] | Return a list of metadata gathered from the CmdStan CSV files. |
#'  [`$code()`][fit-method-code] | Return Stan code as a character vector. |
#'
#'  ## Summarize inferences
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$summary()`][fit-method-summary]  | Run [`posterior::summarise_draws()`][posterior::draws_summary]. |
#'  [`$cmdstan_summary()`][fit-method-cmdstan_summary] |  Run and print CmdStan's `bin/stansummary`. |
#'
#'  ## Save fitted model object and temporary files
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$save_object()`][fit-method-save_object] |  Save fitted model object to a file. |
#'  [`$save_output_files()`][fit-method-save_output_files] |  Save output CSV files to a specified location. |
#'  [`$save_data_file()`][fit-method-save_data_file] |  Save JSON data file to a specified location. |
#'  [`$save_latent_dynamics_files()`][fit-method-save_latent_dynamics_files] |  Save diagnostic CSV files to a specified location. |
#'
#'  ## Report run times, console output, return codes
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$time()`][fit-method-time]  |  Report the total run time. |
#'  [`$output()`][fit-method-output]  |  Pretty print the output that was printed to the console. |
#'  [`$return_codes()`][fit-method-return_codes]  |  Return the return codes from the CmdStan runs. |
#'
#'  ## Expose Stan functions and additional methods to R
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$expose_functions()`][fit-method-expose_functions] |  Expose Stan functions for use in R. |
#'  [`$init_model_methods()`][fit-method-init_model_methods] | Expose methods for log-probability, gradients, parameter constraining and unconstraining. |
#'  [`$log_prob()`][fit-method-log_prob] | Calculate log-prob. |
#'  [`$grad_log_prob()`][fit-method-grad_log_prob] | Calculate log-prob and gradient. |
#'  [`$hessian()`][fit-method-hessian] | Calculate log-prob, gradient, and hessian. |
#'  [`$constrain_variables()`][fit-method-constrain_variables] | Transform a set of unconstrained parameter values to the constrained scale. |
#'  [`$unconstrain_variables()`][fit-method-unconstrain_variables] | Transform a set of parameter values to the unconstrained scale. |
#'  [`$unconstrain_draws()`][fit-method-unconstrain_draws] | Transform all parameter draws to the unconstrained scale. |
#'  [`$variable_skeleton()`][fit-method-variable_skeleton] | Helper function to re-structure a vector of constrained parameter values. |
#'
CmdStanVB <- R6::R6Class(
  classname = "CmdStanVB",
  inherit = CmdStanFit,
  public = list(),
  private = list(
    # inherits draws_ and metadata_ slots from CmdStanFit
    read_csv_ = function(format = getOption("cmdstanr_draws_format", "draws_matrix")) {
      if (!length(self$output_files(include_failed = FALSE))) {
        stop("Variational inference failed. Unable to retrieve the draws.", call. = FALSE)
      }
      csv_contents <- read_cmdstan_csv(self$output_files(), format = format)
      private$draws_ <- csv_contents$draws
      private$metadata_ <- csv_contents$metadata
      invisible(self)
    }
  )
)
CmdStanVB$set("public", name = "lp_approx", value = lp_approx)

# CmdStanPathfinder ---------------------------------------------------------------
#' CmdStanPathfinder objects
#'
#' @name CmdStanPathfinder
#' @family fitted model objects
#' @template seealso-docs
#'
#' @description A `CmdStanPathfinder` object is the fitted model object returned by the
#'   [`$pathfinder()`][model-method-pathfinder] method of a
#'   [`CmdStanModel`] object.
#'
#' @section Methods: `CmdStanPathfinder` objects have the following associated methods,
#'   all of which have their own (linked) documentation pages.
#'
#'  ## Extract contents of fitted model object
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$draws()`][fit-method-draws]  |  Return approximate posterior draws as a [`draws_matrix`][posterior::draws_matrix]. |
#'  [`$lp()`][fit-method-lp]  |  Return the total log probability density (`target`) computed in the model block of the Stan program. |
#'  [`$lp_approx()`][fit-method-lp]  |  Return the log density of the approximation to the posterior. |
#'  [`$init()`][fit-method-init] |  Return user-specified initial values. |
#'  [`$metadata()`][fit-method-metadata] | Return a list of metadata gathered from the CmdStan CSV files. |
#'  [`$code()`][fit-method-code] | Return Stan code as a character vector. |
#'
#'  ## Summarize inferences
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$summary()`][fit-method-summary]  | Run [`posterior::summarise_draws()`][posterior::draws_summary]. |
#'  [`$cmdstan_summary()`][fit-method-cmdstan_summary] |  Run and print CmdStan's `bin/stansummary`. |
#'
#'  ## Save fitted model object and temporary files
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$save_object()`][fit-method-save_object] |  Save fitted model object to a file. |
#'  [`$save_output_files()`][fit-method-save_output_files] |  Save output CSV files to a specified location. |
#'  [`$save_data_file()`][fit-method-save_data_file] |  Save JSON data file to a specified location. |
#'  [`$save_latent_dynamics_files()`][fit-method-save_latent_dynamics_files] |  Save diagnostic CSV files to a specified location. |
#'
#'  ## Report run times, console output, return codes
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$time()`][fit-method-time]  |  Report the total run time. |
#'  [`$output()`][fit-method-output]  |  Pretty print the output that was printed to the console. |
#'  [`$return_codes()`][fit-method-return_codes]  |  Return the return codes from the CmdStan runs. |
#'
CmdStanPathfinder <- R6::R6Class(
  classname = "CmdStanPathfinder",
  inherit = CmdStanFit,
  public = list(),
  private = list(
    # inherits draws_ and metadata_ slots from CmdStanFit
    read_csv_ = function(format = getOption("cmdstanr_draws_format", "draws_matrix")) {
      if (!length(self$output_files(include_failed = FALSE))) {
        stop("Pathfinder failed. Unable to retrieve the draws.", call. = FALSE)
      }
      csv_contents <- read_cmdstan_csv(self$output_files(), format = format)
      private$draws_ <- csv_contents$draws
      private$metadata_ <- csv_contents$metadata
      invisible(self)
    }
  )
)

#' @rdname fit-method-lp
lp_approx <- function() {
  as.numeric(self$draws()[, "lp_approx__"])
}
CmdStanPathfinder$set("public", name = "lp_approx", value = lp_approx)



# CmdStanGQ ---------------------------------------------------------------
#' CmdStanGQ objects
#'
#' @name CmdStanGQ
#' @family fitted model objects
#' @template seealso-docs
#'
#' @description A `CmdStanGQ` object is the fitted model object returned by the
#'   [`$generate_quantities()`][model-method-generate-quantities] method of a
#'   [`CmdStanModel`] object.
#'
#' @section Methods: `CmdStanGQ` objects have the following associated methods,
#'   all of which have their own (linked) documentation pages.
#'
#'  ## Extract contents of generated quantities object
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$draws()`][fit-method-draws] | Return the generated quantities as a [`draws_array`][posterior::draws_array]. |
#'  [`$metadata()`][fit-method-metadata] | Return a list of metadata gathered from the CmdStan CSV files. |
#'  [`$code()`][fit-method-code] | Return Stan code as a character vector. |
#'
#'  ## Summarize inferences
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$summary()`][fit-method-summary] | Run [`posterior::summarise_draws()`][posterior::draws_summary]. |
#'
#'  ## Save fitted model object and temporary files
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$save_object()`][fit-method-save_object] | Save fitted model object to a file. |
#'  [`$save_output_files()`][fit-method-save_output_files] | Save output CSV files to a specified location. |
#'  [`$save_data_file()`][fit-method-save_data_file] | Save JSON data file to a specified location. |
#'
#'  ## Report run times, console output, return codes
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$time()`][fit-method-time] | Report the total run time. |
#'  [`$output()`][fit-method-output] | Return the stdout and stderr of all chains or pretty print the output for a single chain. |
#'  [`$return_codes()`][fit-method-return_codes]  |  Return the return codes from the CmdStan runs. |
#'
#' @inherit model-method-generate-quantities examples
#'
CmdStanGQ <- R6::R6Class(
  classname = "CmdStanGQ",
  inherit = CmdStanFit,
  public = list(
    fitted_params_files = function() {
      self$runset$args$method_args$fitted_params
    },
    num_chains = function() {
      super$num_procs()
    },
    # override CmdStanFit draws method
    draws = function(variables = NULL, inc_warmup = FALSE, format = getOption("cmdstanr_draws_format", "draws_array")) {
      if (!length(self$output_files(include_failed = FALSE))) {
        stop("Generating quantities for all MCMC chains failed. Unable to retrieve the generated quantities.", call. = FALSE)
      }
      if (inc_warmup) {
        warning("'inc_warmup' is ignored except when used with CmdStanMCMC objects.",
                call. = FALSE)
      }
      format <- assert_valid_draws_format(format)
      to_read <- remaining_columns_to_read(
        requested = variables,
        currently_read = dimnames(private$draws_)$variable,
        all = private$metadata_$variables
      )
      private$draws_ <- maybe_convert_draws_format(private$draws_, format)
      if (is.null(to_read) || any(nzchar(to_read))) {
        private$read_csv_(variables = to_read, format = format)
      }
      if (is.null(variables)) {
        variables <- private$metadata_$variables
      } else {
        matching_res <- matching_variables(variables, private$metadata_$variables)
        if (length(matching_res$not_found)) {
          stop("Can't find the following variable(s) in the output: ",
              paste(matching_res$not_found, collapse = ", "), call. = FALSE)
        }
        variables <- matching_res$matching
      }
      posterior::subset_draws(private$draws_, variable = variables)
    },
    # override CmdStanFit output method
    output = function(id = NULL) {
      if (is.null(id)) {
        self$runset$procs$proc_output()
      } else {
        cat(paste(self$runset$procs$proc_output(id), collapse = "\n"))
      }
    }
  ),
  private = list(
    # inherits draws_ and metadata_ slots from CmdStanFit
    read_csv_ = function(variables = NULL, format = getOption("cmdstanr_draws_format", "draws_array")) {
      if (!length(self$output_files(include_failed = FALSE))) {
        stop("Generating quantities for all MCMC chains failed. Unable to retrieve the generated quantities.", call. = FALSE)
      }
      csv_contents <- read_cmdstan_csv(
        files = self$output_files(include_failed = FALSE),
        variables = variables,
        sampler_diagnostics = "",
        format = format
      )
      private$metadata_ <- csv_contents$metadata
      if (!is.null(csv_contents$generated_quantities)) {
        missing_variables <- posterior::variables(csv_contents$generated_quantities)[!(posterior::variables(csv_contents$generated_quantities) %in% posterior::variables(private$draws_))]
        private$draws_ <-
          posterior::bind_draws(
            private$draws_,
            posterior::subset_draws(csv_contents$generated_quantities, variable = missing_variables),
            along = "variable"
          )
      }
      invisible(self)
    }
  )
)


# CmdStan Diagnose --------------------------------------------------------
#' CmdStanDiagnose objects
#'
#' @name CmdStanDiagnose
#' @family fitted model objects
#' @template seealso-docs
#'
#' @description A `CmdStanDiagnose` object is the object returned by the
#'   [`$diagnose()`][model-method-diagnose] method of a [`CmdStanModel`] object.
#'
#' @section Methods: `CmdStanDiagnose` objects have the following associated
#'   methods:
#'
#'  |**Method**|**Description**|
#'  |:----------|:---------------|
#'  [`$gradients()`][fit-method-gradients] |  Return gradients from diagnostic mode. |
#'  [`$lp()`][fit-method-lp] |  Return the total log probability density (`target`). |
#'  [`$init()`][fit-method-init] |  Return user-specified initial values. |
#'  [`$metadata()`][fit-method-metadata] | Return a list of metadata gathered from the CmdStan CSV files. |
#'  [`$save_output_files()`][fit-method-save_output_files] |  Save output CSV files to a specified location. |
#'  [`$save_data_file()`][fit-method-save_data_file] |  Save JSON data file to a specified location. |
#'
#' @examples
#' \dontrun{
#' test <- cmdstanr_example("logistic", method = "diagnose")
#'
#' # retrieve the gradients
#' test$gradients()
#' }
#'
CmdStanDiagnose <- R6::R6Class(
  classname = "CmdStanDiagnose",
  public = list(
    runset = NULL,
    initialize = function(runset) {
      checkmate::assert_r6(runset, classes = "CmdStanRun")
      self$runset <- runset
      csv_data <- read_cmdstan_csv(self$runset$output_files())
      private$metadata_ <- csv_data$metadata
      private$gradients_ <- csv_data$gradients
      private$lp_ <- csv_data$lp
      invisible(self)
    },
    metadata = function() {
      private$metadata_
    }
  ),
  private = list(
    metadata_ = NULL,
    gradients_ = NULL,
    lp_ = NULL,
    init_ = NULL
  )
)

#' Extract gradients after diagnostic mode
#'
#' @name fit-method-gradients
#' @aliases gradients
#' @description Return the data frame containing the gradients for all
#'   parameters.
#'
#' @return A list of lists. See **Examples**.
#'
#' @seealso [`CmdStanDiagnose`]
#' @inherit CmdStanDiagnose examples
#'
gradients <- function() {
  private$gradients_
}

lp_diagnose <- function() {
  as.numeric(private$lp_)
}

CmdStanDiagnose$set("public", name = "gradients", value = gradients)
CmdStanDiagnose$set("public", name = "lp", value = lp_diagnose)
CmdStanDiagnose$set("public", name = "init", value = init)
CmdStanDiagnose$set("public", name = "save_output_files", value = save_output_files)
CmdStanDiagnose$set("public", name = "output_files", value = output_files)
CmdStanDiagnose$set("public", name = "save_data_file", value = save_data_file)
CmdStanDiagnose$set("public", name = "data_file", value = data_file)



# as_draws ----------------------------------------------------------------
#' Create a `draws` object from a CmdStanR fitted model object
#'
#' Create a `draws` object supported by the \pkg{posterior} package. These
#' methods are just wrappers around CmdStanR's [`$draws()`][fit-method-draws]
#' method provided for convenience.
#'
#' @aliases as_draws
#' @importFrom posterior as_draws
#' @export
#' @export as_draws
#'
#' @param x A CmdStanR fitted model object.
#' @param ... Optional arguments passed to the [`$draws()`][fit-method-draws]
#'   method (e.g., `variables`, `inc_warmup`, etc.).
#'
#' @details To subset iterations, chains, or draws, use the
#'   [posterior::subset_draws()] method after creating the `draws` object.
#'
#' @examples
#' \dontrun{
#' fit <- cmdstanr_example()
#' as_draws(fit)
#'
#' # posterior's as_draws_*() methods will also work
#' posterior::as_draws_rvars(fit)
#' posterior::as_draws_list(fit)
#' }
#'
as_draws.CmdStanMCMC <- function(x, ...) {
  x$draws(...)
}

#' @rdname as_draws.CmdStanMCMC
#' @export
as_draws.CmdStanMLE <- function(x, ...) {
  x$draws(...)
}

#' @rdname as_draws.CmdStanMCMC
#' @export
as_draws.CmdStanLaplace <- function(x, ...) {
  x$draws(...)
}

#' @rdname as_draws.CmdStanMCMC
#' @export
as_draws.CmdStanVB <- function(x, ...) {
  x$draws(...)
}

#' @rdname as_draws.CmdStanMCMC
#' @export
as_draws.CmdStanGQ <- function(x, ...) {
  x$draws(...)
}

#' @rdname as_draws.CmdStanMCMC
#' @export
as_draws.CmdStanPathfinder <- function(x, ...) {
  x$draws(...)
}
stan-dev/cmdstanr documentation built on May 16, 2024, 12:58 a.m.