Nothing
#' @rdname predict.icrf
#' @name predict.icrf
#'
#' @title icrf predictions
#'
#' @description Prediction method of test data using interval censored recursive forest.
#' (Quoted statements are from
#' \code{randomForest} by Liaw and Wiener unless otherwise mentioned.)
#'
#' @param object an object of \code{icrf} class generated by the function \code{icrf}.
#' @param newdata 'a data frame or matrix containing new data. (Note: If not given,
#' the predicted survival estimate of the training data set in the \code{object} is returned.)'
#' @param predict.all 'Should the predictions of all trees be kept?'
#' @param proximity 'Should proximity measures be computed?'
#' @param nodes 'Should the terminal node indicators (an n by ntree matrix)
#' be returned? If so, it is in the "nodes" attribute of the returned object.'
#' @param smooth Should smoothed curve be returned?
#' @param ... 'not used currently.'
#'
#' @return A matrix of predicted survival probabilities is returned where the rows represent
#' the observations and the columns represent the time points.
#' 'If predict.all=TRUE, then the returned object is a list of two components:
#' \code{aggregate}, which is the vector of predicted values by the forest,
#' and \code{individual}, which is a matrix where each column contains prediction
#' by a tree in the forest.' The forest is either the last forest or the best forest
#' as specified by \code{returnBest} argument in \code{icrf} function.
#'
#'
#' 'If \code{proximity=TRUE}, the returned object is a list with two components:
#' \code{pred} is the prediction (as described above) and
#' \code{proximity} is the proximitry matrix.'
#'
#'
#' 'If \code{nodes=TRUE}, the returned object has a "nodes" attribute,
#' which is an \code{n} by \code{ntree} matrix, each column containing the
#' node number that the cases fall in for that tree.'
#'
#'
#' @examples
#' # rats data example
#' # Note that this is a toy example. Use a larger ntree and nfold in practice.
#' library(survival) # for Surv()
#' data(rat2)
#' set.seed(1)
#' samp <- sample(1:dim(rat2)[1], 200)
#' rats.train <- rat2[samp, ]
#' rats.test <- rat2[-samp, ]
#' L = ifelse(rats.train$tumor, 0, rats.train$survtime)
#' R = ifelse(rats.train$tumor, rats.train$survtime, Inf)
#' \donttest{
#' set.seed(2)
#' rats.icrf.small <-
#' icrf(survival::Surv(L, R, type = "interval2") ~ dose.lvl + weight + male + cage.no,
#' data = rats.train, ntree = 10, nfold = 3, proximity = TRUE)
#'
#' # predicted survival curve for the training data
#' predict(rats.icrf.small)
#' predict(rats.icrf.small, smooth = FALSE) # non-smoothed
#'
#' # predicted survival curve for new data
#' predict(rats.icrf.small, newdata = rats.test)
#' predict(rats.icrf.small, newdata = rats.test, proximity = TRUE)
#'
#' # time can be extracted using attr()
#' newpred = predict(rats.icrf.small, newdata = rats.test)
#' attr(newpred, "time")
#'
#' newpred2 = predict(rats.icrf.small, newdata = rats.test, proximity = TRUE)
#' attr(newpred2$predicted, "time")
#' }
#' \dontshow{
#' set.seed(2)
#' rats.icrf.small <-
#' icrf(survival::Surv(L, R, type = "interval2") ~ dose.lvl + weight + male + cage.no,
#' data = rats.train, ntree = 2, nfold = 2, proximity = TRUE)
#'
#' # predicted survival curve for the training data
#' predict(rats.icrf.small)
#' predict(rats.icrf.small, smooth = FALSE) # non-smoothed
#'
#' # predicted survival curve for new data
#' predict(rats.icrf.small, newdata = rats.test)
#' predict(rats.icrf.small, newdata = rats.test, proximity = TRUE)
#'
#' # time can be extracted using attr()
#' newpred = predict(rats.icrf.small, newdata = rats.test)
#' attr(newpred, "time")
#'
#' newpred2 = predict(rats.icrf.small, newdata = rats.test, proximity = TRUE)
#' attr(newpred2$predicted, "time")
#' }
#'
#'
#' @author Hunyong Cho, Nicholas P. Jewell, and Michael R. Kosorok.
#'
#' @references
#' \href{https://arxiv.org/abs/1912.09983}{Cho H., Jewell N. J., and Kosorok M. R. (2020+). "Interval censored
#' recursive forests"}
#'
#' @export
#' @useDynLib icrf
#'
"predict.icrf" <-
function (object, newdata, # time.points = NULL,
predict.all=FALSE, proximity = FALSE, nodes=FALSE,
smooth = TRUE, ...) {
if (!inherits(object, "icrf"))
stop("object not of class icrf")
prediction = ifelse(smooth, "predicted.Sm", "predicted")
nodePrediction = ifelse(smooth, "nodepredSm", "nodepred")
timepts = ifelse(smooth, "time.points.smooth", "time.points")
if (missing(newdata)) {
p <- if (! is.null(object$na.action)) {
napredict(object$na.action, object[[prediction]])
} else {
object[[prediction]]
}
attr(p, "time") <- object[[timepts]]
if (proximity & is.null(object$proximity))
warning("cannot return proximity without new data if random forest object does not already have proximity")
if (proximity) {
res = list(pred = p, proximity = object$proximity)
} else
res = p
return(res)
}
if (is.null(object$forest)) stop("No forest component in the object")
if (inherits(object, "icrf.formula")) {
newdata <- as.data.frame(newdata)
rn <- row.names(newdata)
Terms <- delete.response(object$terms)
x <- model.frame(Terms, newdata, na.action = na.omit)
keep <- match(row.names(x), rn)
} else {
if (is.null(dim(newdata)))
dim(newdata) <- c(1, length(newdata))
x <- newdata
if (nrow(x) == 0)
stop("newdata has 0 rows")
if (any(is.na(x)))
stop("missing values in newdata")
keep <- 1:nrow(x)
rn <- rownames(x)
if (is.null(rn)) rn <- keep
}
vname <- if (is.null(dim(object$importance))) {
names(object$importance)
} else {
rownames(object$importance)
}
if (is.null(colnames(x))) {
if (ncol(x) != length(vname)) {
stop("number of variables in newdata does not match that in the training data")
}
} else {
if (any(! vname %in% colnames(x)))
stop("variables in the training data missing in newdata")
x <- x[, vname, drop=FALSE]
}
if (is.data.frame(x)) {
isFactor <- function(x) is.factor(x) & ! is.ordered(x)
xfactor <- which(sapply(x, isFactor))
if (length(xfactor) > 0 && "xlevels" %in% names(object$forest)) {
for (i in xfactor) {
if (any(! levels(x[[i]]) %in% object$forest$xlevels[[i]]))
stop("New factor levels not present in the training data")
x[[i]] <-
factor(x[[i]],
levels=levels(x[[i]])[match(levels(x[[i]]), object$forest$xlevels[[i]])])
}
}
cat.new <- sapply(x, function(x) if (is.factor(x) && !is.ordered(x))
length(levels(x)) else 1)
if (!all(object$forest$ncat == cat.new))
stop("Type of predictors in new data do not match that of the training data.")
}
mdim <- ncol(x)
ntest <- nrow(x)
ntree <- object$forest$ntree
maxcat <- max(object$forest$ncat)
nclass <- object$forest$nclass
nrnodes <- object$forest$nrnodes
## get rid of warning:
op <- options(warn=-1)
on.exit(options(op))
x <- t(data.matrix(x))
time.points = object[[timepts]]
t.names <- paste0("t", seq_along(time.points))
ntime <- length(time.points)
if (predict.all) {
treepred <- array(double(ntest * ntime * ntree), dim = c(ntest, ntime, ntree))
} else {
treepred <- array(integer(ntest * ntime * ntree), dim = c(ntest, ntime, ntree))
}
proxmatrix <- if (proximity) matrix(0, ntest, ntest) else numeric(1)
nodexts <- if (nodes) integer(ntest * ntree) else integer(ntest)
if (!is.null(object$forest$treemap)) {
object$forest$leftDaughter <-
object$forest$treemap[,1,, drop=FALSE]
object$forest$rightDaughter <-
object$forest$treemap[,2,, drop=FALSE]
object$forest$treemap <- NULL
}
keepIndex <- "ypred"
if (predict.all) keepIndex <- c(keepIndex, "treepred")
if (proximity) keepIndex <- c(keepIndex, "proximity")
if (nodes) keepIndex <- c(keepIndex, "nodexts")
## Ensure storage mode is what is expected in C.
if (! is.integer(object$forest$leftDaughter))
storage.mode(object$forest$leftDaughter) <- "integer"
if (! is.integer(object$forest$rightDaughter))
storage.mode(object$forest$rightDaughter) <- "integer"
if (! is.integer(object$forest$nodestatus))
storage.mode(object$forest$nodestatus) <- "integer"
if (! is.double(object$forest$xbestsplit))
storage.mode(object$forest$xbestsplit) <- "double"
if (! is.double(object$forest[[nodePrediction]]))
storage.mode(object$forest[[nodePrediction]]) <- "double"
if (! is.integer(object$forest$bestvar))
storage.mode(object$forest$bestvar) <- "integer"
if (! is.integer(object$forest$ndbigtree))
storage.mode(object$forest$ndbigtree) <- "integer"
if (! is.integer(object$forest$ncat))
storage.mode(object$forest$ncat) <- "integer"
ans <- .C("survForest",
as.double(x),
ypred = double(ntest * ntime),
as.integer(mdim),
#as.integer(ntime.rf),
as.integer(ntime),
as.integer(ntest),
as.integer(ntree),
object$forest$leftDaughter,
object$forest$rightDaughter,
object$forest$nodestatus,
nrnodes,
object$forest$xbestsplit,
object$forest[[nodePrediction]],
object$forest$bestvar,
object$forest$ndbigtree,
object$forest$ncat,
as.integer(maxcat),
as.integer(predict.all),
treepred = as.double(treepred),
as.integer(proximity),
proximity = as.double(proxmatrix),
nodes = as.integer(nodes),
nodexts = as.integer(nodexts),
#DUP=FALSE,
PACKAGE = "icrf")[keepIndex]
## Apply bias correction if needed.
yhat <- matrix(NA, nrow = ntest, ncol = ntime)
rownames(yhat) <- rn
colnames(yhat) <- t.names
yhat[keep, ] <- ans$ypred
attr(yhat, "time") <- time.points
if (predict.all) {
treepred <- array(NA, dim = c(ntest, ntime, ntree),
dimnames=list(rn, t.names, NULL))
treepred[keep, , ] <- ans$treepred
}
if (!proximity) {
res <- if (predict.all)
list(aggregate=yhat, individual=treepred) else yhat
} else { ### TBD this part. ###
res <- list(predicted = yhat,
proximity = structure(ans$proximity,
dim=c(ntest, ntest), dimnames=list(rn, rn)))
}
# attr(res, "time") <- object[[timepts]]
if (nodes) {
attr(res, "nodes") <- matrix(ans$nodexts, ntest, ntree,
dimnames=list(rn[keep], 1:ntree))
}
res
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.