# Functions in script:
# - rang_roc_cut().
# - rang_mtry().
# - rang_oob_err().
######################
#' ROC curve optimal cut for a ranger object
#'
#' Calls `roc_cut()` for a fitted ranger model. See `roc_cut()` for details. Requires
#' `probability = TRUE` to be used when the ranger model was fitted.
#'
#' The `rf$predictions` is a matrix with columns for each class. The class corresponding to
#' the predictions in `actual` can be named in `class_name`. Default is to select the "maximum"
#' column name i.e. "TRUE" if classes are TRUE/FALSE and the largest number if numeric.
#'
#' @param rf A ranger fitted model.
#' @param actual A binary class target vector matching `rf`.
#' @param class_name String giving name of the class predicted in `actual`.
#' @param plot Produce a plot. Defaults to `TRUE`.
#'
#' @export
rang_roc_cut <- function(rf, actual, class_name = NULL, plot = TRUE) {
if (!is.matrix(rf$predictions)){
stop("'rf$predictions' is not a matrix. When fitting the ranger model 'rf', you must use `probability = TRUE`.",
call. = FALSE)
}
if (is.null(class_name)){
class_name <- max(colnames(rf$predictions))
message("Using ", class_name, " as positive class.")
}else{
if (!(class_name %in% colnames(rf$predictions))){
stop("The column '", class_name, "' does not exist in 'rf$predictions'.", call. = FALSE)
}
}
roc_cut(rf$predictions[, class_name], actual, plot = plot)
}
#' Tune ranger models for `mtry`
#'
#' Fits `ranger()` models for a given range of values of `mtry`. Output is a table and graph giving
#' OOB errors.
#'
#' @param data A data frame containing both input and target variables.
#' @param fmla A formula for the model.
#' @param m_vec Integer vector of values of tuning parameter `mtry`.
#' @param seed,num.trees,respect.unordered.factors Optional arguments passed to `ranger()`. Defaults are
#' `1`, `200` and `TRUE` respectively.
#' @param ... Other arguments passed to `ranger`.
#'
#' @export
rang_mtry <- function(data, fmla, m_vec, seed = 1, num.trees = 200, respect.unordered.factors = TRUE, ...) {
if (!is.atomic(m_vec) | !is.numeric(m_vec) | length(m_vec) == 0) stop("`m_vec` must be a non-empty numeric vector.",
call. = FALSE)
n_inputs <- length(attributes(stats::terms(fmla, data = data))$term.labels)
if (max(m_vec) > n_inputs) stop("`mtry` cannot be larger than the number of input variables.\n The max `m_vec` is ",
max(m_vec),
", but there are ",
n_inputs, " input variables in the model.")
target_name <- fmla[[2]]
target <- dplyr::pull(data, !!target_name)
nn <- length(m_vec)
oob_err <- double(nn)
valid_err <- double(nn)
time <- double(nn)
cat("mtry completed: ")
for(i in 1 : nn) {
mtry <- m_vec[i]
time[i] <- system.time(rf <- ranger::ranger(fmla, data = data, num.trees = num.trees, mtry = mtry,
seed = seed,
respect.unordered.factors = respect.unordered.factors,
write.forest = FALSE, ...))[3]
oob_err[i] <- rf$prediction.error
cat(mtry," ")
}
cat("\n")
res <- tibble::tibble(mtry = m_vec, oob_err, time)
g <- ggplot2::ggplot(res, ggplot2::aes(x = .data$mtry, y = .data$oob_err)) +
ggplot2::geom_line() +
ggplot2::geom_point() +
ggplot2::ylab("OOB Error")
print(g)
res
}
#' Helper for `rang_oob_err()`
#'
#' Returns vector of error indicators chosen by majority vote. Where the vote is split the entry wil be 0.5.
#' The entry will be `NaN` Where there are no out-of-bag predictions for a data point.
#'
#' @param oob_mat Matrix of out-of-bag individual tree class predicitions with an `NA` where the prediciton
#' was in bag. Rows are observation, columns are trees.
#' @param n_trees Positive integer. Gives the number of trees to be used in the prediction (columns `1:n_trees`
#' of `oob_mat` will be used).
#' @param target Vector of true classes (either 1 or 2).
#'
#' @keywords internal
oob_errors <- function(oob_mat, n_trees, target) {
oob_mat2 <- oob_mat[, 1:n_trees, drop = FALSE]
err_mat <- sweep(oob_mat2, 1, target, FUN = `!=`)
err_vec <- rowMeans(err_mat, na.rm = TRUE)
splits <- err_vec == 0.5
err_vec <- round(err_vec)
err_vec[splits] <- 0.5
err_vec
}
#' Out-of-bag error rates by number of trees for a ranger random forest
#'
#' Returns a table of out-of-bag error rates for a ranger random forest using number of trees
#' from 5 to all trees in steps of 5. The supplied ranger object must have been created with
#' `keep.inbag = TRUE` argument.
#'
#' @param rf A ranger random forest object, created with `keep.inbag = TRUE`.
#' @param data A data frame used to fit `rf`. Must contain target variable.
#' @param start,by Error rates are evaluated for number of trees from `start` to maximum number
#' of trees in steps of `by`.
#' @param plot Optional logical. Output a plot or not.
#'
#' @export
rang_oob_err <- function(rf, data, start = 5L, by = 5L, plot = TRUE) {
if (!("inbag.counts" %in% names(rf))){
stop("`inbag.counts` not found in `rf`. Must use `keep.inbag = TRUE` when fitting `rf`.",
call. = FALSE)
}
nn <- nrow(data)
ntr <- rf$num.trees
n_trees_vec <- seq(start, ntr, by = by)
fmla <- as.character(rf$call)[2]
target_name <- stringr::str_split(fmla, " ~")[[1]][1]
target <- data[[target_name]] %>% as.numeric() #use classes in {1,2}
# Convert inbag counts (list) to matrix
inbag_mat <- matrix(0, nrow = nn, ncol = ntr)
for (i in 1 : ntr){
inbag_mat[, i] <- rf$inbag.counts[[i]]
}
preds <- c(predict(rf, data, predict.all = TRUE)$predictions)
oob_mat <- dplyr::if_else(c(inbag_mat) > 0, NA_real_, preds) %>%
matrix(nrow = nrow(inbag_mat))
len_ntv <- length(n_trees_vec)
errs <- matrix(0L, nrow = nn, ncol = len_ntv)
for (i in 1: len_ntv){
errs[, i] <- oob_errors(oob_mat, n_trees = n_trees_vec[i], target = target)
}
res <- tibble::tibble(num.trees = n_trees_vec,
total = colMeans(errs, na.rm = TRUE),
class_1 = colMeans(errs[target == 1, ], na.rm = TRUE),
class_2 = colMeans(errs[target == 2, ], na.rm = TRUE)
)
res_long <- tidyr::pivot_longer(res, !c("num.trees"), names_to = "pred", values_to = "error_rate")
if (plot){
g <- ggplot2::ggplot(res_long, ggplot2::aes(x = .data$num.trees, y = .data$error_rate, color = .data$pred)) +
ggplot2::geom_line() +
ggplot2::ylab("OOB Error Rate")
print(g)
}
res
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.