R/plot.variable.R

####**********************************************************************
####**********************************************************************
####
####  RANDOM FORESTS FOR SURVIVAL, REGRESSION, AND CLASSIFICATION (RF-SRC)
####  Version 2.4.1 (_PROJECT_BUILD_ID_)
####
####  Copyright 2016, University of Miami
####
####  This program is free software; you can redistribute it and/or
####  modify it under the terms of the GNU General Public License
####  as published by the Free Software Foundation; either version 3
####  of the License, or (at your option) any later version.
####
####  This program is distributed in the hope that it will be useful,
####  but WITHOUT ANY WARRANTY; without even the implied warranty of
####  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
####  GNU General Public License for more details.
####
####  You should have received a copy of the GNU General Public
####  License along with this program; if not, write to the Free
####  Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
####  Boston, MA  02110-1301, USA.
####
####  ----------------------------------------------------------------
####  Project Partially Funded By: 
####  ----------------------------------------------------------------
####  Dr. Ishwaran's work was funded in part by DMS grant 1148991 from the
####  National Science Foundation and grant R01 CA163739 from the National
####  Cancer Institute.
####
####  Dr. Kogalur's work was funded in part by grant R01 CA163739 from the 
####  National Cancer Institute.
####  ----------------------------------------------------------------
####  Written by:
####  ----------------------------------------------------------------
####    Hemant Ishwaran, Ph.D.
####    Director of Statistical Methodology
####    Professor, Division of Biostatistics
####    Clinical Research Building, Room 1058
####    1120 NW 14th Street
####    University of Miami, Miami FL 33136
####
####    email:  hemant.ishwaran@gmail.com
####    URL:    http://web.ccs.miami.edu/~hishwaran
####    --------------------------------------------------------------
####    Udaya B. Kogalur, Ph.D.
####    Adjunct Staff
####    Department of Quantitative Health Sciences
####    Cleveland Clinic Foundation
####    
####    Kogalur & Company, Inc.
####    5425 Nestleway Drive, Suite L1
####    Clemmons, NC 27012
####
####    email:  ubk@kogalur.com
####    URL:    http://www.kogalur.com
####    --------------------------------------------------------------
####
####**********************************************************************
####**********************************************************************


