R/mcmcRocPrc.R

Defines functions mcmcRocPrc.mcmc mcmcRocPrc.bugs mcmcRocPrc.brmsfit mcmcRocPrc.stanreg identify_link_function is_binary_model mcmcRocPrc.stanfit mcmcRocPrc.runjags mcmcRocPrc.rjags mcmcRocPrc.jags auc_pr auc_roc compute_pr compute_roc mcmcRocPrc.default new_mcmcRocPrc mcmcRocPrc

Documented in compute_pr compute_roc identify_link_function is_binary_model mcmcRocPrc mcmcRocPrc.brmsfit mcmcRocPrc.bugs mcmcRocPrc.default mcmcRocPrc.jags mcmcRocPrc.mcmc mcmcRocPrc.rjags mcmcRocPrc.runjags mcmcRocPrc.stanfit mcmcRocPrc.stanreg new_mcmcRocPrc

#
#   This file contains the mcmcRocPrc() S3 generic, which constructs objects
#   of class "mcmcRocPrc". For methods for this class, see mcmcRocPrc-methods.R
#   S3 methods for the mcmcRocPrc() generic handle different types of input
#   e.g. "rjags" input produced by R2jags.
#



#' ROC and Precision-Recall Curves using Bayesian MCMC estimates
#' 
#' Generate ROC and Precision-Recall curves after fitting a Bayesian logit or 
#' probit regression using [rstan::stan()], [rstanarm::stan_glm()], 
#' [R2jags::jags()], [R2WinBUGS::bugs()], [MCMCpack::MCMClogit()], or other 
#' functions that provide samples from a posterior density. 
#' 
#' @param object A fitted binary choice model, e.g. "rjags" object 
#'   (see [R2jags::jags()]), or a `[N, iter]` matrix of predicted probabilites.
#' @param curves logical indicator of whether or not to return values to plot 
#'   the ROC or Precision-Recall curves. If set to `FALSE` (default), 
#'   results are returned as a list without the extra values. 
#' @param fullsims logical indicator of whether full object (based on all MCMC
#'   draws rather than their average) will be returned. Default is `FALSE`. 
#'   Note: If `TRUE` is chosen, the function takes notably longer to execute.
#' @param yvec A `numeric(N)` vector of observed outcomes. 
#' @param yname (`character(1)`)\cr
#'   The name of the dependent variable, should match the variable name in the 
#'   JAGS data object.
#' @param xnames ([base::character()])\cr
#'   A character vector of the independent variable names, should match the 
#'   corresponding names in the JAGS data object.
#' @param posterior_samples a "mcmc" object with the posterior samples
#' @param ... Used by methods
#' @param x a `mcmcRocPrc()` object
#' 
#' @details If only the average AUC-ROC and PR are of interest, setting 
#'   `curves = FALSE` and `fullsims = FALSE` can greatly speed up calculation 
#'   time. The curve data (`curves = TRUE`) is needed for plotting. The plot
#'   method will always plot both the ROC and PR curves, but the underlying
#'   data can easily be extracted from the output for your own plotting; 
#'   see the documentation of the value returned below. 
#'   
#'   The default method works with a matrix of predicted probabilities and the 
#'   vector of observed incomes as input. Other methods accommodate some of the 
#'   common Bayesian modeling packages like rstan (which returns class "stanfit"),
#'   rstanarm ("stanreg"), R2jags ("jags"), R2WinBUGS ("bugs"), and 
#'   MCMCpack ("mcmc"). Even if a package-specific method is not implemented, 
#'   the default method can always be used as a fallback by manually calculating
#'   the matrix of predicted probabilities for each posterior sample. 
#'   
#'   Note that MCMCpack returns generic "mcmc" output that is annotated with 
#'   some additional information as attributes, including the original function
#'   call. There is no inherent way to distinguish any other kind of "mcmc" 
#'   object from one generated by a proper MCMCpack modeling function, but as a
#'   basic precaution, `mcmcRocPrc()` will check the saved call and return an 
#'   error if the function called was not `MCMClogit()` or `MCMCprobit()`. 
#'   This behavior can be suppressed by setting `force = TRUE`. 
#' 
#' @references Beger, Andreas. 2016. “Precision-Recall Curves.” Available at 
#' \doi{10.2139/ssrn.2765419}
#' 
#' @return Returns a list with length 2 or 4, depending on the on the "curves" 
#'   and "fullsims" argument values:
#'   
#'   - "area_under_roc": `numeric()`; either length 1 if `fullsims = FALSE`, or 
#'     one value for each posterior sample otherwise
#'   - "area_under_prc": `numeric()`; either length 1 if `fullsims = FALSE`, or 
#'     one value for each posterior sample otherwise
#'   - "prc_dat": only if `curves = TRUE`; a list with length 1 if 
#'     `fullsims = FALSE`, longer otherwise
#'   - "roc_dat": only if `curves = TRUE`; a list with length 1 if 
#'     `fullsims = FALSE`, longer otherwise
#'
#' @examples
#' \donttest{
#' if (interactive()) {
#' # load simulated data and fitted model (see ?sim_data and ?jags_logit)
#' data("jags_logit")
#' 
#' # using mcmcRocPrc
#' fit_sum <- mcmcRocPrc(jags_logit,
#'                       yname = "Y",
#'                       xnames = c("X1", "X2"),
#'                       curves = TRUE,
#'                       fullsims = FALSE)
#' fit_sum                     
#' plot(fit_sum)
#' 
#' # Equivalently, we can calculate the matrix of predicted probabilities 
#' # ourselves; using the example from ?jags_logit:
#' library(R2jags)
#' 
#' data("sim_data")
#' yvec <- sim_data$Y
#' xmat <- sim_data[, c("X1", "X2")]
#' 
#' # add intercept to the X data
#' xmat <- as.matrix(cbind(Intercept = 1L, xmat))
#' 
#' beta <- as.matrix(as.mcmc(jags_logit))[, c("b[1]", "b[2]", "b[3]")]
#' pred_mat <- plogis(xmat %*% t(beta)) 
#' 
#' # the matrix of predictions has rows matching the number of rows in the data;
#' # the column are the predictions for each of the 2,000 posterior samples
#' nrow(sim_data)
#' dim(pred_mat)
#' 
#' # now we can call mcmcRocPrc; the default method works with the matrix
#' # of predictions and vector of outcomes as input
#' mcmcRocPrc(object = pred_mat, curves = TRUE, fullsims = FALSE, yvec = yvec)
#' }
#' }
#' 
#' @export
#' @md
mcmcRocPrc <- function(object, curves = FALSE, fullsims = FALSE, ...) {
  UseMethod("mcmcRocPrc", object)
}

