
Defines functions gg_roc.randomForest gg_roc gg_roc.rfsrc

Documented in gg_roc gg_roc.randomForest gg_roc.rfsrc

####  ----------------------------------------------------------------
####  Written by:
####    John Ehrlinger, Ph.D.
####    email:  john.ehrlinger@gmail.com
####    URL:    https://github.com/ehrlinger/ggRandomForests
####  ----------------------------------------------------------------
#' ROC (Receiver operator curve) data from a classification random forest.
#' The sensitivity and specificity of a randomForest classification object.
#' @param object an \code{\link[randomForestSRC]{rfsrc}} classification object
#' @param which_outcome select the classification outcome of interest.
#' @param oob use oob estimates (default TRUE)
#' @param ... extra arguments (not used)
#' @return \code{gg_roc} \code{data.frame} for plotting ROC curves.
#' @seealso \code{\link{plot.gg_roc}} \code{\link[randomForestSRC]{rfsrc}}
#' \code{\link[randomForest]{randomForest}}
#' @examples
#' ## ------------------------------------------------------------
#' ## classification example
#' ## ------------------------------------------------------------
#' ## -------- iris data
#' rfsrc_iris <- rfsrc(Species ~ ., data = iris)
#' # ROC for setosa
#' gg_dta <- gg_roc(rfsrc_iris, which_outcome=1)
#' plot(gg_dta)
#' # ROC for versicolor
#' gg_dta <- gg_roc(rfsrc_iris, which_outcome=2)
#' plot(gg_dta)
#' # ROC for virginica
#' gg_dta <- gg_roc(rfsrc_iris, which_outcome=3)
#' plot(gg_dta)
#' ## -------- iris data
#' rf_iris <- randomForest::randomForest(Species ~ ., data = iris)
#' # ROC for setosa
#' gg_dta <- gg_roc(rf_iris, which_outcome=1)
#' plot(gg_dta)
#' # ROC for versicolor
#' gg_dta <- gg_roc(rf_iris, which_outcome=2)
#' plot(gg_dta)
#' # ROC for virginica
#' gg_dta <- gg_roc(rf_iris, which_outcome=3)
#' plot(gg_dta)
#' @aliases gg_roc gg_roc.rfsrc gg_roc.randomForest

#' @export
gg_roc.rfsrc <- function(object, which_outcome, oob, ...) {
  if (sum(inherits(object, c("rfsrc", "grow"), TRUE) == c(1, 2)) != 2 &&
      sum(inherits(object, c("rfsrc", "predict"), TRUE) == c(1, 2)) != 2 &&
      !inherits(object, "randomForest")) {
      "This function only works for objects of class `(rfsrc, grow)',
      '(rfsrc, predict)' or 'randomForest."
  if (!inherits(object, "class")) {
    stop("gg_roc only works with classification forests")
  # Want to remove the which_outcomes argument to plot ROC for all
  # outcomes simultaneously.
  if (missing(which_outcome))
    which_outcome <- "all"
  if (object$family != "class")
    stop("gg_roc is intended for classification forests only.")
  gg_dta <-
             which_outcome = which_outcome,
             oob = oob)
  #   }
  class(gg_dta) <- c("gg_roc", class(gg_dta))
#' @export
gg_roc <- function(object, which_outcome, oob, ...) {
  UseMethod("gg_roc", object)

#' @export
gg_roc.randomForest <- function(object, which_outcome, oob, ...) {
  if (sum(inherits(object, "randomForest", TRUE) == c(1, 2)) != 1)
      "This function only works for objects of class `(rfsrc, grow)',
      '(rfsrc, predict)' or 'randomForest."
  # Want to remove the which_outcomes argument to plot ROC for all
  # outcomes simultaneously.
  if (missing(which_outcome))
    which_outcome <- "all"
  if (!(object$type == "classification")) {
    stop("gg_roc only works with classification forests")
  gg_dta <-
             which_outcome = which_outcome)
  #   }
  class(gg_dta) <- c("gg_roc", class(gg_dta))

#' @export
gg_roc.default <- gg_roc.rfsrc

Try the ggRandomForests package in your browser

Any scripts or data that you put into this service are public.

ggRandomForests documentation built on Sept. 1, 2022, 5:07 p.m.