#' predict.rfCountData
#' @title predict method for random forest Count Data objects
#' @description Prediction of test data using random forest.
#' @param object an object of class \code{rfCountData}, as that created by the function \code{rf}
#' @param newdata a data frame or matrix containing new data.
#' (Note: If not given, the out-of-bag prediction in \code{object} is returned.
#' @param offset a vector with a length equal to the number of rows in \code{newdata}. Log of time of exposure.
#' @param type To be removed, or score scale vs. response scale ?
#' @param norm.votes not used
#' @param predict.all Should the predictions of all trees be kept?
#' @param proximity To be removed
#' @param nodes Should the terminal node indicators (an n by ntree matrix) be return?
#' If so, it is in the 'nodes' attribute of the returned object.
#' @param cutoff To be removed
#' @param ... Currently not used
#'
#' @return {A vector of predicted values is returned. \cr
#' If \code{predict.all=TRUE}, then the returned object is a list of two components: \code{aggregate},
#' which is the vector of predicted values by the forest, and \code{individual}, which
#' is a matrix where each column contains prediction by a tree in the forest.\cr
#' If \code{predict.all=TRUE}, then the \code{individual} component of thereturned object is a
#' character matrix where each column contains the predicted class by a tree in the forest. \cr
#' If \code{nodes=TRUE}, the returned object has a 'nodes' attribute, which is an n by ntree matrix,
#' each column containing the node number that the cases fall in for that tree.}
#'
#' @references Breiman, L. (2001), \emph{Random Forests}, Machine Learning 45(1),5-32.
#' @author Andy Liaw \email{andy_liaw@merck.com} and Matthew Wiener \email{matthew_wiener@merck.com},
#' based on original Fortran code by Leo Breiman and Adele Cutler.
#' @seealso \code{\link{rfPoisson}}
#' @keywords regression
#'
#' @export
predict.rfCountData <-
function (object, newdata, offset, type = "response", norm.votes = TRUE,
predict.all=FALSE, proximity = FALSE, nodes=FALSE, cutoff, ...)
{
if (!inherits(object, "rfCountData"))
stop("object not of class rfCountData")
if (is.null(object$forest)) stop("No forest component in the object")
out.type <- charmatch(tolower(type),
c("response", "prob", "vote", "class"))
if (is.na(out.type))
stop("type must be one of 'response', 'prob', 'vote'")
if (out.type == 4) out.type <- 1
if (out.type != 1 && object$type == "regression")
stop("'prob' or 'vote' not meaningful for regression")
if (out.type == 2)
norm.votes <- TRUE
if (missing(newdata)) {
p <- if (! is.null(object$na.action)) {
napredict(object$na.action, object$predicted)
} else {
object$predicted
}
if (object$type == "regression") return(p)
if (proximity & is.null(object$proximity))
warning("cannot return proximity without new data if random forest object does not already have proximity")
if (out.type == 1) {
if (proximity) {
return(list(pred = p,
proximity = object$proximity))
} else return(p)
}
v <- object$votes
if (!is.null(object$na.action)) v <- napredict(object$na.action, v)
if (norm.votes) {
t1 <- t(apply(v, 1, function(x) { x/sum(x) }))
class(t1) <- c(class(t1), "votes")
if (proximity) return(list(pred = t1, proximity = object$proximity))
else return(t1)
} else {
if (proximity) return(list(pred = v, proximity = object$proximity))
else return(v)
}
}
if (missing(cutoff)) {
cutoff <- object$forest$cutoff
} else {
if (sum(cutoff) > 1 || sum(cutoff) < 0 || !all(cutoff > 0) ||
length(cutoff) != length(object$classes)) {
stop("Incorrect cutoff specified.")
}
if (!is.null(names(cutoff))) {
if (!all(names(cutoff) %in% object$classes)) {
stop("Wrong name(s) for cutoff")
}
cutoff <- cutoff[object$classes]
}
}
if (object$type == "unsupervised")
stop("Can't predict unsupervised forest.")
if (inherits(object, "randomForest.formula")) {
newdata <- as.data.frame(newdata)
rn <- row.names(newdata)
Terms <- delete.response(object$terms)
x <- model.frame(Terms, newdata, na.action = na.omit)
keep <- match(row.names(x), rn)
} else {
if (is.null(dim(newdata)))
dim(newdata) <- c(1, length(newdata))
x <- newdata
if (nrow(x) == 0)
stop("newdata has 0 rows")
if (any(is.na(x)))
stop("missing values in newdata")
keep <- 1:nrow(x)
rn <- rownames(x)
if (is.null(rn)) rn <- keep
}
vname <- if (is.null(dim(object$importance))) {
names(object$importance)
} else {
rownames(object$importance)
}
if (is.null(colnames(x))) {
if (ncol(x) != length(vname)) {
stop("number of variables in newdata does not match that in the training data")
}
} else {
if (any(! vname %in% colnames(x)))
stop("variables in the training data missing in newdata")
x <- x[, vname, drop=FALSE]
}
if (is.data.frame(x)) {
isFactor <- function(x) is.factor(x) & ! is.ordered(x)
xfactor <- which(sapply(x, isFactor))
if (length(xfactor) > 0 && "xlevels" %in% names(object$forest)) {
for (i in xfactor) {
if (any(! levels(x[[i]]) %in% object$forest$xlevels[[i]]))
stop("New factor levels not present in the training data")
x[[i]] <-
factor(x[[i]],
levels=levels(x[[i]])[match(levels(x[[i]]), object$forest$xlevels[[i]])])
}
}
cat.new <- sapply(x, function(x) if (is.factor(x) && !is.ordered(x))
length(levels(x)) else 1)
if (!all(object$forest$ncat == cat.new))
stop("Type of predictors in new data do not match that of the training data.")
}
mdim <- ncol(x)
ntest <- nrow(x)
ntree <- object$forest$ntree
maxcat <- max(object$forest$ncat)
nclass <- object$forest$nclass
nrnodes <- object$forest$nrnodes
## get rid of warning:
op <- options(warn=-1)
on.exit(options(op))
x <- t(data.matrix(x))
if (predict.all) {
treepred <- if (object$type == "regression") {
matrix(double(ntest * ntree), ncol=ntree)
} else {
matrix(integer(ntest * ntree), ncol=ntree)
}
} else {
treepred <- numeric(ntest)
}
proxmatrix <- if (proximity) matrix(0, ntest, ntest) else numeric(1)
nodexts <- if (nodes) integer(ntest * ntree) else integer(ntest)
if (object$type == "regression") {
if (!is.null(object$forest$treemap)) {
object$forest$leftDaughter <-
object$forest$treemap[,1,, drop=FALSE]
object$forest$rightDaughter <-
object$forest$treemap[,2,, drop=FALSE]
object$forest$treemap <- NULL
}
keepIndex <- "ypred"
if (predict.all) keepIndex <- c(keepIndex, "treepred")
if (proximity) keepIndex <- c(keepIndex, "proximity")
if (nodes) keepIndex <- c(keepIndex, "nodexts")
## Ensure storage mode is what is expected in C.
if (! is.integer(object$forest$leftDaughter))
storage.mode(object$forest$leftDaughter) <- "integer"
if (! is.integer(object$forest$rightDaughter))
storage.mode(object$forest$rightDaughter) <- "integer"
if (! is.integer(object$forest$nodestatus))
storage.mode(object$forest$nodestatus) <- "integer"
if (! is.double(object$forest$xbestsplit))
storage.mode(object$forest$xbestsplit) <- "double"
if (! is.double(object$forest$nodepred))
storage.mode(object$forest$nodepred) <- "double"
if (! is.integer(object$forest$bestvar))
storage.mode(object$forest$bestvar) <- "integer"
if (! is.integer(object$forest$ndbigtree))
storage.mode(object$forest$ndbigtree) <- "integer"
if (! is.integer(object$forest$ncat))
storage.mode(object$forest$ncat) <- "integer"
ans <- .C("regForest",
as.double(x),
as.double(offset),
ypred = double(ntest),
as.integer(mdim),
as.integer(ntest),
as.integer(ntree),
object$forest$leftDaughter,
object$forest$rightDaughter,
object$forest$nodestatus,
nrnodes,
object$forest$xbestsplit,
object$forest$nodepred,
object$forest$bestvar,
object$forest$ndbigtree,
object$forest$ncat,
as.integer(maxcat),
as.integer(predict.all),
treepred = as.double(treepred),
as.integer(proximity),
proximity = as.double(proxmatrix),
nodes = as.integer(nodes),
nodexts = as.integer(nodexts),
DUP=FALSE,
PACKAGE = "rfCountData")[keepIndex]
## Apply bias correction if needed.
yhat <- rep(NA, length(rn))
names(yhat) <- rn
if (!is.null(object$coefs)) {
yhat[keep] <- object$coefs[1] + object$coefs[2] * ans$ypred
} else {
yhat[keep] <- ans$ypred
}
if (predict.all) {
treepred <- matrix(NA, length(rn), ntree,
dimnames=list(rn, NULL))
treepred[keep,] <- ans$treepred
}
if (!proximity) {
res <- if (predict.all)
list(aggregate=yhat, individual=treepred) else yhat
} else {
res <- list(predicted = yhat,
proximity = structure(ans$proximity,
dim=c(ntest, ntest), dimnames=list(rn, rn)))
}
if (nodes) {
attr(res, "nodes") <- matrix(ans$nodexts, ntest, ntree,
dimnames=list(rn[keep], 1:ntree))
}
} else {
countts <- matrix(0, ntest, nclass)
t1 <- .C("classForest",
mdim = as.integer(mdim),
ntest = as.integer(ntest),
nclass = as.integer(object$forest$nclass),
maxcat = as.integer(maxcat),
nrnodes = as.integer(nrnodes),
jbt = as.integer(ntree),
xts = as.double(x),
xbestsplit = as.double(object$forest$xbestsplit),
pid = object$forest$pid,
cutoff = as.double(cutoff),
countts = as.double(countts),
treemap = as.integer(aperm(object$forest$treemap,
c(2, 1, 3))),
nodestatus = as.integer(object$forest$nodestatus),
cat = as.integer(object$forest$ncat),
nodepred = as.integer(object$forest$nodepred),
treepred = as.integer(treepred),
jet = as.integer(numeric(ntest)),
bestvar = as.integer(object$forest$bestvar),
nodexts = as.integer(nodexts),
ndbigtree = as.integer(object$forest$ndbigtree),
predict.all = as.integer(predict.all),
prox = as.integer(proximity),
proxmatrix = as.double(proxmatrix),
nodes = as.integer(nodes),
DUP=FALSE,
PACKAGE = "rfCountData")
if (out.type > 1) {
out.class.votes <- t(matrix(t1$countts, nrow = nclass, ncol = ntest))
if (norm.votes)
out.class.votes <-
sweep(out.class.votes, 1, rowSums(out.class.votes), "/")
z <- matrix(NA, length(rn), nclass,
dimnames=list(rn, object$classes))
z[keep, ] <- out.class.votes
class(z) <- c(class(z), "votes")
res <- z
} else {
out.class <- factor(rep(NA, length(rn)),
levels=1:length(object$classes),
labels=object$classes)
out.class[keep] <- object$classes[t1$jet]
names(out.class)[keep] <- rn[keep]
res <- out.class
}
if (predict.all) {
treepred <- matrix(object$classes[t1$treepred],
nrow=length(keep), dimnames=list(rn[keep], NULL))
res <- list(aggregate=res, individual=treepred)
}
if (proximity)
res <- list(predicted = res, proximity = structure(t1$proxmatrix,
dim = c(ntest, ntest),
dimnames = list(rn[keep], rn[keep])))
if (nodes) attr(res, "nodes") <- matrix(t1$nodexts, ntest, ntree,
dimnames=list(rn[keep], 1:ntree))
}
res
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.