#' Constructor for mcmcRocPrc objects
#' 
#' This function actually does the heavy lifting once we have a matrix of 
#' predicted probabilities from a model, plus the vector of observed outcomes.
#' The reason to have it here in a single function is that we don't replicate 
#' it in each function that accomodates a JAGS, BUGS, RStan, etc. object.
#' 
#' @param pred_prob a `\[N, iter\]` matrix of predicted probabilities 
#' @param yvec a `numeric(N)` vector of observed outcomes
#' @param curves include curve data in output?
#' @param fullsims collapse posterior samples into single summary?
#' 
#' @md
#' @keywords internal
new_mcmcRocPrc <- function(pred_prob, yvec, curves, fullsims) {
  
  stopifnot(
    "number of predictions and observed outcomes do not match" = nrow(pred_prob)==length(yvec),
    "yvec must be 0 or 1"                                      = all(yvec %in% c(0L, 1L)),
    "pred_prob must be in the interval [0, 1]"                 = all(pred_prob >= 0 & pred_prob <= 1)
  )
  
  # pred_prob is a [N, iter] matrix, i.e. each column are preds from one 
  # set of posterior samples
  # if not using fullsims, summarize across columns
  if (isFALSE(fullsims)) {
    
    pred_prob <- as.matrix(apply(pred_prob, MARGIN = 1, median))
    
  }
  
  pred_prob  <- as.data.frame(pred_prob)
  curve_data <- lapply(pred_prob, yy = yvec, FUN = function(x, yy) {
    prc_data <- compute_pr(yvec = yy, pvec = x)
    roc_data <- compute_roc(yvec = yy, pvec = x)
    list(
      prc_dat = prc_data,
      roc_dat = roc_data
    )
  })
  prc_dat <- lapply(curve_data, `[[`, "prc_dat")
  roc_dat <- lapply(curve_data, `[[`, "roc_dat")
  
  # Compute AUC-ROC values
  v_auc_roc <- sapply(roc_dat, function(xy) {
    caTools::trapz(xy$x, xy$y)
  })
  v_auc_pr  <- sapply(prc_dat, function(xy) {
    xy <- subset(xy, !is.nan(xy$y))
    caTools::trapz(xy$x, xy$y)
  })
  
  # Recreate original output formats
  if (curves & fullsims) {
    out <- list(
      area_under_roc = v_auc_roc,
      area_under_prc = v_auc_pr,
      prc_dat = prc_dat,
      roc_dat = roc_dat
    )
  }
  if (curves & !fullsims) {
    out <- list(
      area_under_roc = v_auc_roc,
      area_under_prc = v_auc_pr,
      prc_dat = prc_dat[1],
      roc_dat = roc_dat[1]
    )
  }
  if (!curves & !fullsims) {
    out <- list(
      area_under_roc = v_auc_roc[[1]],
      area_under_prc = v_auc_pr[[1]]
    )
  }
  if (!curves & fullsims) {
    out <- data.frame(
      area_under_roc = v_auc_roc,
      area_under_prc = v_auc_pr
    )
  }
  structure(
    out,
    y_pos_rate = mean(yvec),
    class = "mcmcRocPrc"
  )
}