plot.variable.rfsrc <- function(
  x,
  xvar.names,
  which.class,
  outcome.target=NULL,
  time,
  surv.type = c("mort", "rel.freq", "surv", "years.lost", "cif", "chf"),
  class.type = c("prob", "bayes"),
  partial = FALSE,
  oob = TRUE,
  show.plots = TRUE,
  plots.per.page = 4,
  granule = 5,
  sorted = TRUE,
  nvar,
  npts = 25,
  smooth.lines = FALSE,
  subset,
  ...)
{
  if (sum(inherits(x, c("rfsrc", "synthetic"), TRUE) == c(1, 2)) == 2) {
    x <- x$rfSyn
  }
  object <- x
  remove(x)
  if (sum(inherits(object, c("rfsrc", "grow"), TRUE) == c(1, 2)) != 2 &
      sum(inherits(object, c("rfsrc", "predict"), TRUE) == c(1, 2)) != 2 &
      sum(inherits(object, c("rfsrc", "plot.variable"), TRUE) == c(1,2)) != 2) {
    stop("this function only works for objects of class `(rfsrc, grow)', '(rfsrc, predict)' or '(rfsrc, plot.variable)'")
  }
  if (object$family == "unsupv") {
    stop("this function does not apply to unsupervised forests")
  }
  if (partial && is.null(object$forest)) {
    stop("forest is empty:  re-run rfsrc (grow) call with forest=TRUE")
  }
  xvar <- object$xvar
  if (!is.null(object$imputed.indv)) {
    xvar[object$imputed.indv, ] <- object$imputed.data[, object$xvar.names]
  }
  n <- nrow(xvar)
  if (!inherits(object, "plot.variable")) {
    if (missing(subset)) {
      subset <- 1:n
    }
      else {
        if (is.logical(subset)) subset <- which(subset)
        subset <- unique(subset[subset >= 1 & subset <= n])
        if (length(subset) == 0) {
          stop("'subset' not set properly")
        }
      }
    xvar <- xvar[subset,, drop = FALSE]
    n <- nrow(xvar)
    family <- object$family
    outcome.target <- get.univariate.target(object, outcome.target)
    if (grepl("surv", family)) {
      event.info <- get.event.info(object, subset)
      cens <- event.info$cens
      event.type <- event.info$event.type
      if (missing(time)) {
        time <- median(event.info$time.interest, na.rm = TRUE)
      }
      if (is.null(surv.type)) {
        stop("surv.type is specified incorrectly:  ", surv.type)
      }
      if (family == "surv-CR") {
        if (missing(which.class)) {
          which.class <- 1
        }
          else {
            if (which.class < 1 || which.class > max(event.type, na.rm = TRUE)) {
              stop("'which.class' is specified incorrectly")
            }
          }
        VIMP <- object$importance[, which.class]
        pred.type <- setdiff(surv.type, c("rel.freq", "mort", "chf", "surv"))[1]
        pred.type <- match.arg(pred.type, c("years.lost", "cif", "chf"))
        ylabel <- switch(pred.type,
                         "years.lost" = paste("Years lost for event ", which.class),
                         "cif"        = paste("CIF for event ", which.class, " (time=", time, ")", sep = ""),
                         "chf"        = paste("CHF for event ", which.class, " (time=", time, ")", sep = ""))
      }
        else {
          which.class <- 1
          VIMP <- object$importance
          pred.type <- setdiff(surv.type, c("years.lost", "cif", "chf"))[1]
          pred.type <- match.arg(pred.type, c("rel.freq", "mort", "chf", "surv"))
          ylabel <- switch(pred.type,
                           "rel.freq"   = "standardized mortality",
                           "mort"       = "mortality",
                           "chf"        = paste("CHF (time=", time, ")", sep = ""),
                           "surv"       = paste("predicted survival (time=", time, ")", sep = ""))
        }
    }
      else {
        event.info <- time <- NULL
        if(is.factor(coerce.multivariate(object, outcome.target)$yvar)) {
          object.yvar.levels <- levels(coerce.multivariate(object, outcome.target)$yvar)
          pred.type <- match.arg(class.type, c("prob", "bayes"))
          if (missing(which.class)) {
            which.class <- object.yvar.levels[1]
          }
          if (is.character(which.class)) {
            which.class <- match(match.arg(which.class, object.yvar.levels), object.yvar.levels)
          }
            else {
              if ((which.class > length(object.yvar.levels)) | (which.class < 1)) {
                stop("which.class is specified incorrectly:", which.class)
              }
            }
          if (pred.type == "prob") {
            VIMP <- coerce.multivariate(object, outcome.target)$importance[, 1 + which.class]
            ylabel <- paste("probability", object.yvar.levels[which.class])
            remove(object.yvar.levels)
          }
            else {
              VIMP <- coerce.multivariate(object, outcome.target)$importance[, 1]
              ylabel <- paste("bayes")
              remove(object.yvar.levels)
            }
        }
          else {
            pred.type <- "y"
            which.class <- NULL
            VIMP <- coerce.multivariate(object, outcome.target)$importance
            ylabel <- expression(hat(y))
          }
      }
    if (missing(xvar.names)) {
      xvar.names <- object$xvar.names
    }
      else {
        xvar.names <- intersect(xvar.names, object$xvar.names)
        if (length(xvar.names) ==  0){
          stop("none of the x-variable supplied match available ones:\n", object$xvar.names)
        }
      }
    if (sorted & !is.null(VIMP)) {
      xvar.names <- xvar.names[rev(order(VIMP[xvar.names]))]
    }
    if (!missing(nvar)) {
      nvar <- max(round(nvar), 1)
      xvar.names <- xvar.names[1:min(length(xvar.names), nvar)]
    }
    nvar <- length(xvar.names)
    if (!partial) {
      yhat <- extract.pred(predict.rfsrc(object,
                                         outcome.target = outcome.target,
                                         importance = "none"),
                           pred.type,
                           subset,
                           time,
                           outcome.target,
                           which.class,
                           oob = oob)
    }
      else {
        if (npts < 1) npts <- 1 else npts <- round(npts)
        prtl <- lapply(1:nvar, function(k) {
          x <- na.omit(object$xvar[, object$xvar.names == xvar.names[k]])
          if (is.factor(x)) x <- factor(x, exclude = NULL)          
          n.x <- length(unique(x))
          if (!is.factor(x) & n.x > npts) {
            x.uniq <- sort(unique(x))[unique(as.integer(seq(1, n.x, length = min(npts, n.x))))]
          }
            else {
              x.uniq <- sort(unique(x))
            }
          n.x <- length(x.uniq)
          yhat <- yhat.se <- NULL
          factor.x <- !(!is.factor(x) & (n.x > granule))
          pred.temp <- extract.partial.pred(partial.rfsrc(object$forest,
                                                          outcome.target = outcome.target,
                                                          partial.type = pred.type,
                                                          partial.xvar = xvar.names[k],
                                                          partial.values = x.uniq,
                                                          partial.time = time,
                                                          oob = oob),
                                            pred.type,
                                            1:n,
                                            time,
                                            outcome.target,
                                            which.class)
          if (!is.null(dim(pred.temp))) {
            mean.temp <- apply(pred.temp, 2, mean, na.rm = TRUE)
          }
            else {
              mean.temp <- mean(pred.temp, na.rm = TRUE)
            }
          if (!factor.x) {
            yhat <- mean.temp
            if (coerce.multivariate(object, outcome.target)$family == "class") {
              yhat.se <- mean.temp * (1 - mean.temp) / sqrt(n)
            }
              else {
                sd.temp <- apply(pred.temp, 2, sd, na.rm = TRUE)
                yhat.se <- sd.temp / sqrt(n)
              }
          }
            else {
              pred.temp <- mean.temp + (pred.temp - mean.temp) / sqrt(n)
              yhat <- c(yhat, pred.temp)
            }
          list(xvar.name = xvar.names[k], yhat = yhat, yhat.se = yhat.se, n.x = n.x, x.uniq = x.uniq, x = x)
        })
      }
    plots.per.page <- max(round(min(plots.per.page,nvar)), 1)
    granule <- max(round(granule), 1)
    plot.variable.obj <- list(family = family,
                              partial = partial,
                              event.info = event.info,
                              which.class = which.class,
                              ylabel = ylabel,
                              n = n,
                              xvar.names = xvar.names,
                              nvar = nvar, 
                              plots.per.page = plots.per.page,
                              granule = granule,
                              smooth.lines = smooth.lines)
    if (partial) {
      plot.variable.obj$pData <- prtl
    }
      else {
        plot.variable.obj$yhat <- yhat
        plot.variable.obj$xvar <- xvar
      }
    class(plot.variable.obj) <- c("rfsrc", "plot.variable", family)
  }
    else {
      plot.variable.obj <- object
      remove(object)
      family <- plot.variable.obj$family
      partial <- plot.variable.obj$partial
      event.info <- plot.variable.obj$event.info
      which.class <- plot.variable.obj$which.class
      ylabel <- plot.variable.obj$ylabel
      n <- plot.variable.obj$n
      xvar.names <- plot.variable.obj$xvar.names
      nvar <- plot.variable.obj$nvar 
      plots.per.page <- plot.variable.obj$plots.per.page
      granule <- plot.variable.obj$granule
      smooth.lines <- plot.variable.obj$smooth.lines
      if (partial) {
        prtl <- plot.variable.obj$pData
      }
        else {
          yhat <- plot.variable.obj$yhat
          xvar <- plot.variable.obj$xvar
        }
      if (!is.null(event.info)){
        cens <- event.info$cens
        event.type <- event.info$event.type
      }
    }
  if (show.plots) {
    old.par <- par(no.readonly = TRUE)
  }
  if (!partial && show.plots) {
    par(mfrow = c(min(plots.per.page, ceiling(nvar / plots.per.page)), plots.per.page))
    if (n > 500) cex.pt <- 0.5 else cex.pt <- 0.75
    for (k in 1:nvar) {
      x <- xvar[, which(colnames(xvar) == xvar.names[k])]
      x.uniq <- unique(x)
      n.x <- length(x.uniq)
      if (!is.factor(x) & n.x > granule) {
        plot(x,
             yhat,
             xlab = xvar.names[k],
             ylab = ylabel,
             type = "n", ...) 
        if (grepl("surv", family)) {
          points(x[cens == which.class], yhat[cens == which.class], pch = 16, col = 4, cex = cex.pt)
          points(x[cens == 0], yhat[cens == 0], pch = 16, cex = cex.pt)
        }
        lines(lowess(x[!is.na(x)], yhat[!is.na(x)]), col = 2, lwd=3)
      }
        else {
          if (is.factor(x)) x <- factor(x, exclude = NULL)          
          boxplot(yhat ~ x, na.action = "na.omit",
                  xlab = xvar.names[k],
                  ylab = ylabel,
                  notch = TRUE,
                  outline = FALSE,
                  col = "bisque",
                  names = rep("", n.x),
                  xaxt = "n", ...)
          at.pretty <- unique(round(pretty(1:n.x, min(30, n.x))))
          at.pretty <- at.pretty[at.pretty >= 1 & at.pretty <= n.x]
          axis(1,
               at = at.pretty,
               labels = format(sort(x.uniq)[at.pretty], trim = TRUE, digits = 4),
               tick = TRUE)
        }
    }
  }
  if (partial && show.plots) {
    plots.per.page <- max(round(min(plots.per.page,nvar)), 1)
    granule <- max(round(granule),1)
    par(mfrow = c(min(plots.per.page, ceiling(nvar/plots.per.page)), plots.per.page))
    for (k in 1:nvar) {
      x <- prtl[[k]]$x
      if (is.factor(x)) x <- factor(x, exclude = NULL)          
      x.uniq <- prtl[[k]]$x.uniq
      n.x <- prtl[[k]]$n.x
      if (n.x > 25) cex.pt <- 0.5 else cex.pt <- 0.75
      yhat <- prtl[[k]]$yhat
      yhat.se <- prtl[[k]]$yhat.se
      factor.x <- !(!is.factor(x) & (n.x > granule))
      if (!factor.x) {
        plot(c(min(x), x.uniq, max(x), x.uniq, x.uniq),
             c(NA, yhat, NA, yhat + 2 * yhat.se, yhat - 2 * yhat.se),
             xlab = prtl[[k]]$xvar.name,
             ylab = ylabel,
             type = "n", ...)
        points(x.uniq, yhat, pch = 16, cex = cex.pt, col = 2)
        if (!is.na(yhat.se) && any(yhat.se > 0)) {
          if (smooth.lines) {
            lines(lowess(x.uniq, yhat + 2 * yhat.se), lty = 3, col = 2)
            lines(lowess(x.uniq, yhat - 2 * yhat.se), lty = 3, col = 2)
          }
            else {
              lines(x.uniq, yhat + 2 * yhat.se, lty = 3, col = 2)
              lines(x.uniq, yhat - 2 * yhat.se, lty = 3, col = 2)
            }
        }
        if (smooth.lines) {
          lines(lowess(x.uniq, yhat), lty = 2, lwd=2)
        }
          else {
            lines(x.uniq, yhat, lty = 2, lwd=2)
          }
        rug(x, ticksize=0.03)
      }
        else {
          y.se <- 0.005
          bxp.call <- boxplot(yhat ~ rep(x.uniq, rep(n, n.x)), range = 2, plot = FALSE)
          boxplot(yhat ~ rep(x.uniq, rep(n, n.x)),
                  xlab = xvar.names[k],
                  ylab = ylabel,
                  notch = TRUE,
                  outline = FALSE,
                  range = 2,
                  ylim = c(min(bxp.call$stats[1,], na.rm=TRUE) * ( 1 - y.se ),
                    max(bxp.call$stats[5,], na.rm=TRUE) * ( 1 + y.se )),
                  col = "bisque",
                  names = rep("",n.x),
                  xaxt = "n", ...)
          at.pretty <- unique(round(pretty(1:n.x, min(30,n.x))))
          at.pretty <- at.pretty[at.pretty >= 1 & at.pretty <= n.x]
          axis(1,
               at = at.pretty,
               labels = format(sort(x.uniq)[at.pretty], trim = TRUE, digits = 4),
               tick = TRUE)
        }
    }
  }
  if (show.plots) {
    par(old.par)
  }
  invisible(plot.variable.obj)
}
extract.pred <- function(obj, type, subset, time, outcome.target, which.class, oob = oob) {
    obj <- coerce.multivariate(obj, outcome.target)
    if (oob == FALSE) {
        pred <- obj$predicted
        surv <- obj$survival
        chf <- obj$chf
        cif <- obj$cif
    }
    else {
        pred <- obj$predicted.oob
        surv <- obj$survival.oob
        chf <- obj$chf.oob
        cif <- obj$cif.oob
    }
    if (grepl("surv", obj$family)) {
      if (obj$family == "surv-CR") {
        n <- dim(pred)[1]
        if (missing(subset)) subset <- 1:n
        if (missing(which.class)) which.class <- 1
        type <- match.arg(type, c("years.lost", "cif", "chf"))
        time.idx <-  max(which(obj$time.interest <= time))
        return(switch(type,
                      "years.lost" = pred[subset, which.class],
                      "cif"        = cif[subset, time.idx, which.class],
                      "chf"        = chf[subset, time.idx, which.class]
                      ))
      }
      else {
        n <- length(pred)
        if (missing(subset)) subset <- 1:n
        type <- match.arg(type, c("rel.freq", "mort", "chf", "surv"))
        time.idx <-  max(which(obj$time.interest <= time))
        return(switch(type,
                      "rel.freq" = pred[subset]/max(n, na.omit(pred)),
                      "mort"     = pred[subset],
                      "chf"      = 100 * chf[subset, time.idx],
                      "surv"     = 100 * surv[subset, time.idx]
                      ))
      }
    }
      else {
        if (obj$family == "class") {
          n <- dim(pred)[1]
          if (missing(subset)) subset <- 1:n
          type <- match.arg(type, c("prob", "bayes"))
          if (missing(which.class)) which.class <- 1
          prob <- pred[subset,, drop = FALSE]
          return(switch(type,
                        "prob" = prob[, which.class],
                        "bayes" =  bayes.rule(prob)))
        }
          else {
            n <- length(pred)
            if (missing(subset)) subset <- 1:n
            return(pred[subset])
          }
      }
  }
