inst/doc/generalizedCV.R

## ----load_data----------------------------------------------------------------
data(mtcars)
head(mtcars)

## ----linear_mod---------------------------------------------------------------
lm_mod <- lm(mpg ~ ., data = mtcars)
summary(lm_mod)

## ----get_naive_error----------------------------------------------------------
err <- mean(resid(lm_mod)^2)

## ----define_fun_cv_lm---------------------------------------------------------
cv_lm <- function(fold, data, reg_form) {
  # get name and index of outcome variable from regression formula
  out_var <- as.character(unlist(str_split(reg_form, " "))[1])
  out_var_ind <- as.numeric(which(colnames(data) == out_var))

  # split up data into training and validation sets
  train_data <- training(data)
  valid_data <- validation(data)

  # fit linear model on training set and predict on validation set
  mod <- lm(as.formula(reg_form), data = train_data)
  preds <- predict(mod, newdata = valid_data)

  # capture results to be returned as output
  out <- list(coef = data.frame(t(coef(mod))),
              SE = ((preds - valid_data[, out_var_ind])^2))
  return(out)
}

## ----load_pkgs----------------------------------------------------------------
library(origami)
library(stringr) # used in defining the cv_lm function above

## ----cv_lm_resub--------------------------------------------------------------
# resubstitution estimate
resub <- make_folds(mtcars, fold_fun = folds_resubstitution)[[1]]
resub_results <- cv_lm(fold = resub, data = mtcars, reg_form = "mpg ~ .")
mean(resub_results$SE)

## ----cv_lm_cross_valdate------------------------------------------------------
# cross-validated estimate
folds <- make_folds(mtcars)
cvlm_results <- cross_validate(cv_fun = cv_lm, folds = folds, data = mtcars,
                               reg_form = "mpg ~ .")
mean(cvlm_results$SE)

## ----cv_fun_randomForest------------------------------------------------------
cv_rf <- function(fold, data, reg_form) {
  # get name and index of outcome variable from regression formula
  out_var <- as.character(unlist(str_split(reg_form, " "))[1])
  out_var_ind <- as.numeric(which(colnames(data) == out_var))

  # define training and validation sets based on input object of class "folds"
  train_data <- training(data)
  valid_data <- validation(data)

  # fit Random Forest regression on training set and predict on holdout set
  mod <- randomForest(formula = as.formula(reg_form), data = train_data)
  preds <- predict(mod, newdata = valid_data)

  # define output object to be returned as list (for flexibility)
  out <- list(coef = data.frame(mod$coefs),
              SE = ((preds - valid_data[, out_var_ind])^2))
  return(out)
}

## -----------------------------------------------------------------------------
library(randomForest)
folds <- make_folds(mtcars)
cvrf_results <- cross_validate(cv_fun = cv_rf, folds = folds, data = mtcars,
                               reg_form = "mpg ~ .")
mean(cvrf_results$SE)

## -----------------------------------------------------------------------------
data(AirPassengers)
print(AirPassengers)

## -----------------------------------------------------------------------------
library(forecast)
folds = make_folds(AirPassengers, fold_fun=folds_rolling_origin,
                   first_window = 36, validation_size = 24)
fold = folds[[1]]

# function to calculate cross-validated squared error
cv_forecasts <- function(fold, data) {
  train_data <- training(data)
  valid_data <- validation(data)
  valid_size <- length(valid_data)

  train_ts <- ts(log10(train_data), frequency = 12)

  # borrowed from AirPassengers help
  arima_fit <- arima(train_ts, c(0, 1, 1),
                     seasonal = list(order = c(0, 1, 1),
                                     period = 12))
  raw_arima_pred <- predict(arima_fit, n.ahead = valid_size)
  arima_pred <- 10^raw_arima_pred$pred
  arima_MSE <- mean((arima_pred - valid_data)^2)

  # stl model
  stl_fit <- stlm(train_ts, s.window = 12)
  raw_stl_pred = forecast(stl_fit, h = valid_size)
  stl_pred <- 10^raw_stl_pred$mean
  stl_MSE <- mean((stl_pred - valid_data)^2)

  out <- list(mse = data.frame(fold = fold_index(),
                               arima = arima_MSE, stl = stl_MSE))
  return(out)
}

mses = cross_validate(cv_fun = cv_forecasts, folds = folds,
                      data = AirPassengers)$mse
colMeans(mses[, c("arima", "stl")])

## ----sessionInfo, echo=FALSE--------------------------------------------------
sessionInfo()

Try the origami package in your browser

Any scripts or data that you put into this service are public.

origami documentation built on Oct. 20, 2022, 1:07 a.m.