R/print.R

Defines functions convert_old_object .k_help .warn .fr print_mcse_summary print_reff_summary print_dims.psis_loo_ss print_dims.kfold print_dims.waic print_dims.importance_sampling_loo print_dims.psis_loo print_dims.importance_sampling print_dims print.importance_sampling print.psis print.psis_loo_ap print.importance_sampling_loo print.psis_loo print.waic print.loo

Documented in print_dims print_dims.importance_sampling print_dims.importance_sampling_loo print_dims.kfold print_dims.psis_loo print_dims.psis_loo_ss print_dims.waic print.importance_sampling print.importance_sampling_loo print.loo print.psis print.psis_loo print.psis_loo_ap print.waic

#' Print methods
#'
#' @export
#' @param x An object returned by [loo()], [psis()], or [waic()].
#' @param digits An integer passed to [base::round()].
#' @param plot_k Logical. If `TRUE` the estimates of the Pareto shape
#'   parameter \eqn{k} are plotted. Ignored if `x` was generated by
#'   [waic()]. To just plot \eqn{k} without printing use the
#'   [plot()][pareto-k-diagnostic] method for 'loo' objects.
#' @param ... Arguments passed to [plot.psis_loo()] if `plot_k` is
#'   `TRUE`.
#'
#' @return `x`, invisibly.
#'
#' @seealso [pareto-k-diagnostic]
#'
print.loo <- function(x, digits = 1, ...) {
  cat("\n")
  print_dims(x)
  if (!("estimates" %in% names(x))) {
    x <- convert_old_object(x)
  }
  cat("\n")
  print(.fr(as.data.frame(x$estimates), digits), quote = FALSE)
  return(invisible(x))
}

#' @export
#' @rdname print.loo
print.waic <- function(x, digits = 1, ...) {
  print.loo(x, digits = digits, ...)
  throw_pwaic_warnings(x$pointwise[, "p_waic"], digits = digits, warn = FALSE)
  invisible(x)
}

#' @export
#' @rdname print.loo
print.psis_loo <- function(x, digits = 1, plot_k = FALSE, ...) {
  print.loo(x, digits = digits, ...)
  cat("------\n")
  print_mcse_summary(x, digits = digits)
  S <- dim(x)[1]
  k_threshold <- ps_khat_threshold(S)
  if (length(pareto_k_ids(x, threshold = k_threshold))) {
    cat("\n")
  }
  print(pareto_k_table(x), digits = digits)
  cat(.k_help())
  if (plot_k) {
    graphics::plot(x, ...)
  }
  invisible(x)
}

#' @export
#' @rdname print.loo
print.importance_sampling_loo <- function(x, digits = 1, plot_k = FALSE, ...) {
  print.loo(x, digits = digits, ...)
  cat("------\n")
  invisible(x)
}

#' @export
#' @rdname print.loo
print.psis_loo_ap <- function(x, digits = 1, plot_k = FALSE, ...) {
  print.loo(x, digits = digits, ...)
  cat("------\n")
  cat("Posterior approximation correction used.\n")
  attr(x, 'r_eff') <- 1
  print_mcse_summary(x, digits = digits)
  S <- dim(x)[1]
  k_threshold <- ps_khat_threshold(S)
  if (length(pareto_k_ids(x, threshold = k_threshold))) {
    cat("\n")
  }
  print(pareto_k_table(x), digits = digits)
  cat(.k_help())
  if (plot_k) {
    graphics::plot(x, ...)
  }
  invisible(x)
}


#' @export
#' @rdname print.loo
print.psis <- function(x, digits = 1, plot_k = FALSE, ...) {
  print_dims(x)
  print_reff_summary(x, digits)
  print(pareto_k_table(x), digits = digits)
  cat(.k_help())
  if (plot_k) {
    graphics::plot(x, ...)
  }
  invisible(x)
}

#' @export
#' @rdname print.loo
print.importance_sampling <- function(x, digits = 1, plot_k = FALSE, ...) {
  print_dims(x)
  if (plot_k) {
    graphics::plot(x, ...)
  }
  invisible(x)
}

# internal ----------------------------------------------------------------

#' Print dimensions of log-likelihood or log-weights matrix
#'
#' @export
#' @keywords internal
#'
#' @param x The object returned by [psis()], [loo()], or [waic()].
#' @param ... Ignored.
print_dims <- function(x, ...) UseMethod("print_dims")

#' @rdname print_dims
#' @export
print_dims.importance_sampling <- function(x, ...) {
  cat(
    "Computed from",
    paste(dim(x), collapse = " by "),
    "log-weights matrix.\n"
  )
}

#' @rdname print_dims
#' @export
print_dims.psis_loo <- function(x, ...) {
  cat(
    "Computed from",
    paste(dim(x), collapse = " by "),
    "log-likelihood matrix.\n"
  )
}

#' @rdname print_dims
#' @export
print_dims.importance_sampling_loo <- function(x, ...) {
  cat(
    "Computed from",
    paste(dim(x), collapse = " by "),
    "log-likelihood matrix using", class(x)[1], ".\n"
  )
}

#' @rdname print_dims
#' @export
print_dims.waic <- function(x, ...) {
  cat(
    "Computed from",
    paste(dim(x), collapse = " by "),
    "log-likelihood matrix.\n"
  )
}

#' @rdname print_dims
#' @export
print_dims.kfold <- function(x, ...) {
  K <- attr(x, "K", exact = TRUE)
  if (!is.null(K)) {
    cat("Based on", paste0(K, "-fold"), "cross-validation.\n")
  }
}

#' @rdname print_dims
#' @export
print_dims.psis_loo_ss <- function(x, ...) {
  cat(
    "Computed from",
    paste(c(dim(x)[1], nobs(x)) , collapse = " by "),
    "subsampled log-likelihood\nvalues from",
    length(x$loo_subsampling$elpd_loo_approx),
    "total observations.\n"
  )
}

print_reff_summary <- function(x, digits) {
  r_eff <- x$diagnostics$r_eff
  if (is.null(r_eff)) {
    if (!is.null(x$psis_object)) {
      r_eff <- attr(x$psis_object,'r_eff')
    } else {
      r_eff <- attr(x,'r_eff')
    }
  }
  if (!is.null(r_eff)) {
    if (all(r_eff==1)) {
      cat(
        "MCSE and ESS estimates assume independent draws (r_eff=1).\n"
      )
    } else {
      cat(paste0(
        "MCSE and ESS estimates assume MCMC draws (r_eff in [",
        .fr(min(r_eff), digits),
        ", ",
        .fr(max(r_eff), digits),
        "]).\n"
      ))
    }
  }
}

print_mcse_summary <- function(x, digits) {
  mcse_val <- mcse_loo(x)
  cat(
    "MCSE of elpd_loo is",
    paste0(.fr(mcse_val, digits), ".\n")
  )
  print_reff_summary(x, digits)
}

# print and warning helpers
.fr <- function(x, digits) format(round(x, digits), nsmall = digits)
.warn <- function(..., call. = FALSE) warning(..., call. = call.)
.k_help <- function() "See help('pareto-k-diagnostic') for details.\n"

# compatibility with old loo objects
convert_old_object <- function(x, digits = 1, ...) {
  z <- x[-grep("pointwise|pareto_k|n_eff", names(x))]
  uz <- unlist(z)
  nms <- names(uz)
  ses <- grepl("se", nms)
  list(estimates = data.frame(Estimate = uz[!ses], SE = uz[ses]))
}
jgabry/loo documentation built on April 19, 2024, 4:08 a.m.