R/loo-functions.R

Defines functions .exp_log_sum_exp .loo_pit loo_pit.default loo_pit loo_predictive_interval loo_predict loo_linpred

Documented in loo_linpred loo_pit loo_pit.default loo_predict loo_predictive_interval

#' Generic functions for LOO predictions
#'
#' See the methods in the \pkg{rstanarm} package for examples.
#'
#' @name loo-prediction
#'
#' @template args-object
#' @template args-dots
#'
#' @return `loo_predict()`, `loo_linpred()`, and `loo_pit()`
#'   (probability integral transform) methods should return a vector with length
#'   equal to the number of observations in the data.
#'   `loo_predictive_interval()` methods should return a two-column matrix
#'   formatted in the same way as for [predictive_interval()].
#'
#' @template seealso-rstanarm-pkg
#' @template seealso-vignettes
#'

#' @rdname loo-prediction
#' @export
loo_linpred <- function(object, ...) {
  UseMethod("loo_linpred")
}

#' @rdname loo-prediction
#' @export
loo_predict <- function(object, ...) {
  UseMethod("loo_predict")
}

#' @rdname loo-prediction
#' @export
loo_predictive_interval <- function(object, ...) {
  UseMethod("loo_predictive_interval")
}

#' @rdname loo-prediction
#' @export
loo_pit <- function(object, ...) {
  UseMethod("loo_pit")
}

#' @rdname loo-prediction
#' @export
#' @param y For the default method of `loo_pit()`, a vector of `y` values the
#'   same length as the number of columns in the matrix used as `object`.
#' @param lw For the default method of `loo_pit()`, a matrix of log-weights of
#'   the same length as the number of columns in the matrix used as `object`.
#'
loo_pit.default <- function(object, y, lw, ...) {
  if (!is.matrix(object))
    stop("For the default method 'object' should be a matrix.")
  stopifnot(
    is.numeric(object), is.numeric(y), length(y) == ncol(object),
    is.matrix(lw), identical(dim(lw), dim(object))
  )
  .loo_pit(y = y, yrep = object, lw = lw)
}

# internal ----------------------------------------------------------------
.loo_pit <- function(y, yrep, lw) {
  vapply(seq_len(ncol(yrep)), function(j) {
    sel <- yrep[, j] <= y[j]
    .exp_log_sum_exp(lw[sel, j])
  }, FUN.VALUE = 1)
}
.exp_log_sum_exp <- function(x) {
  m <- suppressWarnings(max(x))
  exp(m + log(sum(exp(x - m))))
}

Try the rstantools package in your browser

Any scripts or data that you put into this service are public.

rstantools documentation built on July 26, 2023, 5:35 p.m.