#' Double Loop Cross Validation Linear Regression
#'
#' Functionality to perform Linear Regression in a Double Loop Cross Validation (DLCV)
#'
#' @param folds rsample object with either group V-fold or the standard V-fold cross validation folds.
#' @param rec recipes recipe used for training
#' @return Tibble with k outer loop models, and training and testing predictions.
#' @export
dlcvLr <- function(folds, rec) {
# Model
lr_model <- logistic_reg() %>%
set_engine("glm") %>%
set_mode(mode = "classification")
# logistic regression
lr_wflow <- workflow() %>%
add_recipe(rec) %>%
add_model(lr_model)
# Train model
lr_folds <- folds %>%
mutate(lr_model = map2(splits, id, ~ {
set.seed(as.integer(stringr::str_sub(.y, nchar(.y), nchar(.y)))+1)
lr_wflow %>% fit(data = analysis(.x))
}),
final_wf = map2(splits, lr_model, ~ {
lr_wflow %>%
finalize_workflow(.y) %>%
fit(analysis(.x))
}),
map2_dfr(splits, final_wf, dlcvOuter))
lr_folds <- lr_folds %>%
select(-splits)
return(lr_folds)
}
#' Feature importance Linear Regression
#'
#' Visualize feature importance of the Linear Regression DLCV results
#'
#' @param x DLCV linear regression object generated by the dlcvLr function
#' @param plot Boolean to visualise plot
#'
#' @return Tibble with the feature importances in all folds.
#' @export
fipLr <- function(x, plot = T) {
x <- x %>%
mutate(coef_estimates = map(lr_model, ~ {
.x %>%
extract_fit_parsnip() %>%
tidy() %>%
transmute(feature = term,
estimate = -estimate) %>%
arrange(-estimate) %>%
filter(estimate != 0.0 & feature != "(Intercept)")
}))
lr_folds_feature_counts <- x %>%
select(coef_estimates) %>%
unnest(coef_estimates) %>%
with_groups(feature, nest) %>%
mutate(mean_coef = map_dbl(data, ~ mean(.x$estimate))) %>%
arrange(mean_coef)
if(plot) {
p <- lr_folds_feature_counts %>%
mutate(feature = factor(feature, levels = lr_folds_feature_counts$feature)) %>%
unnest(data) %>%
ggplot(aes(x = feature, y = estimate)) +
geom_boxplot(aes(y = estimate)) +
geom_point() +
labs(x = "Feature",
y = "Importance (coefficient)") +
coord_flip()
print(p)
}
return(lr_folds_feature_counts)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.