R/crtree.R

Defines functions crtree

Documented in crtree

#' Classification and regression trees based on the rpart package
#'
#' @details See \url{https://radiant-rstats.github.io/docs/model/crtree.html} for an example in Radiant
#'
#' @param dataset Dataset
#' @param rvar The response variable in the model
#' @param evar Explanatory variables in the model
#' @param type Model type (i.e., "classification" or "regression")
#' @param lev The level in the response variable defined as _success_
#' @param wts Weights to use in estimation
#' @param minsplit The minimum number of observations that must exist in a node in order for a split to be attempted.
#' @param minbucket the minimum number of observations in any terminal <leaf> node. If only one of minbucket or minsplit is specified, the code either sets minsplit to minbucket*3 or minbucket to minsplit/3, as appropriate.
#' @param cp Minimum proportion of root node deviance required for split (default = 0.001)
#' @param pcp Complexity parameter to use for pruning
#' @param nodes Maximum size of tree in number of nodes to return
#' @param K Number of folds use in cross-validation
#' @param seed Random seed used for cross-validation
#' @param split Splitting criterion to use (i.e., "gini" or "information")
#' @param prior Adjust the initial probability for the selected level (e.g., set to .5 in unbalanced samples)
#' @param adjprob Setting a prior will rescale the predicted probabilities. Set adjprob to TRUE to adjust the probabilities back to their original scale after estimation
#' @param cost Cost for each treatment (e.g., mailing)
#' @param margin Margin associated with a successful treatment (e.g., a purchase)
#' @param check Optional estimation parameters (e.g., "standardize")
#' @param data_filter Expression entered in, e.g., Data > View to filter the dataset in Radiant. The expression should be a string (e.g., "price > 10000")
#' @param arr Expression to arrange (sort) the data on (e.g., "color, desc(price)")
#' @param rows Rows to select from the specified dataset
#' @param envir Environment to extract data from
#'
#' @return A list with all variables defined in crtree as an object of class tree
#'
#' @examples
#' crtree(titanic, "survived", c("pclass", "sex"), lev = "Yes") %>% summary()
#' result <- crtree(titanic, "survived", c("pclass", "sex")) %>% summary()
#' result <- crtree(diamonds, "price", c("carat", "clarity"), type = "regression") %>% str()
#' @seealso \code{\link{summary.crtree}} to summarize results
#' @seealso \code{\link{plot.crtree}} to plot results
#' @seealso \code{\link{predict.crtree}} for prediction
#'
#' @importFrom rpart rpart rpart.control prune.rpart
#'
#' @export
crtree <- function(dataset, rvar, evar, type = "", lev = "", wts = "None",
                   minsplit = 2, minbucket = round(minsplit / 3), cp = 0.001,
                   pcp = NA, nodes = NA, K = 10, seed = 1234, split = "gini",
                   prior = NA, adjprob = TRUE, cost = NA, margin = NA, check = "",
                   data_filter = "", arr = "", rows = NULL, envir = parent.frame()) {
  if (rvar %in% evar) {
    return("Response variable contained in the set of explanatory variables.\nPlease update model specification." %>%
      add_class("crtree"))
  }

  ## allow cp to be negative so full tree is built http://stackoverflow.com/q/24150058/1974918
  if (is.empty(cp)) {
    return("Please provide a complexity parameter to split the data." %>% add_class("crtree"))
  } else if (!is.empty(nodes) && nodes < 2) {
    return("The (maximum) number of nodes in the tree should be larger than or equal to 2." %>% add_class("crtree"))
  }

  vars <- c(rvar, evar)

  if (is.empty(wts, "None")) {
    wts <- NULL
  } else if (is_string(wts)) {
    wtsname <- wts
    vars <- c(rvar, evar, wtsname)
  }

  df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset))
  dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir)

  if (!is.empty(wts)) {
    if (exists("wtsname")) {
      wts <- dataset[[wtsname]]
      dataset <- select_at(dataset, .vars = base::setdiff(colnames(dataset), wtsname))
    }
    if (length(wts) != nrow(dataset)) {
      return(
        paste0("Length of the weights variable is not equal to the number of rows in the dataset (", format_nr(length(wts), dec = 0), " vs ", format_nr(nrow(dataset), dec = 0), ")") %>%
          add_class("crtree")
      )
    }
  }

  not_vary <- colnames(dataset)[summarise_all(dataset, does_vary) == FALSE]
  if (length(not_vary) > 0) {
    return(paste0("The following variable(s) show no variation. Please select other variables.\n\n** ", paste0(not_vary, collapse = ", "), " **") %>%
      add_class("crtree"))
  }

  rv <- dataset[[rvar]]

  if (type == "classification" && !is.factor(rv)) {
    dataset[[rvar]] <- as_factor(dataset[[rvar]])
  }

  if (is.factor(dataset[[rvar]])) {
    if (type == "regression") {
      return("Cannot estimate a regression when the response variable is of type factor." %>% add_class("crtree"))
    }

    if (lev == "") {
      lev <- levels(dataset[[rvar]])[1]
    } else {
      if (!lev %in% levels(dataset[[rvar]])) {
        return(paste0("Specified level is not a level in ", rvar) %>% add_class("crtree"))
      }

      dataset[[rvar]] <- factor(dataset[[rvar]], levels = unique(c(lev, levels(dataset[[rvar]]))))
    }

    type <- "classification"
    method <- "class"
  } else {
    type <- "regression"
    method <- "anova"
  }

  ## logicals would get < 0.5 and >= 0.5 otherwise
  ## also need to update data in predict_model
  ## so the correct type is used in prediction
  dataset <- mutate_if(dataset, is.logical, as.factor)

  ## standardize data ...
  if ("standardize" %in% check) {
    dataset <- scale_df(dataset, wts = wts)
  }

  vars <- evar
  ## in case : is used
  if (length(vars) < (ncol(dataset) - 1)) vars <- evar <- colnames(dataset)[-1]

  form <- paste(rvar, "~ . ")

  seed %>%
    gsub("[^0-9]", "", .) %>%
    (function(x) if (!is.empty(x)) set.seed(seed))

  minsplit <- ifelse(is.empty(minsplit), 2, minsplit)
  minbucket <- ifelse(is.empty(minbucket), round(minsplit / 3), minbucket)

  ## make max tree
  # http://stackoverflow.com/questions/24150058/rpart-doesnt-build-a-full-tree-problems-with-cp
  control <- rpart::rpart.control(
    cp = cp,
    xval = K,
    minsplit = minsplit,
    minbucket = minbucket,
  )

  parms <- list(split = split)
  if (type == "classification") {
    ind <- if (which(lev %in% levels(dataset[[rvar]])) == 1) c(1, 2) else c(2, 1)
    if (!is.empty(prior) && !is_not(cost) && !is_not(cost)) {
      return("Choose either a prior or cost and margin values but not both.\nPlease adjust your settings and try again" %>% add_class("crtree"))
    }

    if (!is_not(cost) && !is_not(margin)) {
      loss2 <- as_numeric(cost)
      loss1 <- as_numeric(margin) - loss2

      if (loss1 <= 0) {
        return("Cost must be smaller than the specied margin.\nPlease adjust the settings and try again" %>% add_class("crtree"))
      } else if (loss2 <= 0) {
        return("Cost must be larger than zero.\nPlease adjust the settings and try again" %>% add_class("crtree"))
      } else {
        parms[["loss"]] <- c(loss1, loss2) %>%
          .[ind] %>%
          {
            matrix(c(0, .[1], .[2], 0), byrow = TRUE, nrow = 2)
          }
      }
    } else if (!is.empty(prior)) {
      if (!is.numeric(prior)) {
        return("Prior did not resolve to a numeric factor" %>% add_class("crtree"))
      } else if (prior > 1 || prior < 0) {
        return("Prior is not a valid probability" %>% add_class("crtree"))
      } else {
        ## prior is applied to the selected level
        parms[["prior"]] <- c(prior, 1 - prior) %>%
          .[ind]
      }
    }
  }

  ## using an input list with do.call ensure that a full "call" is available for cross-validation
  crtree_input <- list(
    formula = as.formula(form),
    data = dataset,
    method = method,
    parms = parms,
    weights = wts,
    control = control
  )

  model <- do.call(rpart::rpart, crtree_input)

  if (!is_not(nodes)) {
    unpruned <- model
    if (nrow(model$frame) > 1) {
      cptab <- as.data.frame(model$cptable, stringsAsFactors = FALSE)
      cptab$nodes <- cptab$nsplit + 1
      ind <- max(which(cptab$nodes <= nodes))
      model <- sshhr(rpart::prune.rpart(model, cp = cptab$CP[ind]))
    }
  } else if (!is_not(pcp)) {
    unpruned <- model
    if (nrow(model$frame) > 1) {
      model <- sshhr(rpart::prune.rpart(model, cp = pcp))
    }
  }

  ## rpart::rpart does not return residuals by default
  model$residuals <- residuals(model, type = "pearson")

  if (is_not(cost) && is_not(margin) &&
    !is.empty(prior) && !is.empty(adjprob)) {

    ## when prior = 0.5 can use pp <- p / (p + (1 - p) * (1 - bp) / bp)
    ## more generally, use Theorem 2 from "The Foundations of Cost-Sensitive Learning" by Charles Elkan
    ## in the equation below prior equivalent of b
    p <- model$frame$yval2[, 4]
    bp <- mean(dataset[[rvar]] == lev)
    model$frame$yval2[, 4] <- bp * (p - p * prior) / (prior - p * prior + bp * p - prior * bp)
    model$frame$yval2[, 5] <- 1 - model$frame$yval2[, 4]
  }

  ## tree model object does not include the data by default
  model$model <- dataset

  ## passing on variable classes for plotting
  model$var_types <- sapply(dataset, class)

  rm(dataset, envir) ## dataset not needed elsewhere

  as.list(environment()) %>% add_class(c("crtree", "model"))
}

