Nothing
## ----load_data----------------------------------------------------------------
data(mtcars)
head(mtcars)
## ----linear_mod---------------------------------------------------------------
lm_mod <- lm(mpg ~ ., data = mtcars)
summary(lm_mod)
## ----get_naive_error----------------------------------------------------------
err <- mean(resid(lm_mod)^2)
## ----define_fun_cv_lm---------------------------------------------------------
cv_lm <- function(fold, data, reg_form) {
# get name and index of outcome variable from regression formula
out_var <- as.character(unlist(str_split(reg_form, " "))[1])
out_var_ind <- as.numeric(which(colnames(data) == out_var))
# split up data into training and validation sets
train_data <- training(data)
valid_data <- validation(data)
# fit linear model on training set and predict on validation set
mod <- lm(as.formula(reg_form), data = train_data)
preds <- predict(mod, newdata = valid_data)
# capture results to be returned as output
out <- list(coef = data.frame(t(coef(mod))),
SE = ((preds - valid_data[, out_var_ind])^2))
return(out)
}
## ----load_pkgs----------------------------------------------------------------
library(origami)
library(stringr) # used in defining the cv_lm function above
## ----cv_lm_resub--------------------------------------------------------------
# resubstitution estimate
resub <- make_folds(mtcars, fold_fun = folds_resubstitution)[[1]]
resub_results <- cv_lm(fold = resub, data = mtcars, reg_form = "mpg ~ .")
mean(resub_results$SE)
## ----cv_lm_cross_valdate------------------------------------------------------
# cross-validated estimate
folds <- make_folds(mtcars)
cvlm_results <- cross_validate(cv_fun = cv_lm, folds = folds, data = mtcars,
reg_form = "mpg ~ .")
mean(cvlm_results$SE)
## ----cv_fun_randomForest------------------------------------------------------
cv_rf <- function(fold, data, reg_form) {
# get name and index of outcome variable from regression formula
out_var <- as.character(unlist(str_split(reg_form, " "))[1])
out_var_ind <- as.numeric(which(colnames(data) == out_var))
# define training and validation sets based on input object of class "folds"
train_data <- training(data)
valid_data <- validation(data)
# fit Random Forest regression on training set and predict on holdout set
mod <- randomForest(formula = as.formula(reg_form), data = train_data)
preds <- predict(mod, newdata = valid_data)
# define output object to be returned as list (for flexibility)
out <- list(coef = data.frame(mod$coefs),
SE = ((preds - valid_data[, out_var_ind])^2))
return(out)
}
## -----------------------------------------------------------------------------
library(randomForest)
folds <- make_folds(mtcars)
cvrf_results <- cross_validate(cv_fun = cv_rf, folds = folds, data = mtcars,
reg_form = "mpg ~ .")
mean(cvrf_results$SE)
## -----------------------------------------------------------------------------
data(AirPassengers)
print(AirPassengers)
## -----------------------------------------------------------------------------
library(forecast)
folds = make_folds(AirPassengers, fold_fun=folds_rolling_origin,
first_window = 36, validation_size = 24)
fold = folds[[1]]
# function to calculate cross-validated squared error
cv_forecasts <- function(fold, data) {
train_data <- training(data)
valid_data <- validation(data)
valid_size <- length(valid_data)
train_ts <- ts(log10(train_data), frequency = 12)
# borrowed from AirPassengers help
arima_fit <- arima(train_ts, c(0, 1, 1),
seasonal = list(order = c(0, 1, 1),
period = 12))
raw_arima_pred <- predict(arima_fit, n.ahead = valid_size)
arima_pred <- 10^raw_arima_pred$pred
arima_MSE <- mean((arima_pred - valid_data)^2)
# stl model
stl_fit <- stlm(train_ts, s.window = 12)
raw_stl_pred = forecast(stl_fit, h = valid_size)
stl_pred <- 10^raw_stl_pred$mean
stl_MSE <- mean((stl_pred - valid_data)^2)
out <- list(mse = data.frame(fold = fold_index(),
arima = arima_MSE, stl = stl_MSE))
return(out)
}
mses = cross_validate(cv_fun = cv_forecasts, folds = folds,
data = AirPassengers)$mse
colMeans(mses[, c("arima", "stl")])
## ----sessionInfo, echo=FALSE--------------------------------------------------
sessionInfo()
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.