party-predict: Tree Predictions

party-predictR Documentation

Tree Predictions

Description

Compute predictions from party objects.

Usage

## S3 method for class 'party'
predict(object, newdata = NULL, perm = NULL, ...)
predict_party(party, id, newdata = NULL, ...)
## Default S3 method:
predict_party(party, id, newdata = NULL, FUN = NULL, ...) 
## S3 method for class 'constparty'
predict_party(party, id, newdata = NULL,
    type = c("response", "prob", "quantile", "density", "node"),
    at = if (type == "quantile") c(0.1, 0.5, 0.9),
    FUN = NULL, simplify = TRUE, ...)
## S3 method for class 'simpleparty'
predict_party(party, id, newdata = NULL,
    type = c("response", "prob", "node"), ...)

Arguments

object

objects of class party.

newdata

an optional data frame in which to look for variables with which to predict, if omitted, the fitted values are used.

perm

an optional character vector of variable names. Splits of nodes with a primary split in any of these variables will be permuted (after dealing with surrogates). Note that surrogate split in the perm variables will no be permuted.

party

objects of class party.

id

a vector of terminal node identifiers.

type

a character string denoting the type of predicted value returned, ignored when argument FUN is given. For "response", the mean of a numeric response, the predicted class for a categorical response or the median survival time for a censored response is returned. For "prob" the matrix of conditional class probabilities (simplify = TRUE) or a list with the conditional class probabilities for each observation (simplify = FALSE) is returned for a categorical response. For numeric and censored responses, a list with the empirical cumulative distribution functions and empirical survivor functions (Kaplan-Meier estimate) is returned when type = "prob". "node" returns an integer vector of terminal node identifiers.

FUN

a function to extract (default method) or compute (constparty method) summary statistics. For the default method, this is a function of a terminal node only, for the constparty method, predictions for each node have to be computed based on arguments (y, w) where y is the response and w are case weights.

at

if the return value is a function (as the empirical cumulative distribution function or the empirical quantile function), this function is evaluated at values at and these numeric values are returned. If at is NULL, the functions themselves are returned in a list.

simplify

a logical indicating whether the resulting list of predictions should be converted to a suitable vector or matrix (if possible).

...

additional arguments.

Details

The predict method for party objects computes the identifiers of the predicted terminal nodes, either for new data in newdata or for the learning samples (only possible for objects of class constparty). These identifiers are delegated to the corresponding predict_party method which computes (via FUN for class constparty) or extracts (class simpleparty) the actual predictions.

Value

A list of predictions, possibly simplified to a numeric vector, numeric matrix or factor.

Examples


  ## fit tree using rpart
  library("rpart")
  rp <- rpart(skips ~ Opening + Solder + Mask + PadType + Panel,
              data = solder, method = 'anova')

  ## coerce to `constparty'
  pr <- as.party(rp)

  ## mean predictions
  predict(pr, newdata = solder[c(3, 541, 640),])

  ## ecdf
  predict(pr, newdata = solder[c(3, 541, 640),], type = "prob")

  ## terminal node identifiers
  predict(pr, newdata = solder[c(3, 541, 640),], type = "node")

  ## median predictions
  predict(pr, newdata = solder[c(3, 541, 640),], 
          FUN = function(y, w = 1) median(y))


partykit documentation built on April 14, 2023, 5:09 p.m.