R/predict.icrf.R

#' @rdname predict.icrf
#' @name predict.icrf
#'
#' @title icrf predictions
#'
#' @description Prediction method of test data using interval censored recursive forest.
#' (Quoted statements are from
#' \code{randomForest} by Liaw and Wiener unless otherwise mentioned.)
#'
#' @param object an object of \code{icrf} class generated by the function \code{icrf}.
#' @param newdata 'a data frame or matrix containing new data. (Note: If not given,
#' the predicted survival estimate of the training data set in the \code{object} is returned.)'
#' @param predict.all 'Should the predictions of all trees be kept?'
#' @param proximity 'Should proximity measures be computed?'
#' @param nodes 'Should the terminal node indicators (an n by ntree matrix)
#' be returned? If so, it is in the "nodes" attribute of the returned object.'
#' @param smooth Should smoothed curve be returned?
#' @param ... 'not used currently.'
#'
#' @return A matrix of predicted survival probabilities is returned where the rows represent
#' the observations and the columns represent the time points.
#' 'If predict.all=TRUE, then the returned object is a list of two components:
#' \code{aggregate}, which is the vector of predicted values by the forest,
#' and \code{individual}, which is a matrix where each column contains prediction
#' by a tree in the forest.' The forest is either the last forest or the best forest
#' as specified by \code{returnBest} argument in \code{icrf} function.
#'
#'
#' 'If \code{proximity=TRUE}, the returned object is a list with two components:
#' \code{pred} is the prediction (as described above) and
#' \code{proximity} is the proximitry matrix.'
#'
#'
#' 'If \code{nodes=TRUE}, the returned object has a "nodes" attribute,
#' which is an \code{n} by \code{ntree} matrix, each column containing the
#' node number that the cases fall in for that tree.'
#'
#'
#' @examples
#' # rats data example
#' # Note that this is a toy example. Use a larger ntree and nfold in practice.
#' library(survival)  # for Surv()
#' data(rat2)
#' set.seed(1)
#' samp <- sample(1:dim(rat2)[1], 200)
#' rats.train <- rat2[samp, ]
#' rats.test <- rat2[-samp, ]
#' L = ifelse(rats.train$tumor, 0, rats.train$survtime)
#' R = ifelse(rats.train$tumor, rats.train$survtime, Inf)
#' \donttest{
#'  set.seed(2)
#'  rats.icrf.small <-
#'    icrf(survival::Surv(L, R, type = "interval2") ~ dose.lvl + weight + male + cage.no,
#'         data = rats.train, ntree = 10, nfold = 3, proximity = TRUE)
#'
#'  # predicted survival curve for the training data
#'  predict(rats.icrf.small)
#'  predict(rats.icrf.small, smooth = FALSE) # non-smoothed
#'
#'  # predicted survival curve for new data
#'  predict(rats.icrf.small, newdata = rats.test)
#'  predict(rats.icrf.small, newdata = rats.test, proximity = TRUE)
#'
#'  # time can be extracted using attr()
#'  newpred = predict(rats.icrf.small, newdata = rats.test)
#'  attr(newpred, "time")
#'
#'  newpred2 = predict(rats.icrf.small, newdata = rats.test, proximity = TRUE)
#'  attr(newpred2$predicted, "time")
#' }
#' \dontshow{
#'  set.seed(2)
#'  rats.icrf.small <-
#'    icrf(survival::Surv(L, R, type = "interval2") ~ dose.lvl + weight + male + cage.no,
#'         data = rats.train, ntree = 2, nfold = 2, proximity = TRUE)
#'
#'  # predicted survival curve for the training data
#'  predict(rats.icrf.small)
#'  predict(rats.icrf.small, smooth = FALSE) # non-smoothed
#'
#'  # predicted survival curve for new data
#'  predict(rats.icrf.small, newdata = rats.test)
#'  predict(rats.icrf.small, newdata = rats.test, proximity = TRUE)
#'
#'  # time can be extracted using attr()
#'  newpred = predict(rats.icrf.small, newdata = rats.test)
#'  attr(newpred, "time")
#'
#'  newpred2 = predict(rats.icrf.small, newdata = rats.test, proximity = TRUE)
#'  attr(newpred2$predicted, "time")
#' }
#'
#'
#' @author Hunyong Cho, Nicholas P. Jewell, and Michael R. Kosorok.
#'
#' @references
#' \href{https://arxiv.org/abs/1912.09983}{Cho H., Jewell N. J., and Kosorok M. R. (2020+). "Interval censored
#'  recursive forests"}
#'
#' @export
#' @useDynLib icrf
#'
"predict.icrf" <-
  function (object, newdata, # time.points = NULL,
            predict.all=FALSE, proximity = FALSE, nodes=FALSE,
            smooth = TRUE, ...) {
    if (!inherits(object, "icrf"))
      stop("object not of class icrf")
    prediction = ifelse(smooth, "predicted.Sm", "predicted")
    nodePrediction = ifelse(smooth, "nodepredSm", "nodepred")
    timepts = ifelse(smooth, "time.points.smooth", "time.points")

    if (missing(newdata)) {
      p <- if (! is.null(object$na.action)) {
        napredict(object$na.action, object[[prediction]])
      } else {
        object[[prediction]]
      }
      attr(p, "time") <- object[[timepts]]
      if (proximity & is.null(object$proximity))
        warning("cannot return proximity without new data if random forest object does not already have proximity")
      if (proximity) {
        res = list(pred = p, proximity = object$proximity)
      } else
        res = p
      return(res)
    }
    if (is.null(object$forest)) stop("No forest component in the object")
    if (inherits(object, "icrf.formula")) {
      newdata <- as.data.frame(newdata)
      rn <- row.names(newdata)
      Terms <- delete.response(object$terms)
      x <- model.frame(Terms, newdata, na.action = na.omit)
      keep <- match(row.names(x), rn)
    } else {
      if (is.null(dim(newdata)))
        dim(newdata) <- c(1, length(newdata))
      x <- newdata
      if (nrow(x) == 0)
        stop("newdata has 0 rows")
      if (any(is.na(x)))
        stop("missing values in newdata")
      keep <- 1:nrow(x)
      rn <- rownames(x)
      if (is.null(rn)) rn <- keep
    }
    vname <- if (is.null(dim(object$importance))) {
      names(object$importance)
    } else {
      rownames(object$importance)
    }

    if (is.null(colnames(x))) {
      if (ncol(x) != length(vname)) {
        stop("number of variables in newdata does not match that in the training data")
      }
    } else {
      if (any(! vname %in% colnames(x)))
        stop("variables in the training data missing in newdata")
      x <- x[, vname, drop=FALSE]
    }
    if (is.data.frame(x)) {
      isFactor <- function(x) is.factor(x) & ! is.ordered(x)
      xfactor <- which(sapply(x, isFactor))
      if (length(xfactor) > 0 && "xlevels" %in% names(object$forest)) {
        for (i in xfactor) {
          if (any(! levels(x[[i]]) %in% object$forest$xlevels[[i]]))
            stop("New factor levels not present in the training data")
          x[[i]] <-
            factor(x[[i]],
                   levels=levels(x[[i]])[match(levels(x[[i]]), object$forest$xlevels[[i]])])
        }
      }
      cat.new <- sapply(x, function(x) if (is.factor(x) && !is.ordered(x))
        length(levels(x)) else 1)
      if (!all(object$forest$ncat == cat.new))
        stop("Type of predictors in new data do not match that of the training data.")
    }
    mdim <- ncol(x)
    ntest <- nrow(x)
    ntree <- object$forest$ntree
    maxcat <- max(object$forest$ncat)
    nclass <- object$forest$nclass
    nrnodes <- object$forest$nrnodes
    ## get rid of warning:
    op <- options(warn=-1)
    on.exit(options(op))
    x <- t(data.matrix(x))

    time.points = object[[timepts]]
    t.names <- paste0("t", seq_along(time.points))
    ntime <- length(time.points)

    if (predict.all) {
      treepred <- array(double(ntest * ntime * ntree), dim = c(ntest, ntime, ntree))
    } else {
      treepred <- array(integer(ntest * ntime * ntree), dim = c(ntest, ntime, ntree))
    }
    proxmatrix <- if (proximity) matrix(0, ntest, ntest) else numeric(1)
    nodexts <- if (nodes) integer(ntest * ntree) else integer(ntest)

    if (!is.null(object$forest$treemap)) {
      object$forest$leftDaughter <-
        object$forest$treemap[,1,, drop=FALSE]
      object$forest$rightDaughter <-
        object$forest$treemap[,2,, drop=FALSE]
      object$forest$treemap <- NULL
    }


    keepIndex <- "ypred"
    if (predict.all) keepIndex <- c(keepIndex, "treepred")
    if (proximity) keepIndex <- c(keepIndex, "proximity")
    if (nodes) keepIndex <- c(keepIndex, "nodexts")
    ## Ensure storage mode is what is expected in C.
    if (! is.integer(object$forest$leftDaughter))
      storage.mode(object$forest$leftDaughter) <- "integer"
    if (! is.integer(object$forest$rightDaughter))
      storage.mode(object$forest$rightDaughter) <- "integer"
    if (! is.integer(object$forest$nodestatus))
      storage.mode(object$forest$nodestatus) <- "integer"
    if (! is.double(object$forest$xbestsplit))
      storage.mode(object$forest$xbestsplit) <- "double"
    if (! is.double(object$forest[[nodePrediction]]))
      storage.mode(object$forest[[nodePrediction]]) <- "double"
    if (! is.integer(object$forest$bestvar))
      storage.mode(object$forest$bestvar) <- "integer"
    if (! is.integer(object$forest$ndbigtree))
      storage.mode(object$forest$ndbigtree) <- "integer"
    if (! is.integer(object$forest$ncat))
      storage.mode(object$forest$ncat) <- "integer"

    ans <- .C("survForest",
              as.double(x),
              ypred = double(ntest * ntime),
              as.integer(mdim),
      #as.integer(ntime.rf),
              as.integer(ntime),
              as.integer(ntest),
              as.integer(ntree),
              object$forest$leftDaughter,
              object$forest$rightDaughter,
              object$forest$nodestatus,
              nrnodes,
              object$forest$xbestsplit,
              object$forest[[nodePrediction]],
              object$forest$bestvar,
              object$forest$ndbigtree,
              object$forest$ncat,
              as.integer(maxcat),
              as.integer(predict.all),
              treepred = as.double(treepred),
              as.integer(proximity),
              proximity = as.double(proxmatrix),
              nodes = as.integer(nodes),
              nodexts = as.integer(nodexts),
              #DUP=FALSE,
              PACKAGE = "icrf")[keepIndex]
    ## Apply bias correction if needed.
    yhat <- matrix(NA, nrow = ntest, ncol = ntime)
    rownames(yhat) <- rn
    colnames(yhat) <- t.names
    yhat[keep, ] <- ans$ypred

    attr(yhat, "time") <- time.points

    if (predict.all) {
      treepred <- array(NA, dim = c(ntest, ntime, ntree),
                        dimnames=list(rn, t.names, NULL))
      treepred[keep, , ] <- ans$treepred
    }
    if (!proximity) {
      res <- if (predict.all)
        list(aggregate=yhat, individual=treepred) else yhat
    } else {  ### TBD this part. ###
      res <- list(predicted = yhat,
                  proximity = structure(ans$proximity,
                                        dim=c(ntest, ntest), dimnames=list(rn, rn)))
    }
    # attr(res, "time") <- object[[timepts]]
    if (nodes) {
      attr(res, "nodes") <- matrix(ans$nodexts, ntest, ntree,
                                   dimnames=list(rn[keep], 1:ntree))
    }

    res
  }

Try the icrf package in your browser

Any scripts or data that you put into this service are public.

icrf documentation built on Oct. 30, 2022, 1:05 a.m.