#' Runs random forest with grid-search for hyper parameters.
#'
#' @param formula Formula for model specification.
#' @param train_df An input dataframe with \code{y} and \code{X}.
#' @param probability Logical. Whether predicted values are probabilities or
#' \code{0, 1} values.
#' @param predict_df (Optional) A dataframe matching \code{train_df}.
#' This is to generate predictions using the trained & tested model.
#' @param mtry (Optional) Numeric vector including all values to try.
#' Defines number of variables available for splitting at each tree node.
#' @param node_size (Optional) Numeric vector including all values to try.
#' Defines minimum number of observations in a terminal node.
#' @param num_trees (Optional) Numeric vector including all values to try.
#' Defines number of trees to grow.
#' @param nfolds (Optional) Numeric value. Use to specify number of CV folds.
#' @param error_type (Optional) String of either "CV" or "OOB" for error
#' type to use for choosing optimal hyper parameters.
#' @param verbose (Optional) Logical. Whether to print progress or not.
#' @examples
#' \dontrun{
#' idx <- train_test_validate(iris$Sepal.Length, train.p = .6, test.p = .2)
#'
#' initialize_parallel()
#'
#' rf_model <- rf(train_df = iris[idx$train, ],
#' formula = Sepal.Length ~ .,
#' probability = FALSE,
#' predict_df = iris[idx$validate, ])
#' }
#' @export
rf <- function(formula,
train_df,
probability,
predict_df = NULL,
mtry = NULL,
node_size = NULL,
num_trees = NULL,
nfolds = NULL,
error_type = "OOB",
verbose = FALSE) {
# Error checking
assert(error_type %in% c("OOB", "CV"),
"Argument 'error_type' must be either 'OOB' or 'CV'")
if(error_type == "CV") {
cat("Cross-validation is not typically recommended for RF.",
"Performance will be much slower in many cases.\n")
}
if(!is.null(nfolds) &
error_type == "OOB") {
assert(nfolds > 0, "Argument nfolds must be NULL or > 0")
error_type <- "CV"
if(verbose == TRUE) {
cat("Argument nfolds is non-null. Setting error_type to 'CV'\n")
}
}
# Split out data by train and test set
x <- dplyr::select(train_df, -!!formula_lhs(formula))
y <- dplyr::pull(train_df, formula_lhs(formula))
if(!is.null(predict_df)) {
predict_x <- dplyr::select(predict_df, -!!formula_lhs(formula))
predict_y <- dplyr::pull(predict_df, formula_lhs(formula))
}
# Run models across tuning grid if specified
# If not, run model with default values
if(!is.null(mtry) & !is.null(node_size) & !is.null(num_trees)) {
grid <- as.data.frame(
t(expand.grid(mtry = mtry,
node_size = node_size,
num_trees = num_trees))
)
if(error_type == "OOB") {
models <- future.apply::future_lapply(
grid,
function(i) {
model <- ranger::ranger(y ~ .,
data = x,
probability = probability,
mtry = i[1],
min.node.size = i[2],
num.trees = i[3],
verbose = FALSE)
error <- model$prediction.error
return(get("error"))
}
)
} else {
assert(is.numeric(nfolds) &
nfolds > 0,
"Argument nfolds must be > 0 for cross validation\n")
models <- future.apply::future_lapply(
grid,
function(i) {
folds <- caret::createFolds(as_numeric(y), k = nfolds)
cv_errors <- lapply(folds, function(j) {
model <- ranger::ranger(y[-j] ~ .,
data = x[-j, ],
probability = probability,
mtry = i[1],
min.node.size = i[2],
num.trees = i[3],
verbose = FALSE)
preds <- if(is.factor(y)) {
stats::predict(model, x[j, ])$predictions[, 2]
} else {
stats::predict(model, x[j, ])$predictions
}
error <- caret::RMSE(preds, as_numeric(y[j]))
return(get("error"))
})
error <- mean(unlist(cv_errors))
return(get("error"))
}
)
}
grid_errors <- dplyr::arrange(
dplyr::bind_cols(
tibble::as_tibble(t(grid)),
error <- unlist(models)
),
get("error")
)
hyper_params <- tibble::as_tibble(t(grid))[which.min(unlist(models)), ]
final_model <- ranger::ranger(y ~ .,
data = x,
probability = probability,
mtry = hyper_params$mtry,
min.node.size = hyper_params$node_size,
num.trees = hyper_params$num_trees)
} else {
if(verbose == TRUE) {
cat("At least one of mtry, node_size, and num_trees are null,",
"so using default ranger values\n")
}
final_model <- ranger::ranger(y ~ .,
data = x,
probability = probability)
}
values <- if(!is.null(predict_df)) {
if(is.factor(y)) {
stats::predict(final_model, predict_x)$predictions[, 2]
} else {
stats::predict(final_model, predict_x)$predictions
}
}
# List of outs
out <- list(model = final_model)
if(!is.null(predict_df)) out <- append(out, list(values = as.vector(values)))
if(!is.null(mtry) & !is.null(node_size) & !is.null(num_trees)) {
out <- append(out,
list(
grid = grid_errors,
which_min = hyper_params
))
}
return(out)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.