inst/doc/picking_the_number_of_folds_for_cross-validation.R

## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.path = "man/figures/vignette_pick_k-",
  dpi = 92,
  fig.retina = 2
)
options(rmarkdown.html_vignette.check_title = FALSE)

## ----echo=FALSE, warning=FALSE, message=FALSE---------------------------------
rerun_analysis <- FALSE
save_analysis <- FALSE

if (isTRUE(rerun_analysis)){
  # Uncomment temporarily!!
  # library(doParallel)
  # doParallel::registerDoParallel(6)
}

## ----echo=FALSE, warning=FALSE, message=FALSE, fig.align='center'-------------
library(cvms)
library(dplyr)
library(ggplot2)
library(groupdata2)

generate_data <- function(num_points = 35, seed = 3) {
  xpectr::set_test_seed(seed)
  intuition_df <- data.frame("x" = runif(num_points))
  intuition_df <- intuition_df %>%
    dplyr::mutate(
      y = 1.5 * rnorm(num_points) * runif(num_points) + 1.2 * x + 1.5 * x ^ 2 + 2.8 * x ^ 3) %>%
    groupdata2::fold(k = c(3, 10), num_fold_cols = 2) %>%
    dplyr::mutate(`k = 3` = as.character(.folds_1),
                  `k = 10` = as.character(.folds_2)) %>%
    tidyr::gather(key = "Fold Column", value = "Fold", 5:6) %>%
    dplyr::mutate(Fold = factor(Fold))
  
  intuition_cv <- cross_validate(
    intuition_df,
    formulas = "y ~ x",
    family = "gaussian",
    fold_cols = paste0(".folds_", 1:2)
  )$Results[[1]] %>%
    dplyr::group_by(`Fold Column`) %>%
    dplyr::summarise(RMSE = mean(RMSE),
                     MAE = mean(MAE))

  list(intuition_df, intuition_cv)
}

# Largest difference at: 2
# for (i in 1:100){
#   out <- generate_data(num_points=35, seed=i)  
#   print(i)
#   print(out[[2]])
# }

out <- generate_data(num_points=35, seed=2)
res <- out[[2]]
res[["Fold Column"]] <- factor(c("k = 3", "k = 10"), levels = c("k = 3", "k = 10"))
res[["RMSE"]] <- paste0("RMSE: ", round(res[["RMSE"]], digits = 3))
data <- out[[1]]
data$`Fold Column` <- factor(data$`Fold Column`, levels = c("k = 3", "k = 10"))
data$Fold <- factor(data$Fold, levels = 1:10)

data %>%
  ggplot(aes(x = x, y = y, color = Fold)) + 
  geom_point() +
  stat_smooth(method = "lm", se = FALSE, alpha = 0.5, size = 0.3) +
  facet_wrap(`Fold Column` ~ .) + 
  geom_text(data = res, aes(label = RMSE, color = NULL), x = 0.75, y = -1) +
  theme_minimal() +
  labs(caption = paste0(
    "Generated data split into 3 (left) and 10 (right) folds.",
    "\nLines are optimal linear models for each fold.",
    "\nRMSE (Root Mean Square Error) is for a cross-validated linear model (higher = worse).",
    "\nIn k=3, the third fold differs, impacting the prediction negatively.",
    "\nIn k=10, the two 'outlier' folds don't impact the average error as much."
    )
  )


## ----echo=FALSE, warning=FALSE, message=FALSE---------------------------------
if (isTRUE(rerun_analysis)){
  bootstrapped <- plyr::ldply(1:100, function(i) generate_data(num_points=35, seed=i)[[2]])
  bootstrapped <- bootstrapped %>% 
    dplyr::mutate(`Fold Column` = dplyr::case_when(
      `Fold Column` == ".folds_1" ~ "k = 3",
      `Fold Column` == ".folds_2" ~ "k = 10",
      TRUE ~ `Fold Column`
    ))
  if (isTRUE(save_analysis)){
    save(bootstrapped, file="inst/vignette_data/bootstrapped_cv_picking_k.rda")      
  }
} else {
  bootstrapped_file_path <- cvms:::get_vignette_data_path("bootstrapped_cv_picking_k.rda")
  load(bootstrapped_file_path)
}

bootstrapped <- bootstrapped %>% 
  dplyr::mutate(`Fold Column` = factor(`Fold Column`, levels = c("k = 3", "k = 10")))

bootstrapped %>% 
  dplyr::group_by(`Fold Column`) %>% 
  dplyr::summarise_all(mean)

three_is_larger <- bootstrapped %>% 
  groupdata2::group(n = 2, method = "greedy") %>% 
  dplyr::summarize(larger = diff(RMSE) < 0) %>% 
  dplyr::summarize(larger = sum(larger))

