R/beset_rf.R

Defines functions beset_rf

Documented in beset_rf

#' Beset Random Forest
#'
#' \code{beset_rf} is a wrapper to \code{\link[randomForest]{randomForest}} that
#' estimates predictive performance of the random forest using repeated k-fold
#' cross-validation. \code{beset_rf} insures that the correct arguments are
#' provided to \code{\link[randomForest]{randomForest}} and that enough
#' information is retained for compatibility with \code{beset} methods such as
#' variable \code{\link{importance}} and partial \code{\link{dependence}}.
#'
#' @param n_trees Number of trees. Defaults to 500.
#'
#' @param sample_rate Row sample rate per tree (from \code{0 to 1}). Defaults to
#' \code{1 - exp(1), or ~ 0.632}.
#'
#' @param mtry (Optional) \code{integer} number of variables randomly sampled
#' as candidates at each split. If omitted, defaults to the square root of the
#' number of predictors for classification and one-third the number of
#' predictors for regression.
#'
#' @param min_obs_in_node (Optional) \code{integer} number specifying the
#' fewest allowed observations in a terminal node. If omitted, defaults to 1 for
#' classification and 5 for regression.
#'
#' @param class_wt Priors of the classes. Ignored for regression.
#'
#' @param x A \code{"beset_rf"} object to plot
#'
#' @param metric Prediction metric to plot. Options are mean squared error
#' (\code{"mse"}) or R-squared (\code{"rsq"}) for regression, and
#' misclassification error (\code{"err.rate"}) for classification. Default
#' \code{"auto"} plots MSE for regression and error rate for classification.
#'
#' @inheritParams beset_glm
#' @inheritParams randomForest::randomForest
#'
#' @return A "beset_rf" object with the following components:
#' \describe{
#'   \item{forests}{list of "randomForest" objects for each fold and repetition}
#'   \item{stats}{a "cross_valid" object giving cross-validation metrics}
#'   \item{data}{the data frame used to train random forest}
#'  }
#'
#' @examples
#' # Using default 10 X 10 repeated k-fold cross-validation
#' data("prostate", package = "beset")
#' rf <- beset_rf(tumor ~ ., data = prostate)
#' summary(rf)
#' plot(rf)
#'
#' # Using a single independent test set instead of cross-validation
#' inTrain <- sample.int(nrow(prostate), nrow(prostate)/2)
#' data <- data_partition(
#'   train = prostate[inTrain,], test = prostate[-inTrain,], y = "tumor"
#' )
#' rf <- beset_rf(tumor ~ ., data = data)
#' summary(rf)
#' plot(rf)
#'
#' # Example with continuous outcome
#' rf <- beset_rf(gleason ~ ., data = data)
#' summary(rf)
#' plot(rf)
#' @import dplyr
#' @import purrr
#' @export
beset_rf <- function(form, data, n_trees = 500, sample_rate = 1 - exp(-1),
                     mtry = NULL, min_obs_in_node = NULL,
                     n_folds = 10, n_reps = 10, seed = 42,
                     class_wt = NULL, cutoff = NULL, strata = NULL,
                     parallel_type = NULL, n_cores = NULL, cl = NULL){
  if(inherits(data, "data.frame")){
    data <- mutate_if(data, is.logical, factor)
    mf <- model.frame(form, data = data)
    n_omit <- nrow(data) - nrow(mf)
    if(n_omit > 0){
      warning(paste("Dropping", n_omit, "rows with missing data."),
              immediate. = TRUE)
      attr(data, "na.action") <- NULL
    }
    attr(mf, "terms") <- NULL
    data <- mf
    x <- mf[-1]
    y <- mf[[1]]
    names(y) <- row.names(mf)
  } else if(inherits(data, "data_partition")){
    data$train$`(offset)` <- data$train$`(weights)` <-
      data$test$`(offset)` <- NULL
    data$train <- mutate_if(data$train, is.logical, factor)
    data$test <- mutate_if(data$test, is.logical, factor)
    data$train <- model.frame(form, data = data$train)
    attr(data$train, "terms") <- NULL
    data$test<- model.frame(form, data = data$test)
    attr(data$test, "terms") <- NULL
    x <- data$train[-1]
    y <- data$train[[1]]
  } else {
    stop("`data` argument must inherit class 'data.frame' or 'data_partition'")
  }

  #======================================================================
  # Set up x and y arguments for randomForest.default
  #----------------------------------------------------------------------
  if(is.factor(y) && n_distinct(y) > 2){
      stop(
        paste("Multinomial classification not supported by beset at this time.",
              "Please use the `randomForest` package directly for this.",
              sep = "\n")
      )
  }
  if(!is.factor(y) && n_distinct(y) == 2){
    y <- factor(y)
    if(inherits(data, "data_partition")){
      data$train[[1]] <- factor(data$train[[1]])
      data$test[[1]] <- factor(data$test[[1]])
    }
  }
  type <- if(is.factor(y)) "prob" else "response"
  y_orig <- y

  #======================================================================
  # Set up arguments for randomForest
  #----------------------------------------------------------------------
  rf_par <- list(
    ntree = n_trees,
    mtry = if(is.null(mtry)){
      if(is.factor(y)){
        floor(sqrt(ncol(x)))
      } else {
        max(floor(ncol(x)/3), 1)
      }
    } else mtry,
    replace = FALSE, classwt = class_wt,
    cutoff = if(is.null(cutoff) && is.factor(y)){
      n_class <- length(levels(y))
      rep(1/n_class, n_class)
    } else cutoff, strata = strata,
    sampsize = ceiling(sample_rate * nrow(x)),
    nodesize = if(is.null(min_obs_in_node)){
      if(!is.null(y) && !is.factor(y)) 5 else 1
    } else min_obs_in_node,
    importance = TRUE, localImp = TRUE, keep.forest = TRUE
  )

  if(inherits(data, "data_partition")){
    train_data <- c(
      list(x = x, y = y, xtest = data$test[-1], ytest = data$test[[1]]), rf_par
    )
    set.seed(seed)
    fit <- do.call(randomForest::randomForest, train_data)
    y_hat <- predict(fit, data$test, type = type)
    n_obs <- length(y_hat)
    if(is.factor(y)){
      y_hat <- y_hat[, 2]
      y_hat[y_hat == 0] <- 1/n_obs
      y_hat[y_hat == 1] <- 1 - 1/n_obs
      family <- "binomial"
    } else {
      family <- "gaussian"
    }
    y <- data$test[[1]]
    stats <- predict_metrics_(y, y_hat, family = family)
    return(
      structure(
        list(
          forests = list(fit),
          stats = stats,
          data = data$train
        ), class = c("beset", "rf")
      )
    )
  }
  #======================================================================
  # Set up parallel operations
  #----------------------------------------------------------------------
  if(!is.null(cl)){
    if(!inherits(cl, "cluster")) stop("Not a valid parallel socket cluster")
    n_cores <- length(cl)
  } else if(is.null(n_cores) || n_cores > 1){
    if(is.null(parallel_type)) parallel_type <- "sock"
    parallel_control <- setup_parallel(
      parallel_type = parallel_type, n_cores = n_cores, cl = cl)
    n_cores <- parallel_control$n_cores
    cl <- parallel_control$cl
  }
  #======================================================================
  # Set up cross-validation
  #----------------------------------------------------------------------
  n_obs <- length(y)
  cv_par <- set_cv_par(n_obs, n_folds, n_reps)
  n_folds <- cv_par$n_folds; n_reps <- cv_par$n_reps
  fold_ids <- create_folds(
    y = y, n_folds = n_folds, n_reps = n_reps, seed = seed
  )
  train_data <- lapply(fold_ids, function(i)
    c(list(x = x[-i, , drop = FALSE],
           y = y[-i],
           xtest = x[i, , drop = FALSE],
           ytest = y[i]),
      rf_par)
  )
  #======================================================================
  # Train and cross-validate random forests
  #----------------------------------------------------------------------
  cv_fits <- if(n_cores > 1L){
    if(is.null(cl)){
      parallel::mclapply(train_data, function(x, seed){
        set.seed(seed); a <- do.call(randomForest::randomForest, x)
      }, seed = seed, mc.cores = n_cores)
    } else {
      parallel::parLapply(cl, train_data, function(x, seed){
        set.seed(seed); do.call(randomForest::randomForest, x)
      }, seed = seed)
    }
  } else {
    lapply(train_data, function(x, seed){
      set.seed(seed); do.call(randomForest::randomForest, x)
    }, seed = seed)
  }
  if(!is.null(cl)) parallel::stopCluster(cl)

  y_hat <- map2(cv_fits, train_data,
                ~ as.matrix(predict(.x, .y$xtest, type = type)))
  if(is.factor(y)){
    y_names <- names(y)
    y <- as.integer(y) - 1L
    names(y) <- y_names
    y_hat <- map(y_hat, ~ .x[, 2, drop = FALSE])
    for(i in seq_along(y_hat)){
      temp <- y_hat[[i]]
      temp[temp[,1] == 0, 1] <- 1/nrow(temp)
      temp[temp[,1] == 1, 1] <- 1 - 1/nrow(temp)
      y_hat[[i]] <- temp
    }
    family <- "binomial"
  } else {
    family <- "gaussian"
  }
  cv_stats <- get_cv_stats(y = y, y_hat = y_hat, family = family,
                           n_folds = n_folds, n_reps = n_reps)
  fold_assignments <- get_fold_ids(fold_ids, n_reps)
  cv_stats <- structure(
    c(cv_stats, list(
      fold_assignments = fold_assignments,
      parameters = list(family = family,
                        n_obs = n_obs,
                        n_folds = n_folds,
                        n_reps = n_reps,
                        seed = seed,
                        y = y_orig))),
    class = "cross_valid"
  )
  structure(
    list(
      forests = cv_fits,
      stats = cv_stats,
      data = data
    ), class = c("beset", "rf")
  )
}

