R/rpart.r

Defines functions fit_rpart predict_rpart

Documented in fit_rpart predict_rpart

#' Fit a decision tree
#' 
#' @param x Data set (features).
#' @param y Response.
#' @param ... Sent to \code{\link{rpart}}.
#' @return A fitted decision tree.
#' @author Christofer \enc{Bäcklin}{Backlin}
#' @references Breiman L., Friedman J. H., Olshen R. A., and Stone, C. J. (1984)
#'   \emph{Classification and Regression Trees}. Wadsworth.
#' @export
fit_rpart <- function(x, y, ...){
    nice_require("rpart", "is needed to fit decision trees")
    model <- if(inherits(y, "formula")){
        rpart::rpart(formula = y, data = x, ...)
    } else {
        rpart::rpart(formula = y ~ ., data = x, ...)
    }
    model$y <- y
    model
}

#' Predict using a fitted decision tree
#' 
#' @param object Fitted decision tree.
#' @param x New data whose response is to be predicted.
#' @return Predictions. The exact form depends on the type of application
#'   (classification or regression)
#' @author Christofer \enc{Bäcklin}{Backlin}
#' @export
predict_rpart <- function(object, x){
    if(is.factor(object$y)){
        # Classification
        list(prediction = predict(object, x, type="class"),
             probability = as.data.frame(predict(object, x, type="prob")))
    } else {
        # Regression
        list(prediction = predict(object, x, type="vector"))
    }
}

Try the emil package in your browser

Any scripts or data that you put into this service are public.

emil documentation built on Aug. 1, 2018, 1:03 a.m.