R/gg_partial.R

Defines functions gg_partial.randomForest gg_partial.rfsrc gg_partial

Documented in gg_partial gg_partial.randomForest gg_partial.rfsrc

####**********************************************************************
####**********************************************************************
####
####  ----------------------------------------------------------------
####  Written by:
####    John Ehrlinger, Ph.D.
####
####    email:  john.ehrlinger@gmail.com
####    URL:    https://github.com/ehrlinger/ggRandomForests
####  ----------------------------------------------------------------
####
####**********************************************************************
####**********************************************************************
#' Partial variable dependence object
#'
#' @description The \code{\link[randomForestSRC]{plot.variable}} function
#' returns a list of either marginal variable dependence or partial variable
#' dependence data from a \code{\link[randomForestSRC]{rfsrc}} object.
#' The \code{gg_partial} function formulates the
#' \code{\link[randomForestSRC]{plot.variable}} output for partial plots
#' (where \code{partial=TRUE}) into a data object for creation of partial
#' dependence plots using the \code{\link{plot.gg_partial}} function.
#'
#' Partial variable dependence plots are the risk adjusted estimates of the
#' specified response as a function of a single covariate, possibly subsetted
#' on other covariates.
#'
#' An option \code{named} argument can name a column for merging multiple
#' plots together
#'
#' @param object the partial variable dependence data object from
#'   \code{\link[randomForestSRC]{plot.variable}} function
#' @param ... optional arguments
#'
#' @return \code{gg_partial} object. A \code{data.frame} or \code{list} of
#' \code{data.frames} corresponding the variables
#' contained within the \code{\link[randomForestSRC]{plot.variable}} output.
#'
#' @seealso \code{\link{plot.gg_partial}}
#' @seealso \code{\link[randomForestSRC]{plot.variable}}
#'
#' @importFrom parallel mclapply
#'
#' @references
#' Friedman, Jerome H. 2000. "Greedy Function Approximation: A Gradient
#' Boosting Machine." Annals of Statistics 29: 1189-1232."
#'
#' @examples
#' ## ------------------------------------------------------------
#' ## classification
#' ## ------------------------------------------------------------
#' ## -------- iris data
#' ## iris "Petal.Width" partial dependence plot
#' ##
#' rfsrc_iris <- rfsrc(Species ~ ., data = iris)
#' partial_iris <- plot.variable(rfsrc_iris,
#'   xvar.names = "Petal.Width",
#'   partial = TRUE
#' )
#'
#' gg_dta <- gg_partial(partial_iris)
#' plot(gg_dta)
#'
#' ## ------------------------------------------------------------
#' ## regression
#' ## ------------------------------------------------------------
#' \dontrun{
#' ## -------- air quality data
#' ## airquality "Wind" partial dependence plot
#' ##
#' rfsrc_airq <- rfsrc(Ozone ~ ., data = airquality)
#' partial_airq <- plot.variable(rfsrc_airq,
#'   xvar.names = "Wind",
#'   partial = TRUE, show.plot = FALSE
#' )
#'
#' gg_dta <- gg_partial(partial_airq)
#' plot(gg_dta)
#' }
#' \dontrun{
#' ## -------- Boston data
#' data(Boston, package = "MASS")
#' Boston$chas <- as.logical(Boston$chas)
#' rfsrc_boston <- rfsrc(medv ~ .,
#'   data = Boston,
#'   forest = TRUE,
#'   importance = TRUE,
#'   tree.err = TRUE,
#'   save.memory = TRUE
#' )
#'
#' varsel_boston <- var.select(rfsrc_boston)
#'
#' partial_boston <- plot.variable(rfsrc_boston,
#'   xvar.names = varsel_boston$topvars,
#'   sorted = FALSE,
#'   partial = TRUE,
#'   show.plots = FALSE
#' )
#' gg_dta <- gg_partial(partial_boston)
#' plot(gg_dta, panel = TRUE)
#' }
#' \dontrun{
#' ## -------- mtcars data
#' rfsrc_mtcars <- rfsrc(mpg ~ ., data = mtcars)
#' varsel_mtcars <- var.select(rfsrc_mtcars)
#'
#' partial_mtcars <- plot.variable(rfsrc_mtcars,
#'   xvar.names = varsel_mtcars$topvars,
#'   sorted = FALSE,
#'   partial = TRUE,
#'   show.plots = FALSE
#' )
#'
#' gg_dta <- gg_partial(partial_mtcars)
#'
#' gg_dta.cat <- gg_dta
#' gg_dta.cat[["disp"]] <- gg_dta.cat[["wt"]] <- gg_dta.cat[["hp"]] <- NULL
#' gg_dta.cat[["drat"]] <- gg_dta.cat[["carb"]] <- gg_dta.cat[["qsec"]] <- NULL
#'
#' plot(gg_dta.cat, panel = TRUE, notch = TRUE)
#'
#' gg_dta[["cyl"]] <- gg_dta[["vs"]] <- gg_dta[["am"]] <- NULL
#' gg_dta[["gear"]] <- NULL
#' plot(gg_dta, panel = TRUE)
#' }
#'
#' ## ------------------------------------------------------------
#' ## survival examples
#' ## ------------------------------------------------------------
#' \dontrun{
#' ## -------- veteran data
#' ## survival "age" partial variable dependence plot
#' ##
#' data(veteran, package = "randomForestSRC")
#' rfsrc_veteran <- rfsrc(Surv(time, status) ~ ., veteran,
#'   nsplit = 10,
#'   ntree = 100
#' )
#'
#' varsel_rfsrc <- var.select(rfsrc_veteran)
#'
#' ## 30 day partial plot for age
#' partial_veteran <- plot.variable(rfsrc_veteran,
#'   surv.type = "surv",
#'   partial = TRUE, time = 30,
#'   show.plots = FALSE
#' )
#'
#' gg_dta <- gg_partial(partial_veteran)
#' plot(gg_dta, panel = TRUE)
#'
#' gg_dta.cat <- gg_dta
#' gg_dta[["celltype"]] <- gg_dta[["trt"]] <- gg_dta[["prior"]] <- NULL
#' plot(gg_dta, panel = TRUE)
#'
#' gg_dta.cat[["karno"]] <- gg_dta.cat[["diagtime"]] <-
#'   gg_dta.cat[["age"]] <- NULL
#' plot(gg_dta.cat, panel = TRUE, notch = TRUE)
#'
#' gg_dta <- lapply(partial_veteran, gg_partial)
#' gg_dta <- combine.gg_partial(gg_dta[[1]], gg_dta[[2]])
#'
#' plot(gg_dta[["karno"]])
#' plot(gg_dta[["celltype"]])
#'
#' gg_dta.cat <- gg_dta
#' gg_dta[["celltype"]] <- gg_dta[["trt"]] <- gg_dta[["prior"]] <- NULL
#' plot(gg_dta, panel = TRUE)
#'
#' gg_dta.cat[["karno"]] <- gg_dta.cat[["diagtime"]] <-
#'   gg_dta.cat[["age"]] <- NULL
#' plot(gg_dta.cat, panel = TRUE, notch = TRUE)
#'
#' ## ------------------------------------------------------------
#' ## -------- pbc data
#' # We need to create this dataset
#' data(pbc, package = "randomForestSRC", )
#' # For whatever reason, the age variable is in days... makes no sense to me
#' for (ind in seq_len(dim(pbc)[2])) {
#'   if (!is.factor(pbc[, ind])) {
#'     if (length(unique(pbc[which(!is.na(pbc[, ind])), ind])) <= 2) {
#'       if (sum(range(pbc[, ind], na.rm = TRUE) == c(0, 1)) == 2) {
#'         pbc[, ind] <- as.logical(pbc[, ind])
#'       }
#'     }
#'   } else {
#'     if (length(unique(pbc[which(!is.na(pbc[, ind])), ind])) <= 2) {
#'       if (sum(sort(unique(pbc[, ind])) == c(0, 1)) == 2) {
#'         pbc[, ind] <- as.logical(pbc[, ind])
#'       }
#'       if (sum(sort(unique(pbc[, ind])) == c(FALSE, TRUE)) == 2) {
#'         pbc[, ind] <- as.logical(pbc[, ind])
#'       }
#'     }
#'   }
#'   if (!is.logical(pbc[, ind]) &
#'     length(unique(pbc[which(!is.na(pbc[, ind])), ind])) <= 5) {
#'     pbc[, ind] <- factor(pbc[, ind])
#'   }
#' }
#' # Convert age to years
#' pbc$age <- pbc$age / 364.24
#'
#' pbc$years <- pbc$days / 364.24
#' pbc <- pbc[, -which(colnames(pbc) == "days")]
#' pbc$treatment <- as.numeric(pbc$treatment)
#' pbc$treatment[which(pbc$treatment == 1)] <- "DPCA"
#' pbc$treatment[which(pbc$treatment == 2)] <- "placebo"
#' pbc$treatment <- factor(pbc$treatment)
#' dta_train <- pbc[-which(is.na(pbc$treatment)), ]
#' # Create a test set from the remaining patients
#' pbc_test <- pbc[which(is.na(pbc$treatment)), ]
#'
#' # ========
#' # build the forest:
#' rfsrc_pbc <- randomForestSRC::rfsrc(
#'   Surv(years, status) ~ .,
#'   dta_train,
#'   nsplit = 10,
#'   na.action = "na.impute",
#'   forest = TRUE,
#'   importance = TRUE,
#'   save.memory = TRUE
#' )
#'
#' varsel_pbc <- var.select(rfsrc_pbc)
#'
#' xvar <- varsel_pbc$topvars
#'
#' # Convert all partial plots to gg_partial objects
#' gg_dta <- lapply(partial_pbc, gg_partial)
#'
#' # Combine the objects to get multiple time curves
#' # along variables on a single figure.
#' pbc_ggpart <- combine.gg_partial(gg_dta[[1]], gg_dta[[2]],
#'   lbls = c("1 Year", "3 Years")
#' )
#'
#' summary(pbc_ggpart)
#' class(pbc_ggpart[["bili"]])
#'
#' # Plot the highest ranked variable, by name.
#' # plot(pbc_ggpart[["bili"]])
#'
#' # Create a temporary holder and remove the stage and edema data
#' ggpart <- pbc_ggpart
#' ggpart$edema <- NULL
#'
#' # Panel plot the remainder.
#' plot(ggpart, panel = TRUE)
#'
#' plot(pbc_ggpart[["edema"]])
#' }
#' @aliases gg_partial gg_partial_list gg_partial.rfsrc gg_partial.randomForest
#' @name gg_partial
#' @aliases gg_partial_list
#' @export
gg_partial <- function(object, ...) {
  UseMethod("gg_partial", object)
}