#' @rdname mcmcRocPrc
#' 
#' @md
#' @export
mcmcRocPrc.default <- function(object, curves, fullsims, yvec, ...) {
  pred_prob <- object
  
  stopifnot(
    "mcmcRocPrc.default requires 'matrix' like input" = inherits(pred_prob, "matrix")
  )
  
  new_mcmcRocPrc(pred_prob, yvec, curves, fullsims)
}

# Under the hood ROC/PRC calculations -------------------------------------

#' Compute ROC and PR curve points
#' 
#' Faster replacements for calculating ROC and PR curve data than with 
#' [ROCR::prediction()] and [ROCR::performance()]
#' 
#' @details Replacements to use instead of a combination of [ROCR::prediction()] 
#' and [ROCR::performance()] to calculate ROC and PR curves. These functions are
#' about 10 to 20 times faster when using [mcmcRocPrc()] with `curves = TRUE` 
#' and/or `fullsims = TRUE`. 
#' 
#' See this [issue on GH (ShanaScogin/BayesPostEst#25)](https://github.com/ShanaScogin/BayesPostEst/issues/25) for more general details.
#' 
#' And [here is a note](https://github.com/andybega/BayesPostEst/blob/f1da23b9db86461d4f9c671d9393265dd10578c5/tests/profile-mcmcRocPrc.md) with specific performance benchmarks, compared to the 
#' old approach relying on ROCR.
#' 
#' @keywords internal
#' @md
compute_roc <- function(yvec, pvec) {
  porder <- order(pvec, decreasing = TRUE)
  yvecs  <- yvec[porder]
  pvecs  <- pvec[porder]
  p      <- sum(yvecs)
  n      <- length(yvecs) - p
  tp     <- cumsum(yvecs)
  tpr    <- tp/p
  fp     <- 1:length(yvecs) - tp
  fpr    <- fp/n
  
  dup_pred  <- rev(duplicated(pvecs))
  dup_stats <- duplicated(tpr) & duplicated(fpr)
  dups <- dup_pred | dup_stats
  
  fpr <- c(0, fpr[!dups])
  tpr <- c(0, tpr[!dups])
  
  roc_data <- data.frame(x = fpr,
                         y = tpr)
  roc_data
}

#' @rdname compute_roc
#' @aliases compute_pr
compute_pr <- function(yvec, pvec) {
  porder <- order(pvec, decreasing = TRUE)
  yvecs  <- yvec[porder]
  pvecs  <- pvec[porder]
  p      <- sum(yvecs)
  n      <- length(yvecs) - p
  tp     <- cumsum(yvecs)
  tpr    <- tp/p
  pp     <- 1:length(yvecs) 
  prec   <- tp/pp
  
  dup_pred  <- rev(duplicated(pvecs))
  dup_stats <- duplicated(tpr) & duplicated(prec)
  dups <- dup_pred | dup_stats
  
  prec <- c(NaN, prec[!dups])
  tpr <- c(0, tpr[!dups])
  
  prc_data <- data.frame(x = tpr,
                         y = prec)
  prc_data
}


# auc_roc and auc_pr are not really used, but keep around just in case
auc_roc <- function(obs, pred) {
  values <- compute_roc(obs, pred)
  caTools::trapz(values$x, values$y)
}

auc_pr <- function(obs, pred) {
  values <- compute_pr(obs, pred)
  caTools::trapz(values$x, values$y)
}



# JAGS-like input (rjags, R2jags, runjags) --------------------------------

