#' Cross-Validated Ranger Model
#'
#' @description
#' Apply cross-validation on a random forest model and return ROC area and precision-recall curve area
#'
#' @param k for k-fold. Default is 5
#' @param data data that will be cross-validated
#' @param formula ranger formula
#' @param y the target variable
#' @param seed random seed. Default is random
#' @return returns a list containing ROC area and PR area
#'
#' @import ranger
#' @export
#'
bin_cv_rf <- function(k = 5, data, formula, y, seed = runif(1, 0, 9999999)) {
# Ensure formula is typed right
if (class(formula) != "formula") {
formula = as.formula(formula)
}
# Shuffle data
set.seed(seed)
data <- data[sample(nrow(data)), ]
# Create K-folds
folds <- cut(seq(1, nrow(data)), breaks = k, labels = FALSE)
# Initialize variables
model_roc_area <- 0
model_pr_area <- 0
# Perform K-fold cross-validation
for (i in 1:k) {
# Segement your data by fold using the which() function
test_index <- which(folds == i, arr.ind = TRUE)
data_test <- data[test_index, ]
data_train <- data[-test_index, ]
# Model
model <- ranger(formula, data_train, verbose = FALSE,
seed = seed, probability = TRUE,
mtry = round(sqrt(ncol(data_train) - 1)),
importance = "impurity")
# Apply model on test data
predicted <- predict(model, data_test)$predictions[,2]
actual <- data_test[[y]]
model_eval <- bin_model_eval(actual, predicted)
model_roc_area[i] <- model_eval$roc_area
model_pr_area[i] <- model_eval$pr_area
}
return(list(roc_area = model_roc_area,
pr_area = model_pr_area))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.