#' Summary method for the crtree function
#'
#' @details See \url{https://radiant-rstats.github.io/docs/model/crtree.html} for an example in Radiant
#'
#' @param object Return value from \code{\link{crtree}}
#' @param prn Print tree in text form
#' @param splits Print the tree splitting metrics used
#' @param cptab Print the cp table
#' @param modsum Print the model summary
#' @param ... further arguments passed to or from other methods
#'
#' @examples
#' result <- crtree(titanic, "survived", c("pclass", "sex"), lev = "Yes")
#' summary(result)
#' result <- crtree(diamonds, "price", c("carat", "color"), type = "regression")
#' summary(result)
#' @seealso \code{\link{crtree}} to generate results
#' @seealso \code{\link{plot.crtree}} to plot results
#' @seealso \code{\link{predict.crtree}} for prediction
#'
#' @export
summary.crtree <- function(object, prn = TRUE, splits = FALSE, cptab = FALSE, modsum = FALSE, ...) {
  if (is.character(object)) {
    return(object)
  }

  if (object$type == "classification") {
    cat("Classification tree")
  } else {
    cat("Regression tree")
  }
  cat("\nData                 :", object$df_name)
  if (!is.empty(object$data_filter)) {
    cat("\nFilter               :", gsub("\\n", "", object$data_filter))
  }
  if (!is.empty(object$arr)) {
    cat("\nArrange              :", gsub("\\n", "", object$arr))
  }
  if (!is.empty(object$rows)) {
    cat("\nSlice                :", gsub("\\n", "", object$rows))
  }
  cat("\nResponse variable    :", object$rvar)
  if (object$type == "classification") {
    cat("\nLevel                :", object$lev, "in", object$rvar)
  }
  cat("\nExplanatory variables:", paste0(object$evar, collapse = ", "), "\n")
  if (length(object$wtsname) > 0) {
    cat("Weights used         :", object$wtsname, "\n")
  }
  cat("Complexity parameter :", object$cp, "\n")
  cat("Minimum observations :", object$minsplit, "\n")
  if (!is_not(object$nodes)) {
    max_nodes <- sum(object$unpruned$frame$var == "<leaf>")
    cat("Maximum nr. nodes    :", object$nodes, "out of", max_nodes, "\n")
  }
  if (!is.empty(object$cost) && !is.empty(object$margin) && object$type == "classification") {
    cat("Cost:Margin          :", object$cost, ":", object$margin, "\n")
    if (!is.empty(object$prior)) object$prior <- "Prior ignored when cost and margin set"
  }
  if (!is.empty(object$prior) && object$type == "classification") {
    cat("Priors               :", object$prior, "\n")
    cat("Adjusted prob.       :", object$adjprob, "\n")
  }
  if (!is.empty(object$wts, "None") && inherits(object$wts, "integer")) {
    cat("Nr obs               :", format_nr(sum(object$wts), dec = 0), "\n\n")
  } else {
    cat("Nr obs               :", format_nr(length(object$rv), dec = 0), "\n\n")
  }

  ## extra output
  if (splits) {
    print(object$model$split)
  }

  if (cptab) {
    print(object$model$cptable)
  }

  if (modsum) {
    object$model$call <- NULL
    print(summary(object$model))
  } else if (prn) {
    cat(paste0(capture.output(print(object$model))[c(-1, -2)], collapse = "\n"))
  }
}