#' @rdname mcmcRocPrc
#' 
#' @export
mcmcRocPrc.jags <- function(object, curves = FALSE, fullsims = FALSE, yname, 
                            xnames, posterior_samples, ...) {
  
  stopifnot(
    inherits(posterior_samples, c("mcmc", "mcmc.list"))
  )
  
  link_logit  <- any(grepl("logit", object$model()))
  link_probit <- any(grepl("probit", object$model()))
  
  if (isFALSE(link_logit | link_probit)) {
    stop("Could not identify model link function")
  }
  
  mdl_data <- object$data()
  stopifnot(all(xnames %in% names(mdl_data)))
  stopifnot(all(yname %in% names(mdl_data)))
  
  # add intercept by default, maybe revisit this
  xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames])))
  yvec  <- mdl_data[[yname]]
  
  pardraws <- as.matrix(posterior_samples)
  # this is not very robust, assumes pars are 'b[x]'
  # for both this and the intercept addition above, maybe a more robust solution
  # down the road would be to dig into the object$model$model() string
  betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))]
  
  if(isTRUE(link_logit)) {
    pred_prob <- plogis(xdata %*% t(betadraws))  
  } else if (isTRUE(link_probit)) {
    pred_prob <- pnorm(xdata %*% t(betadraws))
  } 
  
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
                 fullsims = fullsims)
}

#' @rdname mcmcRocPrc
#' 
#' @export
mcmcRocPrc.rjags <- function(object, curves = FALSE, fullsims = FALSE, yname, 
                             xnames, ...) {
  
  if (!requireNamespace("R2jags", quietly = TRUE)) {
    stop("Package \"R2jags\" is needed for this function to work. Please install it.", call. = FALSE)  # nocov
  }
  
  jags_object <- object$model
  pardraws    <- coda::as.mcmc(object)
  
  # pass it on to the "jags" method
  mcmcRocPrc(object = jags_object, curves = curves, fullsims = fullsims, 
             yname = yname, xnames = xnames, posterior_samples = pardraws, ...)
}

#' @rdname mcmcRocPrc
#' 
#' @export
mcmcRocPrc.runjags <- function(object, curves = FALSE, fullsims = FALSE, yname, 
                               xnames, ...) {
  jags_object <- runjags::as.jags(object, quiet = TRUE)
  # as.mcmc.runjags will issue a warning when converting multiple chains
  # because it combines them
  pardraws    <- suppressWarnings(coda::as.mcmc(object))
  
  # pass it on to the "jags" method
  mcmcRocPrc(object = jags_object, curves = curves, fullsims = fullsims, 
             yname = yname, xnames = xnames, posterior_samples = pardraws, ...)
}


# STAN-like input (rstan, rstanarm, brms) ---------------------------------



#' @rdname mcmcRocPrc
#' 
#' @param data the data that was used in the `stan(data = ?, ...)` call
#' 
#' @export
mcmcRocPrc.stanfit <- function(object, curves = FALSE, fullsims = FALSE, data, 
                               xnames, yname, ...) {
  if (!requireNamespace("rstan", quietly = TRUE)) {
    stop("Package \"rstan\" is needed for this function to work. Please install it.", call. = FALSE)  # nocov
  }
  
  if (!is_binary_model(object)) {
    stop("the input model does not seem to be a binary choice model; if this is a mistake please file an issue at https://github.com/ShanaScogin/BayesPostEst/issues/")
  }
  link_type <- identify_link_function(object)
  if (is.na(link_type)) {
    stop("could not identify model link function; please file an issue at https://github.com/ShanaScogin/BayesPostEst/issues/")
  }
  
  mdl_data <- data
  stopifnot(all(xnames %in% names(mdl_data)))
  stopifnot(all(yname %in% names(mdl_data)))
   
  # add intercept by default, maybe revisit this
  xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames])))
  yvec  <- mdl_data[[yname]]
  
  pardraws <- as.matrix(object)
  # this is not very robust, assumes pars are 'b[x]'
  betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))]
  
  if(link_type=="logit") {
    pred_prob <- plogis(xdata %*% t(betadraws))  
  } else if (link_type=="probit") {
    pred_prob <- pnorm(xdata %*% t(betadraws))
  } 
  
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
                 fullsims = fullsims)
  
  
}

#' Try to identify if a stanfit model is a binary choice model
#' 
#' @param obj stanfit object
#' 
#' @keywords internal
is_binary_model <- function(obj) {
  stopifnot(inherits(obj, "stanfit"))
  model_string <- rstan::get_stancode(obj)
  grepl("bernoulli", model_string)
}

#' Try to identify link function 
#' 
#' @param obj stanfit object
#' 
#' @return Either "logit" or "probit"; if neither can be identified the function
#' will return `NA_character_`. 
#' 
#' @keywords internal
identify_link_function <- function(obj) {
  stopifnot(inherits(obj, "stanfit"))
  model_string <- rstan::get_stancode(obj)
  if (grepl("logit", model_string)) return("logit")
  if (grepl("Phi", model_string)) return("probit")
  NA_character_
}