## ----echo=FALSE, warning=FALSE, message=FALSE, fig.width=5.5, fig.height=4, fig.align='center'----
bootstrapped %>% 
  ggplot(aes(x = `Fold Column`, y = RMSE, fill = `Fold Column`)) +
  geom_violin() + 
  theme_minimal()


## ----warning=FALSE, message=FALSE---------------------------------------------
library(cvms)  # version >= 1.2.2 
library(groupdata2)  # version >= 1.4.1
library(dplyr)
library(ggplot2)

xpectr::set_test_seed(1)

## ----eval=FALSE---------------------------------------------------------------
#  # Enable parallelization
#  # NOTE: Uncomment to run
#  # library(doParallel)
#  # doParallel::registerDoParallel(6)
#  

## -----------------------------------------------------------------------------
# Load iris
data("iris")

# Convert iris to a tibble
iris <- dplyr::as_tibble(iris)
iris

## -----------------------------------------------------------------------------
iris %>% 
  dplyr::count(Species)


## -----------------------------------------------------------------------------
# Generate sequence of `k` settings in the 3-50 range
fold_counts <- round(seq(from = 3, to = 50, length.out = 10))
# Repeat each 3 times
fold_counts <- rep(fold_counts, each = 3)

fold_counts


## ----echo=FALSE, warning=FALSE, message=FALSE---------------------------------
if (isTRUE(rerun_analysis)){
  data <- iris %>% 
    groupdata2::fold(
      k = fold_counts, 
      cat_col = "Species", 
      num_fold_cols = length(fold_counts),  # Must match the length of `k` 
      parallel = TRUE
    )
  if (isTRUE(save_analysis)){
    save(data, file="inst/vignette_data/folded_iris_picking_k.rda")      
  }
} else {
  data_file_path <- cvms:::get_vignette_data_path("folded_iris_picking_k.rda")
  load(data_file_path)
}


## ----eval=FALSE---------------------------------------------------------------
#  data <- iris %>%
#    groupdata2::fold(
#      k = fold_counts,
#      cat_col = "Species",
#      num_fold_cols = length(fold_counts),  # Must match the length of `k`
#      parallel = TRUE
#    )
#  

## -----------------------------------------------------------------------------
data

## -----------------------------------------------------------------------------
# Quick way to generate the names of the fold columns
# Note: `seq_along()` is equal to `1:length(fold_counts)`
fold_columns <- paste0(".folds_", seq_along(fold_counts))
fold_columns

## -----------------------------------------------------------------------------
fold_stats <- groupdata2::summarize_group_cols(
    data = data,
    group_cols = fold_columns
  ) %>% 
  rename(`Fold Column` = `Group Column`)

# View fold column statistics
# We only look at one fold column per `k` setting
fold_stats %>% 
  dplyr::group_by(`Num Groups`) %>% 
  dplyr::filter(dplyr::row_number() == 1) %>% 
  knitr::kable()

## -----------------------------------------------------------------------------
# To use the e1071::svm() model function
# we specify the model and predict functions
model_fn <- cvms::model_functions("svm_multinomial")
predict_fn <- cvms::predict_functions("svm_multinomial")

# Specify hyperparameters
hyperparameters <- list('kernel' = 'radial', 'cost' = 10)

## -----------------------------------------------------------------------------
formulas <- c(
  "Species ~ Sepal.Length + Sepal.Width",
  "Species ~ Petal.Length + Petal.Width",
  "Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width",
  "Species ~ Sepal.Length * Sepal.Width + Petal.Length * Petal.Width"
)

## ----echo=FALSE---------------------------------------------------------------
if (isTRUE(rerun_analysis)){
  cv <- cross_validate_fn(
    data = data,
    formulas = formulas,
    type = "multinomial",
    model_fn = model_fn,
    predict_fn = predict_fn,
    hyperparameters = hyperparameters,
    fold_cols = fold_columns,
    parallel = TRUE
  )
  cv$Predictions <- "Removed to save memory"
  if (isTRUE(save_analysis)){
    save(cv, file="inst/vignette_data/cv_iris_picking_k.rda")      
  }
} else {
  cv_file_path <- cvms:::get_vignette_data_path("cv_iris_picking_k.rda")
  load(cv_file_path)
}

## ----eval=FALSE---------------------------------------------------------------
#  cv <- cross_validate_fn(
#    data = data,
#    formulas = formulas,
#    type = "multinomial",
#    model_fn = model_fn,
#    predict_fn = predict_fn,
#    hyperparameters = hyperparameters,
#    fold_cols = fold_columns,
#    parallel = TRUE
#  )
#  

## -----------------------------------------------------------------------------
cv


## -----------------------------------------------------------------------------
# Extract the fold column results
# This is a list of data frames (one per formula)
fold_column_results <- cv$Results