#' Plot method for the crtree function
#'
#' @details See \url{https://radiant-rstats.github.io/docs/model/crtree.html} for an example in Radiant. The standard tree plot used by by the rpart package can be generated by \code{plot.rpart(result$model)}. See \code{\link{plot.rpart}} for additional details.
#'
#' @param x Return value from \code{\link{crtree}}
#' @param plots Plots to produce for the specified rpart tree. "tree" shows a tree diagram. "prune" shows a line graph to evaluate appropriate tree pruning. "imp" shows a variable importance plot
#' @param orient Plot orientation for tree: LR for vertical and TD for horizontal
#' @param width Plot width in pixels for tree (default is "900px")
#' @param labs Use factor labels in plot (TRUE) or revert to default letters used by tree (FALSE)
#' @param nrobs Number of data points to show in dashboard scatter plots (-1 for all)
#' @param dec Decimal places to round results to
#' @param incl Which variables to include in a coefficient plot or PDP plot
#' @param incl_int Which interactions to investigate in PDP plots
#' @param shiny Did the function call originate inside a shiny app
#' @param custom Logical (TRUE, FALSE) to indicate if ggplot object (or list of ggplot objects) should be returned. This option can be used to customize plots (e.g., add a title, change x and y labels, etc.). See examples and \url{https://ggplot2.tidyverse.org} for options.
#' @param ... further arguments passed to or from other methods
#'
#' @examples
#' result <- crtree(titanic, "survived", c("pclass", "sex"), lev = "Yes")
#' plot(result)
#' result <- crtree(diamonds, "price", c("carat", "clarity", "cut"))
#' plot(result, plots = "prune")
#' result <- crtree(dvd, "buy", c("coupon", "purch", "last"), cp = .01)
#' plot(result, plots = "imp")
#'
#' @importFrom DiagrammeR DiagrammeR mermaid
#' @importFrom rlang .data
#'
#' @seealso \code{\link{crtree}} to generate results
#' @seealso \code{\link{summary.crtree}} to summarize results
#' @seealso \code{\link{predict.crtree}} for prediction
#'
#' @export
plot.crtree <- function(x, plots = "tree", orient = "LR",
                        width = "900px", labs = TRUE,
                        nrobs = Inf, dec = 2,
                        incl = NULL, incl_int = NULL,
                        shiny = FALSE, custom = FALSE, ...) {
  if (is.empty(plots) || "tree" %in% plots) {
    if ("character" %in% class(x)) {
      return(paste0("graph LR\n A[\"", x, "\"]") %>% DiagrammeR::DiagrammeR(.))
    }

    ## avoid error when dec is NA or NULL
    if (is_not(dec)) dec <- 2

    df <- x$frame
    if (is.null(df)) {
      df <- x$model$frame ## make it easier to call from code
      type <- x$model$method
      nlabs <- labels(x$model, minlength = 0, collapse = FALSE)
    } else {
      nlabs <- labels(x, minlength = 0, collapse = FALSE)
      type <- x$method
    }

    if (nrow(df) == 1) {
      return(paste0("graph LR\n A[Graph unavailable for single node tree]") %>% DiagrammeR::DiagrammeR(.))
    }

    if (type == "class") {
      df$yval <- format_nr(df$yval2[, 4], dec = dec, perc = TRUE)
      pre <- "<b>p:</b> "
    } else {
      df$yval <- round(df$yval, dec)
      pre <- "<b>b:</b> "
    }
    df$yval2 <- NULL

    df$id <- as.integer(rownames(df))
    df$to1 <- NA
    df$to2 <- NA
    non_leafs <- which(df$var != "<leaf>")
    df$to1[non_leafs] <- df$id[non_leafs + 1]
    df$to2[non_leafs] <- df$to1[non_leafs] + 1
    df <- gather(df, "level", "to", !!c("to1", "to2"))

    df$split1 <- nlabs[, 1]
    df$split2 <- nlabs[, 2]

    isInt <- x$model$var_types %>%
      (function(x) names(x)[x == "integer"])
    if (length(isInt) > 0) {
      # inspired by https://stackoverflow.com/a/35556288/1974918
      int_labs <- function(x) {
        paste(
          gsub("(>=|<)\\s*(-{0,1}[0-9]+.*)", "\\1", x),
          gsub("(>=|<)\\s*(-{0,1}[0-9]+.*)", "\\2", x) %>%
            as.numeric() %>%
            ceiling(.)
        )
      }
      int_ind <- df$var %in% isInt
      df[int_ind, "split1"] %<>% int_labs()
      df[int_ind, "split2"] %<>% int_labs()
    }

    df$split1_full <- df$split1
    df$split2_full <- df$split2

    bnr <- 20
    df$split1 <- df$split1 %>% ifelse(nchar(.) > bnr, paste0(strtrim(., bnr), " ..."), .)
    df$split2 <- df$split2 %>% ifelse(nchar(.) > bnr, paste0(strtrim(., bnr), " ..."), .)

    df$to <- as.integer(df$to)
    df$edge <- ifelse(df$level == "to1", df$split1, df$split2) %>%
      (function(x) paste0(" --- |", x, "|"))
    ## seems like only unicode letters are supported in mermaid at this time
    # df$edge <- ifelse (df$level == "to1", df$split1, df$split2) %>% {paste0("--- |", sub("^>", "\u2265",.), "|")}
    # df$edge <- iconv(df$edge, "UTF-8", "ASCII",sub="")
    non_leafs <- which(df$var != "<leaf>")
    df$from <- NA
    df$from[non_leafs] <- paste0("id", df$id[non_leafs], "[", df$var[non_leafs], "]")

    df$to_lab <- NA
    to_lab <- sapply(df$to[non_leafs], function(x) which(x == df$id))[1, ]
    df$to_lab[non_leafs] <- paste0("id", df$to[non_leafs], "[", ifelse(df$var[to_lab] == "<leaf>", "", paste0(df$var[to_lab], "<br>")), "<b>n:</b> ", format_nr(df$n[to_lab], dec = 0), "<br>", pre, df$yval[to_lab], "]")
    df <- na.omit(df)

    leafs <- paste0("id", base::setdiff(df$to, df$id))

    ## still need the below setup to keep the "chance" class
    ## when a decision analysis plot is in the report
    # style <- paste0(
    #   "classDef default fill:none, bg:none, stroke-width:0px;
    #   classDef leaf fill:#9ACD32,stroke:#333,stroke-width:1px;
    #   class ", paste(leafs, collapse = ","), " leaf;"
    # )

    ## still need this to keep the "chance" class
    ## when a decision analysis plot is in the report
    style <- paste0(
      "classDef default fill:none, bg:none, stroke-width:0px;
      classDef leaf fill:#9ACD32,stroke:#333,stroke-width:1px;
      classDef chance fill:#FF8C00,stroke:#333,stroke-width:1px;
      classDef chance_with_cost fill:#FF8C00,stroke:#333,stroke-width:3px,stroke-dasharray:4,5;
      classDef decision fill:#9ACD32,stroke:#333,stroke-width:1px;
      classDef decision_with_cost fill:#9ACD32,stroke:#333,stroke-width:3px,stroke-dasharray:4,5;
      class ", paste(leafs, collapse = ","), " leaf;"
    )

    ## check orientation for branch labels
    brn <- if (orient %in% c("LR", "RL")) {
      c("top", "bottom")
    } else {
      c("left", "right")
    }
    brn <- paste0("<b>", brn, ":</b> ")

    ## don't print full labels that don't add information
    df[df$split1_full == df$split1 & df$split2_full == df$split2, c("split1_full", "split2_full")] <- ""
    df$split1_full <- ifelse(df$split1_full == "", "", paste0("<br>", brn[1], gsub(",", ", ", df$split1_full)))
    df$split2_full <- ifelse(df$split2_full == "", "", paste0("<br>", brn[2], gsub(",", ", ", df$split2_full)))

    ttip_ind <- 1:(nrow(df) / 2)
    ttip <- df[ttip_ind, , drop = FALSE] %>%
      {
        paste0("click id", .$id, " callback \"<b>n:</b> ", format_nr(.$n, dec = 0), "<br>", pre, .$yval, .$split1_full, .$split2_full, "\"", collapse = "\n")
      }

    ## try to link a tooltip directly to an edge using mermaid
    ## see https://github.com/rich-iannone/DiagrammeR/issues/267
    # ttip_lev <- filter(df[-ttip_ind,], split1_full != "")
    # if (nrow(ttip_lev) == 0) {
    #   ttip_lev <- ""
    # } else {
    #   ttip_lev <- paste0("click id", ttip_lev$id, " callback \"", ttip_lev$split1_full, "\"", collapse = "\n")
    # }

    paste(paste0("graph ", orient), paste(paste0(df$from, df$edge, df$to_lab), collapse = "\n"), style, ttip, sep = "\n") %>%
      DiagrammeR::mermaid(., width = width, height = "100%")
  } else {
    if ("character" %in% class(x)) {
      return(x)
    }
    plot_list <- list()
    nrCol <- 1
    if ("prune" %in% plots) {
      if (is.null(x$unpruned)) {
        df <- data.frame(x$model$cptable, stringsAsFactors = FALSE)
      } else {
        df <- data.frame(x$unpruned$cptable, stringsAsFactors = FALSE)
      }

      if (nrow(df) < 2) {
        return("Evaluation of tree pruning not available for single node tree")
      }

      df$CP <- sqrt(df$CP * c(Inf, head(df$CP, -1))) %>% round(5)
      df$nsplit <- as.integer(df$nsplit + 1)
      ind1 <- min(which(df$xerror == min(df$xerror)))
      size1 <- c(df$nsplit[ind1], df$CP[ind1])
      ind2 <- min(which(df$xerror < (df$xerror[ind1] + df$xstd[ind1])))
      size2 <- c(df$nsplit[ind2], df$CP[ind2])

      p <- ggplot(data = df, aes(x = .data$nsplit, y = .data$xerror)) +
        geom_line() +
        geom_vline(xintercept = size1[1], linetype = "dashed") +
        geom_hline(yintercept = min(df$xerror), linetype = "dashed") +
        labs(
          title = "Evaluate tree pruning based on cross-validation",
          x = "Number of nodes",
          y = "Relative error"
        )
      if (nrow(df) < 10) p <- p + scale_x_continuous(breaks = df$nsplit)

      ## http://stats.stackexchange.com/questions/13471/how-to-choose-the-number-of-splits-in-rpart
      if (size1[1] == Inf) size1[2] <- NA
      footnote <- paste0("\nMinimum error achieved at prune complexity ", format(size1[2], scientific = FALSE), " (", size1[1], " nodes)")
      ind2 <- min(which(df$xerror < (df$xerror[ind1] + df$xstd[ind1])))
      p <- p +
        geom_vline(xintercept = size2[1], linetype = "dotdash", color = "blue") +
        geom_hline(yintercept = df$xerror[ind1] + df$xstd[ind1], linetype = "dotdash", color = "blue")

      if (size2[1] < size1[1]) {
        footnote <- paste0(footnote, ".\nError at pruning complexity ", format(size2[2], scientific = FALSE), " (", size2[1], " nodes) is within one std. of minimum")
      }

      plot_list[["prune"]] <- p + labs(caption = footnote)
    }

    if ("imp" %in% plots) {
      imp <- x$model$variable.importance
      if (is.null(imp)) {
        return("Variable importance information not available for singlenode tree")
      }

      df <- data.frame(
        vars = names(imp),
        imp = imp / sum(imp),
        stringsAsFactors = FALSE
      ) %>%
        arrange_at(.vars = "imp")
      df$vars <- factor(df$vars, levels = df$vars)

      plot_list[["imp"]] <-
        visualize(df, yvar = "imp", xvar = "vars", type = "bar", custom = TRUE) +
        labs(
          title = "Variable importance",
          x = "",
          y = "Importance"
        ) +
        coord_flip() +
        theme(axis.text.y = element_text(hjust = 0))
    }

    if ("vip" %in% plots) {
      # imp <- x$model$variable.importance
      # if (is.null(imp)) {
      #   return("Variable importance information not available for singlenode tree")
      # } else {
      vi_scores <- varimp(x)
      plot_list[["vip"]] <-
        visualize(vi_scores, yvar = "Importance", xvar = "Variable", type = "bar", custom = TRUE) +
        labs(
          title = "Permutation Importance",
          x = NULL,
          y = ifelse(x$type == "regression", "Importance (R-square decrease)", "Importance (AUC decrease)")
        ) +
        coord_flip() +
        theme(axis.text.y = element_text(hjust = 0))
    }

    if ("pred_plot" %in% plots) {
      nrCol <- 2
      if (length(incl) > 0 | length(incl_int) > 0) {
        plot_list <- pred_plot(x, plot_list, incl, incl_int, ...)
      } else {
        return("Select one or more variables to generate Prediction plots")
      }
    }

    if ("pdp" %in% plots) {
      nrCol <- 2
      if (length(incl) > 0 || length(incl_int) > 0) {
        plot_list <- pdp_plot(x, plot_list, incl, incl_int, ...)
      } else {
        return("Select one or more variables to generate Partial Dependence Plots")
      }
    }

    if (x$type == "regression" && "dashboard" %in% plots) {
      plot_list <- plot.regress(x, plots = "dashboard", lines = "line", nrobs = nrobs, custom = TRUE)
      nrCol <- 2
    }

    if (length(plot_list) > 0) {
      if (custom) {
        if (length(plot_list) == 1) plot_list[[1]] else plot_list
      } else {
        patchwork::wrap_plots(plot_list, ncol = nrCol) %>%
          (function(x) if (isTRUE(shiny)) x else print(x))
      }
    }
  }
}