#' @export
gg_partial.rfsrc <- function(object,
                             ...) {
  if (!inherits(object, "plot.variable")) {
    stop(
      paste(
        "gg_partial expects a plot.variable object, ",
        "Run plot.variable with partial=TRUE"
      )
    )
  }

  # If we pass it a plot.variable output, without setting partial=TRUE,
  # We'll want a gg_variable object.
  if (!object$partial) {
    invisible(gg_variable(object, ...))
  }

  call_v <- match.call(expand.dots = TRUE)
  named <- eval.parent(call_v$named)

  # How many variables
  n_var <- length(object$pData)

  # Create a list of data
  gg_dta <- parallel::mclapply(seq_len(n_var), function(ind) {
    if (length(object$pData[[ind]]$x.uniq) ==
      length(object$pData[[ind]]$yhat)) {
      if (object$family == "surv") {
        # Survival family has weird standard errors because of non-normal
        # transforms
        data.frame(cbind(
          yhat = object$pData[[ind]]$yhat,
          x = object$pData[[ind]]$x.uniq
        ))
      } else {
        # We assume RC forests are "normal"
        data.frame(
          cbind(
            yhat = object$pData[[ind]]$yhat,
            x = object$pData[[ind]]$x.uniq,
            se = object$pData[[ind]]$yhat.se
          )
        )
      }
    } else {
      x <- rep(
        as.character(object$pData[[ind]]$x.uniq),
        rep(object$n, object$pData[[ind]]$n.x)
      )
      tmp <- data.frame(cbind(yhat = x, x = x))
      tmp$x <- factor(tmp$x)
      tmp$yhat <- object$pData[[ind]]$yhat
      tmp
    }
  })

  names(gg_dta) <- object$xvar.names

  # name the data, so labels come out correctly.
  for (ind in seq_len(n_var)) {
    colnames(gg_dta[[ind]])[which(colnames(gg_dta[[ind]]) == "x")] <-
      object$xvar.names[ind]
    if (!missing(named)) {
      gg_dta[[ind]]$id <- named
    }
    class(gg_dta[[ind]]) <- c(
      "gg_partial", class(gg_dta[[ind]]),
      object$family
    )
  }

  if (n_var == 1) {
    # If there is only one, no need for a list
    invisible(gg_dta[[1]])
  } else {
    # otherwise, add a class label so we can handle it correctly.
    class(gg_dta) <- c("gg_partial_list", class(gg_dta))
    invisible(gg_dta)
  }
}

#' @export
gg_partial.randomForest <- function(object,
                                    ...) {
  stop("gg_partial is not yet support for randomForest objects")
}
ehrlinger/ggRandomForests documentation built on June 12, 2025, 10:59 a.m.