R/rf.R

Defines functions rf

Documented in rf

#' Runs random forest with grid-search for hyper parameters.
#' 
#' @param formula Formula for model specification.
#' @param train_df An input dataframe with \code{y} and \code{X}.
#' @param probability Logical. Whether predicted values are probabilities or
#'   \code{0, 1} values.
#' @param predict_df (Optional) A dataframe matching \code{train_df}.
#'   This is to generate predictions using the trained & tested model.
#' @param mtry (Optional) Numeric vector including all values to try.
#'   Defines number of variables available for splitting at each tree node.
#' @param node_size (Optional) Numeric vector including all values to try.
#'   Defines minimum number of observations in a terminal node.
#' @param num_trees (Optional) Numeric vector including all values to try.
#'   Defines number of trees to grow.
#' @param nfolds (Optional) Numeric value. Use to specify number of CV folds.
#' @param error_type (Optional) String of either "CV" or "OOB" for error
#'   type to use for choosing optimal hyper parameters.
#' @param verbose (Optional) Logical. Whether to print progress or not.
#' @examples 
#' \dontrun{
#' idx <- train_test_validate(iris$Sepal.Length, train.p = .6, test.p = .2)
#' 
#' initialize_parallel()
#'
#' rf_model <- rf(train_df = iris[idx$train, ],
#'                formula = Sepal.Length ~ .,
#'                probability = FALSE,
#'                predict_df = iris[idx$validate, ])
#' }
#' @export
rf <- function(formula,
               train_df,
               probability,
               predict_df = NULL,
               mtry = NULL,
               node_size = NULL,
               num_trees = NULL,
               nfolds = NULL,
               error_type = "OOB",
               verbose = FALSE) {
  # Error checking
  assert(error_type %in% c("OOB", "CV"),
         "Argument 'error_type' must be either 'OOB' or 'CV'")
  if(error_type == "CV") {
    cat("Cross-validation is not typically recommended for RF.",
        "Performance will be much slower in many cases.\n")
  }
  if(!is.null(nfolds) & 
     error_type == "OOB") {
    assert(nfolds > 0, "Argument nfolds must be NULL or > 0")
    error_type <- "CV"
    if(verbose == TRUE) {
      cat("Argument nfolds is non-null. Setting error_type to 'CV'\n")
    }
  }
  # Split out data by train and test set
  x <- dplyr::select(train_df, -!!formula_lhs(formula))
  y <- dplyr::pull(train_df, formula_lhs(formula))
  if(!is.null(predict_df)) {
    predict_x <- dplyr::select(predict_df, -!!formula_lhs(formula))
    predict_y <- dplyr::pull(predict_df, formula_lhs(formula))
  }
  # Run models across tuning grid if specified
  # If not, run model with default values
  if(!is.null(mtry) & !is.null(node_size) & !is.null(num_trees)) {
    grid <- as.data.frame(
      t(expand.grid(mtry = mtry,
                    node_size = node_size,
                    num_trees = num_trees))
    )
    if(error_type == "OOB") {
      models <- future.apply::future_lapply(
        grid,
        function(i) {
          model <- ranger::ranger(y ~ .,
                                  data = x,
                                  probability = probability,
                                  mtry = i[1],
                                  min.node.size = i[2],
                                  num.trees = i[3],
                                  verbose = FALSE)
          error <- model$prediction.error
          return(get("error"))
        }
      )
    } else {
      assert(is.numeric(nfolds) & 
               nfolds > 0,
             "Argument nfolds must be > 0 for cross validation\n")
      models <- future.apply::future_lapply(
        grid,
        function(i) {
          folds <- caret::createFolds(as_numeric(y), k = nfolds)
          cv_errors <- lapply(folds, function(j) {
            model <- ranger::ranger(y[-j] ~ .,
                                    data = x[-j, ],
                                    probability = probability,
                                    mtry = i[1],
                                    min.node.size = i[2],
                                    num.trees = i[3],
                                    verbose = FALSE)
            preds <- if(is.factor(y)) {
              stats::predict(model, x[j, ])$predictions[, 2]
            } else {
              stats::predict(model, x[j, ])$predictions
            }
            error <- caret::RMSE(preds, as_numeric(y[j]))
            return(get("error"))
          })
          error <- mean(unlist(cv_errors))
          return(get("error"))
        }
      )
    }
    grid_errors <- dplyr::arrange(
      dplyr::bind_cols(
        tibble::as_tibble(t(grid)),
        error <- unlist(models)
      ),
      get("error")
    )
    hyper_params <- tibble::as_tibble(t(grid))[which.min(unlist(models)), ]
    final_model <- ranger::ranger(y ~ .,
                                  data = x,
                                  probability = probability,
                                  mtry = hyper_params$mtry,
                                  min.node.size = hyper_params$node_size,
                                  num.trees = hyper_params$num_trees)
  } else {
    if(verbose == TRUE) {
      cat("At least one of mtry, node_size, and num_trees are null,",
          "so using default ranger values\n")
    }
    final_model <- ranger::ranger(y ~ .,
                                  data = x,
                                  probability = probability)
  }
  values <- if(!is.null(predict_df)) {
    if(is.factor(y)) {
      stats::predict(final_model, predict_x)$predictions[, 2]
    } else {
      stats::predict(final_model, predict_x)$predictions
    }
  }
  # List of outs
  out <- list(model = final_model)
  if(!is.null(predict_df)) out <- append(out, list(values = as.vector(values)))
  if(!is.null(mtry) & !is.null(node_size) & !is.null(num_trees)) {
    out <- append(out,
                  list(
                    grid = grid_errors,
                    which_min = hyper_params
                  ))
  }
  return(out)
}
dmolitor/umbrella documentation built on Nov. 10, 2020, 1:25 a.m.