#' Predict method for the crtree function
#'
#' @details See \url{https://radiant-rstats.github.io/docs/model/crtree.html} for an example in Radiant
#'
#' @param object Return value from \code{\link{crtree}}
#' @param pred_data Provide the dataframe to generate predictions (e.g., titanic). The dataset must contain all columns used in the estimation
#' @param pred_cmd Generate predictions using a command. For example, `pclass = levels(pclass)` would produce predictions for the different levels of factor `pclass`. To add another variable, create a vector of prediction strings, (e.g., c('pclass = levels(pclass)', 'age = seq(0,100,20)')
#' @param conf_lev Confidence level used to estimate confidence intervals (.95 is the default)
#' @param se Logical that indicates if prediction standard errors should be calculated (default = FALSE)
#' @param dec Number of decimals to show
#' @param envir Environment to extract data from
#' @param ... further arguments passed to or from other methods
#'
#' @examples
#' result <- crtree(titanic, "survived", c("pclass", "sex"), lev = "Yes")
#' predict(result, pred_cmd = "pclass = levels(pclass)")
#' result <- crtree(titanic, "survived", "pclass", lev = "Yes")
#' predict(result, pred_data = titanic) %>% head()
#' @seealso \code{\link{crtree}} to generate the result
#' @seealso \code{\link{summary.crtree}} to summarize results
#'
#' @export
predict.crtree <- function(object, pred_data = NULL, pred_cmd = "",
                           conf_lev = 0.95, se = FALSE, dec = 3,
                           envir = parent.frame(), ...) {
  if (is.character(object)) {
    return(object)
  }
  if (is.data.frame(pred_data)) {
    df_name <- deparse(substitute(pred_data))
  } else {
    df_name <- pred_data
  }

  pfun <- function(model, pred, se, conf_lev) {
    pred_val <- try(sshhr(predict(model, pred)), silent = TRUE)

    if (!inherits(pred_val, "try-error")) {
      pred_val %<>%
        as.data.frame(stringsAsFactors = FALSE) %>%
        select(1) %>%
        set_colnames("Prediction")
    }

    pred_val
  }

  predict_model(object, pfun, "crtree.predict", pred_data, pred_cmd, conf_lev, se, dec, envir = envir) %>%
    set_attr("radiant_pred_data", df_name)
}