## -----------------------------------------------------------------------------
# Set the names of the data frames to their respective formula
names(fold_column_results) <- cv$Fixed

# Combine the data frames
# Create a 'Formula' column from the names
fold_column_results <- fold_column_results %>% 
  dplyr::bind_rows(.id = "Formula")

## -----------------------------------------------------------------------------
# Make the formula string prettier for plotting
# Sepal -> S; Petal -> P; Width -> W; Length -> L
fold_column_results <- fold_column_results %>% 
  dplyr::mutate(
    Formula = gsub(
      x = Formula,
      pattern = "[^SPWL+*]",
      replacement = ""
    )
  )

## -----------------------------------------------------------------------------
fold_column_results


## -----------------------------------------------------------------------------
# Select the `Num Groups` column and rename to `Num Folds`
fold_counts_df <- fold_stats %>% 
  dplyr::select(`Fold Column`, `Num Groups`) %>% 
  dplyr::rename(`Num Folds` = `Num Groups`)

fold_counts_df

# Add the counts with a join
fold_column_results <- fold_counts_df %>% 
  dplyr::left_join(fold_column_results, by = "Fold Column")

fold_column_results


## ----message=FALSE, warning=FALSE---------------------------------------------
# Add a column for indicating the repetition of the `k` setting
fold_column_results <- fold_column_results %>%
  # Group to restart groupdata2 group numbers for each `Num Folds` setting
  dplyr::group_by(`Num Folds`) %>% 
  # Create a group whenever the `Fold Column` column changes
  groupdata2::group(n = 'auto',
                    method = 'l_starts',
                    starts_col = "Fold Column",
                    col_name = "Repetition") %>% 
  dplyr::ungroup()

# Inspect the columns relevant to this
fold_column_results %>% 
  dplyr::select(`Fold Column`, `Num Folds`, Formula, Repetition) %>% 
  head(20) %>% 
  knitr::kable()

## -----------------------------------------------------------------------------
avg_balanced_acc <- fold_column_results %>%
  dplyr::group_by(`Num Folds`, Formula) %>%
  dplyr::summarise(`Balanced Accuracy` = mean(`Balanced Accuracy`),
                   .groups = "drop")

avg_balanced_acc

## ----warning=FALSE, fig.align='center'----------------------------------------
# Plot the balanced accuracy by the number of folds
# We add jitter to the points to separate overlapping points slightly
fold_column_results %>% 
  ggplot(aes(x = `Num Folds`, y = `Balanced Accuracy`, color = Formula)) +
  geom_point(
    aes(shape = Repetition), 
    size = 1,
    position = position_jitter(h = 0.0, w = 0.6)) + 
  geom_line(data = avg_balanced_acc) +
  theme_minimal()


## -----------------------------------------------------------------------------
# Extract class level fold column results for the best model
# It is a list of tibbles (one for each species) 
# so we concatenate them to a single tibble with bind_rows()
class_level_fold_results <- cv$`Class Level Results`[[4]]$Results %>% 
  dplyr::bind_rows()

# Add the Num Folds counts
class_level_fold_results <- fold_counts_df %>% 
  dplyr::left_join(class_level_fold_results, by = "Fold Column")

class_level_fold_results


## -----------------------------------------------------------------------------
# Add a column for indicating the repetition of the `k` setting
class_level_fold_results <- class_level_fold_results %>% 
  # Group to restart groupdata2 group numbers for each `Num Folds` setting
  dplyr::group_by(`Num Folds`) %>% 
  # Create a group whenever the `Fold Column` column changes
  groupdata2::group(n = 'auto',
                    method = 'l_starts',
                    starts_col = "Fold Column",
                    col_name = "Repetition") %>% 
  dplyr::ungroup()

# Inspect the columns relevant to this
class_level_fold_results %>% 
  dplyr::select(`Fold Column`, `Num Folds`, Class, Repetition)

## -----------------------------------------------------------------------------
class_level_avg_balanced_acc <- class_level_fold_results %>%
  dplyr::group_by(`Num Folds`, Class) %>%
  dplyr::summarise(`Balanced Accuracy` = mean(`Balanced Accuracy`),
                   .groups = "drop")

class_level_avg_balanced_acc

## ----warning=FALSE, fig.align='center'----------------------------------------
# Plot the balanced accuracy by the number of folds
# We add jitter to the points to separate overlapping points slightly
class_level_fold_results  %>% 
  ggplot(aes(x = `Num Folds`, y = `Balanced Accuracy`, color = Class)) +
  geom_point(aes(shape = Repetition), size = 1, 
             position = position_jitter(h = 0.0, w = 0.6)) + 
  geom_line(data = class_level_avg_balanced_acc) +
  theme_minimal()

Try the cvms package in your browser

Any scripts or data that you put into this service are public.

cvms documentation built on Sept. 11, 2024, 6:22 p.m.