#' @export
#' @describeIn beset_rf Plot OOB and holdout MSE, R-squared, or error rate as a
#'  function of number of trees in forest
#'
plot.beset_rf <- function(x, metric = c("auto", "mse", "rsq", "err.rate"), ...){
  metric <- tryCatch(
    match.arg(metric, c("auto", "mse", "rsq", "err.rate")),
    error = function(c){
      c$message <- gsub("arg", "metric", c$message)
      c$call <- NULL
      stop(c)
    }
  )
  if(metric == "auto"){
    metric <- if(x$forests[[1]]$type == "regression") "mse" else "err.rate"
  }
  if(metric %in% c("mse", "rsq") && x$forests[[1]]$type != "regression"){
    warning(paste(metric, " plots not available for classification.\n",
                  "Misclassification rate plotted instead.", sep = ""))
    metric <- "err.rate"
  }
  oob <- map(x$forests, ~ .x[[metric]])
  if(inherits(oob[[1]], "matrix")) oob <- map(oob, ~ .x[,1])
  oob <- oob %>% transpose %>% simplify_all %>% map_dbl(mean)
  oob <- tibble(
    sample = "Out-of-Bag", n_trees = seq(1:length(oob)), mean = oob
  )
  cv <- map(x$forests, ~ .x$test[[metric]])
  if(inherits(cv[[1]], "matrix")) cv <- map(cv, ~ .x[,1])
  cv <- cv %>% transpose %>% simplify_all %>% map_dbl(mean)
  cv <- tibble(
    sample = "Test Holdout", n_trees = seq(1:length(cv)), mean = cv
  )
  data <- bind_rows(oob, cv)
  y_lab <- switch(metric,
                  err.rate = ylab("Misclassification Rate"),
                  mse = ylab("Mean Squared Error"),
                  rsq = ylab(bquote(~R^2)))
  p <- ggplot(data = data, aes(x = n_trees, y = mean, color = sample)) +
    theme_classic() + xlab("Number of trees") + y_lab +
    geom_line(size = 1) + theme(legend.title = element_blank())
  suppressWarnings(print(p))
}
jashu/beset documentation built on April 20, 2023, 5:28 a.m.