#' Print method for predict.crtree
#'
#' @param x Return value from prediction method
#' @param ... further arguments passed to or from other methods
#' @param n Number of lines of prediction results to print. Use -1 to print all lines
#'
#' @export
print.crtree.predict <- function(x, ..., n = 10) {
  print_predict_model(x, ..., n = n, header = "Classification and regression trees")
}

#' Cross-validation for Classification and Regression Trees
#'
#' @details See \url{https://radiant-rstats.github.io/docs/model/crtree.html} for an example in Radiant
#'
#' @param object Object of type "rpart" or "crtree" to use as a starting point for cross validation
#' @param K Number of cross validation passes to use
#' @param repeats Number of times to repeat the K cross-validation steps
#' @param cp Complexity parameter used when building the (e.g., 0.0001)
#' @param pcp Complexity parameter to use for pruning
#' @param seed Random seed to use as the starting point
#' @param trace Print progress
#' @param fun Function to use for model evaluation (e.g., auc for classification or RMSE for regression)
#' @param ... Additional arguments to be passed to 'fun'
#'
#' @return A data.frame sorted by the mean, sd, min, and max of the performance metric
#'
#' @seealso \code{\link{crtree}} to generate an initial model that can be passed to cv.crtree
#' @seealso \code{\link{Rsq}} to calculate an R-squared measure for a regression
#' @seealso \code{\link{RMSE}} to calculate the Root Mean Squared Error for a regression
#' @seealso \code{\link{MAE}} to calculate the Mean Absolute Error for a regression
#' @seealso \code{\link{auc}} to calculate the area under the ROC curve for classification
#' @seealso \code{\link{profit}} to calculate profits for classification at a cost/margin threshold
#'
#' @importFrom rpart prune.rpart
#' @importFrom shiny getDefaultReactiveDomain withProgress incProgress
#'
#' @examples
#' \dontrun{
#' result <- crtree(dvd, "buy", c("coupon", "purch", "last"))
#' cv.crtree(result, cp = 0.0001, pcp = seq(0, 0.01, length.out = 11))
#' cv.crtree(result, cp = 0.0001, pcp = c(0, 0.001, 0.002), fun = profit, cost = 1, margin = 5)
#' result <- crtree(diamonds, "price", c("carat", "color", "clarity"), type = "regression", cp = 0.001)
#' cv.crtree(result, cp = 0.001, pcp = seq(0, 0.01, length.out = 11), fun = MAE)
#' }
#'
#' @export
cv.crtree <- function(object, K = 5, repeats = 1, cp, pcp = seq(0, 0.01, length.out = 11), seed = 1234, trace = TRUE, fun, ...) {
  if (inherits(object, "crtree")) object <- object$model
  if (inherits(object, "rpart")) {
    dv <- as.character(object$call$formula[[2]])
    m <- eval(object$call[["data"]])
    if (is.numeric(m[[dv]])) {
      type <- "regression"
    } else {
      type <- "classification"
      if (is.factor(m[[dv]])) {
        lev <- levels(m[[dv]])[1]
      } else if (is.logical(m[[dv]])) {
        lev <- TRUE
      } else {
        stop("The level to use for classification is not clear. Use a factor of logical as the response variable")
      }
    }
  } else {
    stop("The model object does not seems to be a decision tree")
  }

  set.seed(seed)
  if (missing(cp)) cp <- object$call$control$cp
  tune_grid <- expand.grid(cp = cp, pcp = pcp)
  out <- data.frame(mean = NA, std = NA, min = NA, max = NA, cp = tune_grid[["cp"]], pcp = tune_grid[["pcp"]])

  if (missing(fun)) {
    if (type == "classification") {
      fun <- radiant.model::auc
      cn <- "AUC (mean)"
    } else {
      fun <- radiant.model::RMSE
      cn <- "RMSE (mean)"
    }
  } else {
    cn <- glue("{deparse(substitute(fun))} (mean)")
  }

  if (length(shiny::getDefaultReactiveDomain()) > 0) {
    trace <- FALSE
    incProgress <- shiny::incProgress
    withProgress <- shiny::withProgress
  } else {
    incProgress <- function(...) {}
    withProgress <- function(...) list(...)[["expr"]]
  }

  nitt <- nrow(tune_grid)
  withProgress(message = "Running cross-validation (crtree)", value = 0, {
    for (i in seq_len(nitt)) {
      perf <- double(K * repeats)
      object$call[["cp"]] <- tune_grid[i, "cp"]
      if (trace) cat("Working on cp", format(tune_grid[i, "cp"], scientific = FALSE), "pcp", format(tune_grid[i, "pcp"], scientific = FALSE), "\n")
      for (j in seq_len(repeats)) {
        rand <- sample(K, nrow(m), replace = TRUE)
        for (k in seq_len(K)) {
          object$call[["data"]] <- quote(m[rand != k, , drop = FALSE])
          pred <- try(rpart::prune(eval(object$call), tune_grid[i, "pcp"]), silent = TRUE)
          if (inherits(pred, "try-error")) next
          if (length(object$call$parms$prior) > 0) {
            pred <- prob_adj(pred, object$call$parms$prior[1], mean(m[rand != k, dv, drop = FALSE] == lev))
          }
          pred <- try(predict(pred, m[rand == k, , drop = FALSE]), silent = TRUE)
          if (inherits(pred, "try-error")) next

          if (type == "classification") {
            if (missing(...)) {
              perf[k + (j - 1) * K] <- fun(pred[, lev], unlist(m[rand == k, dv]), lev)
            } else {
              perf[k + (j - 1) * K] <- fun(pred[, lev], unlist(m[rand == k, dv]), lev, ...)
            }
          } else {
            if (missing(...)) {
              perf[k + (j - 1) * K] <- fun(pred, unlist(m[rand == k, dv]))
            } else {
              perf[k + (j - 1) * K] <- fun(pred, unlist(m[rand == k, dv]), ...)
            }
          }
        }
      }
      out[i, 1:4] <- c(mean(perf, na.rm = TRUE), sd(perf, na.rm = TRUE), min(perf, na.rm = TRUE), max(perf, na.rm = TRUE))
      incProgress(1 / nitt, detail = paste("\nCompleted run", i, "out of", nitt))
    }
  })

  if (type == "classification") {
    out <- arrange(out, desc(mean))
  } else {
    out <- arrange(out, mean)
  }

  ## show evaluation metric in column name
  colnames(out)[1] <- cn

  object$call[["cp"]] <- out[1, "cp"]
  object$call[["data"]] <- m
  object <- rpart::prune(eval(object$call), out[1, "pcp"])
  cat("\nGiven the provided tuning grid, the pruning complexity parameter\nshould be set to", out[1, "pcp"], "or the number of nodes set to", max(object$cptable[, "nsplit"]) + 1, "\n")
  out
}

prob_adj <- function(mod, prior, bp) {
  p <- mod$frame$yval2[, 4]
  mod$frame$yval2[, 4] <- bp * (p - p * prior) / (prior - p * prior + bp * p - prior * bp)
  mod$frame$yval2[, 5] <- 1 - mod$frame$yval2[, 4]
  mod
}

Try the radiant.model package in your browser

Any scripts or data that you put into this service are public.

radiant.model documentation built on Oct. 16, 2023, 9:06 a.m.