R/my_rf_cv.R

Defines functions my_rf_cv

Documented in my_rf_cv

#' Random Forest Cross-Validation Function
#'
#' This function performs a Random Forest Cross-Validation in R
#'
#' @param k number of folds
#'
#' @keywords prediction
#'
#' @return a numeric with the cross-validation error;
#'
#' @examples
#' my_rf_cv(k = 5)
#'
#' @import randomForest stats
#'
#' @export
my_rf_cv <- function(k) {
  my_gapminder <- my_gapminder
  train <- my_gapminder
  folds <- sample(rep(1:k, length = nrow(train)))
  MSE <- rep(NA, k)
  for (i in 1:k) {
    # define training data as all the data not in the ith fold
    data_train <- train[folds != i, ]
    data_test <- train[folds == i, ]
    # train a random forest model with 5dev0 trees
    my_model <- randomForest(lifeExp ~ gdpPercap, data = data_train, ntree = 50)
    # Record predictions
    prediction <- predict(my_model, data_test[, -4])
    # calculate the MSE
    MSE[i] <- mean((data_test$lifeExp - prediction)^2)
  }
  # Compute average MSE to get CV error
  cv_err <- mean(MSE)
  return(cv_err)
}
celestezeng33/STAT302package documentation built on March 22, 2020, 2:09 a.m.