extract.partial.pred <- function(obj, type, subset, time, outcome.target, which.class) {
  if (grepl("surv", obj$family)) {
    n <- dim(obj$survOutput)[1]
    if (missing(subset)) subset <- 1:n
    time.idx <-  max(which(obj$partial.time <= time))
    if (obj$family == "surv-CR") {
      if (missing(which.class)) which.class <- 1
      type <- match.arg(type, c("years.lost", "cif", "chf"))
      return(switch(type,
                    "years.lost" = obj$survOutput[subset, which.class, ],
                    "cif" = obj$survOutput[subset, time.idx, which.class, ],
                    "chf" = obj$survOutput[subset, time.idx, which.class, ]
                    ))
    }
      else {
        if (type == "rel.freq") {
          sz <- apply(obj$survOutput[subset, ], 2, function(x) {length(na.omit(x))})
          rs <- t(apply(obj$survOutput[subset, ], 1, function(x) {x / sz}))
        }
        return(switch(type,
                      "rel.freq" = rs,
                      "mort"     = obj$survOutput[subset, ],
                      "chf"      =  obj$survOutput[subset, time.idx, ],
                      "surv"     =  obj$survOutput[subset, time.idx, ]
                      ))
      }
  }
    else {
      target <- NULL
      if (is.null(target)) {
        if (!is.null(obj$classOutput)) {
          if (is.null(outcome.target)) {
            outcome.target <- names(obj$classOutput)[1]
          }
          target <- which (names(obj$classOutput) == outcome.target)
          if (length(target) > 0) {
            if (length(target) > 1) {
              stop("Invalid number of target outcomes specified in partial plot extraction.")
            }
            type <- match.arg(type, c("prob", "bayes"))
            if (missing(which.class)) which.class <- 1
            n <- dim(obj$classOutput[[target]])[1]
            if (missing(subset)) subset <- 1:n
            return(switch(type,
                          "prob" = obj$classOutput[[target]][subset, 1 + which.class, ],
                          "bayes" =  obj$classOutput[[target]][subset, 1, ]))
          }
        }
      }
      if (is.null(target)) {
        if (!is.null(obj$regrOutput)) {
          if (is.null(outcome.target)) {
            outcome.target <- names(obj$regrOutput)[1]
          }
          target <- which (names(obj$regrOutput) == outcome.target)
          if (length(target) > 0) {
            if (length(target) > 1) {            
              stop("Invalid number of target outcomes specified in partial plot extraction.")
            }
            n <- dim(obj$classOutput[[target]])[1]
            if (missing(subset)) subset <- 1:n
            return(obj$regrOutput[[target]][subset, ])
          }
        }
      }
      if (is.null(target)) {
        stop("Invalid target specified in partial plot extraction:  ", outcome.target)
      }
    }
}
plot.variable <- plot.variable.rfsrc
ehrlinger/randomForestSRC documentation built on May 16, 2019, 1:20 a.m.