#' @rdname mcmcRocPrc
#' 
#' @export
mcmcRocPrc.stanreg <- function(object, curves = FALSE, fullsims = FALSE, ...) {
  if (!requireNamespace("rstanarm", quietly = TRUE)) {
    stop("Package \"rstanarm\" is needed for this function to work. Please install it.", call. = FALSE)  # nocov
  }
  if (!stats::family(object)$family=="binomial") {
    stop("the input model does not seem to be a binary choice model; should be like 'obj <- stan_glm(family = binomial(), ...)'") 
  }
  pred_prob <- rstanarm::posterior_linpred(object, transform = TRUE)
  # posterior_linepred returns a matrix in which data cases are columns, and 
  # MCMC samples are row; we need to transpose this so that columns are samples
  pred_prob <- t(pred_prob)
  yvec <- unname(object$y)
  
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
                 fullsims = fullsims)
}

#' @rdname mcmcRocPrc
#' 
#' @export
mcmcRocPrc.brmsfit <- function(object, curves = FALSE, fullsims = FALSE, ...) {
  if (!requireNamespace("brms", quietly = TRUE)) {
    stop("Package \"brms\" is needed for this function to work. Please install it.", call. = FALSE)  # nocov
  }
  if (!stats::family(object)$family=="bernoulli") {
    stop("the input model does not seem to be a binary choice model; should be like 'obj <- brm(family = bernoulli(), ...)'") 
  }
  
  pred_prob <- brms::posterior_epred(object)
  # posterior_epred returns a matrix in which data cases are columns, and 
  # MCMC samples are row; we need to transpose this so that columns are samples
  pred_prob <- t(pred_prob)
  yvec <- stats::model.response(stats::model.frame(object))
  
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
                 fullsims = fullsims)
}


# Other input types (MCMCpack, ...) ---------------------------------------



#' @rdname mcmcRocPrc
#'
#' @export
mcmcRocPrc.bugs <- function(object, curves = FALSE, fullsims = FALSE, data,
                            xnames, yname, type = c("logit", "probit"), ...) {
  
  link_type <- match.arg(type)
  mdl_data <- data
  stopifnot(
    all(xnames %in% names(mdl_data)),
    all(yname %in% names(mdl_data))
  )
  
  # add intercept by default, maybe revisit this
  xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames])))
  yvec  <- mdl_data[[yname]]
  
  sm <- object$sims.matrix
  # Drop "deviance" column
  betadraws <- sm[, !colnames(sm) %in% "deviance"]
  
  if(link_type=="logit") {
    pred_prob <- plogis(xdata %*% t(betadraws))  
  } else if (link_type=="probit") {
    pred_prob <- pnorm(xdata %*% t(betadraws))
  } 
  
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
                 fullsims = fullsims)
}


#' @rdname mcmcRocPrc
#' 
#' @param type "logit" or "probit"
#' @param force for MCMCpack models, suppress warning if the model does not 
#'   appear to be a binary choice model?
#'
#' @export
mcmcRocPrc.mcmc <- function(object, curves = FALSE, fullsims = FALSE, data, 
                            xnames, yname, type = c("logit", "probit"), 
                            force = FALSE, ...) {
  
  if (!force) {
    if (is.null(attr(object, "call"))) {
      stop("object does not have a 'call' attribute; was it generated with a MCMCpack function?")
    } else {
      func <- as.character(attr(object, "call"))[1]
      if (!func %in% c("MCMClogit", "MCMCprobit")) {
        stop("object does not appear to have been fitted using MCMCpack::MCMClogit() or MCMCprobit(); mcmcRocPrc only properly works for those function. To be safe, consider manually calculating the matrix of predicted probabilities.")
      }
    }
  }
  
  link_type <- match.arg(type)
  mdl_data <- data
  stopifnot(
    all(xnames %in% names(mdl_data)),
    all(yname %in% names(mdl_data))
  )
  
  # add intercept by default, maybe revisit this
  xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames])))
  yvec  <- mdl_data[[yname]]
  
  betadraws <- as.matrix(object)

  if(link_type=="logit") {
    pred_prob <- plogis(xdata %*% t(betadraws))  
  } else if (link_type=="probit") {
    pred_prob <- pnorm(xdata %*% t(betadraws))
  } 
  
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
                 fullsims = fullsims)
}

Try the BayesPostEst package in your browser

Any scripts or data that you put into this service are public.

BayesPostEst documentation built on Nov. 11, 2021, 9:07 a.m.