Nothing
################################################################################
# This is the part of the 'tidyrules' R package hosted at
# https://github.com/talegari/tidyrules with GPL-3 license.
################################################################################
#' @name tidyRules.rpart
#' @title Obtain rules as a tidy tibble from a rpart model
#' @description Each row corresponds to a rule. A rule can be copied into
#' `dplyr::filter` to filter the observations corresponding to a rule
#' @author Amith Kumar U R, \email{amith54@@gmail.com}
#' @param object Fitted model object with rules
#' @param ... Other arguments (currently unused)
#' @details NOTE: For rpart rules, one should build the model without
#' \bold{ordered factor} variable. We recommend you to convert \bold{ordered
#' factor} to \bold{factor} or \bold{integer} class.
#'
#' Optional named arguments:
#'
#' \itemize{
#'
#' \item language (string, default: "r"): language where the rules are parsable.
#' The allowed options is one among: r, python, sql
#'
#' }
#'
#' @return A tibble where each row corresponds to a rule. The columns are:
#' support, confidence, lift, LHS, RHS
#' @examples
#' iris_rpart <- rpart::rpart(Species ~ .,data = iris)
#' tidyRules(iris_rpart)
#' @export
tidyRules.rpart <- function(object, ...){
# asserts for 'language'
arguments = list(...)
if(is.null(arguments[["language"]])){
arguments[["language"]] = "r"
} else {
assertthat::assert_that(assertthat::is.string(arguments[["language"]]))
arguments[["language"]] = stringr::str_to_lower(arguments[["language"]])
assertthat::assert_that(arguments[["language"]] %in% c("r", "python", "sql"))
}
# check for rpart object
stopifnot(inherits(object, "rpart"))
if(object$method == "class"){
# Stop if only root node is present in the object
if(nrow(object$frame) == 1){
stop(paste0("Only root is present in the rpart object"
)
)
}
# Stop if any ordered factor is present:
# partykit doesn't handle the ordered factors while processing rules.
if(sum(object$ordered) > 0){
stop(paste0("Ordered variables detected!!"
, "convert ordered variables"
, " to factor or numberic before model fit"))
}
if(is.null(object$y)){
stop(
stringr::str_c(
"Unable to find target variable in the model object!! "
, "Model should be built using argument `y = TRUE`."
)
)
}
# column names from the object: This will be used at the end to handle the
# variables with a space.
col_names <- stringr::str_remove_all(attr(object$terms,which = "term.labels")
, pattern = "`")
# throw error if there are consecutive spaces in the column names ----
if(any(stringr::str_count(col_names, " ") > 0)){
stop("Variable names should not have two or more consecutive spaces.")
}
# convert to class "party"
party_obj <- partykit::as.party(object)
# extracting rules
rules <- list.rules.party(party_obj) %>%
stringr::str_replace_all(pattern = "\\\"","'") %>%
stringr::str_remove_all(pattern = ", 'NA'") %>%
stringr::str_remove_all(pattern = "'NA',") %>%
stringr::str_remove_all(pattern = "'NA'") %>%
stringr::str_squish()
# terminal nodes from party object
terminal_nodes <- partykit::nodeids(party_obj, terminal = T)
# extract metrics from rpart object
metrics <- object$frame[terminal_nodes,c("n","dev","yval")]
metrics$confidence <- (metrics$n + 1 - metrics$dev) / (metrics$n + 2)
metrics <- metrics[,c("n","yval","confidence")] %>%
magrittr::set_colnames(c("support","predict_class","confidence"))
# prevelance for lift calculation
prevelance <- object$y %>%
table() %>%
prop.table() %>%
as.numeric()
# Actual labels for RHS
metrics$RHS <- attr(object, "ylevels")[metrics$predict_class]
metrics$prevelance <- prevelance[metrics$predict_class]
metrics$lift <- metrics$confidence / metrics$prevelance
metrics$LHS <- rules
tidy_rules <- metrics
# replace variable names with spaces within backquotes ----
for(i in 1:length(col_names)){
tidy_rules[["LHS"]] <- stringr::str_replace_all(
tidy_rules[["LHS"]]
, col_names[i]
, addBackquotes(col_names[i])
)
}
# return ----
tidy_rules <- tibble::rowid_to_column(tidy_rules, "id")
tidy_rules <- tidy_rules[, c("id"
, "LHS"
, "RHS"
, "support"
, "confidence"
, "lift")
] %>%
tibble::as_tibble()
# handle the rule parsable language
lang = arguments[["language"]]
if (lang == "python"){
res[["LHS"]] = ruleRToPython(res[["LHS"]])
} else if (lang == "sql"){
res[["LHS"]] = ruleRToSQL(res[["LHS"]])
}
return(tidy_rules)
} else {
# Stop if only root node is present in the object
if(nrow(object$frame) == 1){
stop(paste0("Only root is present in the rpart object"
)
)
}
# Stop if any ordered factor is present:
# partykit doesn't handle the ordered factors while processing rules.
if(sum(object$ordered) > 0){
stop(paste0("Ordered variables detected!!"
, "convert ordered variables"
, " to factor or numberic before model fit"))
}
# column names from the object: This will be used at the end to handle the
# variables with a space.
col_names <- stringr::str_remove_all(attr(object$terms,which = "term.labels")
, pattern = "`")
# throw error if there are consecutive spaces in the column names ----
if(any(stringr::str_count(col_names, " ") > 0)){
stop("Variable names should not have two or more consecutive spaces.")
}
# convert to class "party"
party_obj <- partykit::as.party(object)
# extracting rules
rules <- list.rules.party(party_obj) %>%
stringr::str_replace_all(pattern = "\\\"","'") %>%
stringr::str_remove_all(pattern = ", 'NA'") %>%
stringr::str_remove_all(pattern = "'NA',") %>%
stringr::str_remove_all(pattern = "'NA'") %>%
stringr::str_squish()
# terminal nodes from party object
terminal_nodes <- partykit::nodeids(party_obj, terminal = T)
# extract metrics from rpart object
metrics <- object$frame[terminal_nodes,c("n","dev","yval")]
# metrics$confidence <- (metrics$n + 1 - metrics$dev) / (metrics$n + 2)
metrics <- metrics[,c("n","yval")] %>%
magrittr::set_colnames(c("support","RHS"))
metrics$LHS <- rules
tidy_rules <- metrics
# replace variable names with spaces within backquotes ----
for(i in 1:length(col_names)){
tidy_rules[["LHS"]] <- stringr::str_replace_all(
tidy_rules[["LHS"]]
, col_names[i]
, addBackquotes(col_names[i])
)
}
# return ----
tidy_rules <- tibble::rowid_to_column(tidy_rules, "id")
tidy_rules <- tidy_rules[, c("id"
, "LHS"
, "RHS"
, "support"
)
] %>%
tibble::as_tibble()
# handle the rule parsable language
lang = arguments[["language"]]
if (lang == "python"){
res[["LHS"]] = ruleRToPython(res[["LHS"]])
} else if (lang == "sql"){
res[["LHS"]] = ruleRToSQL(res[["LHS"]])
}
return(tidy_rules)
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.