R/dlcv_lr.R

Defines functions fipLr dlcvLr

Documented in dlcvLr fipLr

#' Double Loop Cross Validation Linear Regression
#'
#' Functionality to perform Linear Regression in a Double Loop Cross Validation (DLCV)
#'
#' @param folds rsample object with either group V-fold or the standard V-fold cross validation folds.
#' @param rec recipes recipe used for training
#' @return Tibble with k outer loop models, and training and testing predictions.
#' @export
dlcvLr <- function(folds, rec) {
  # Model
  lr_model <- logistic_reg() %>%
    set_engine("glm") %>%
    set_mode(mode = "classification")

  # logistic regression
  lr_wflow <- workflow() %>%
    add_recipe(rec) %>%
    add_model(lr_model)

  # Train model
  lr_folds <- folds %>%
    mutate(lr_model = map2(splits, id, ~ {
      set.seed(as.integer(stringr::str_sub(.y, nchar(.y), nchar(.y)))+1)
      lr_wflow %>% fit(data = analysis(.x))
    }),
    final_wf = map2(splits, lr_model, ~ {
      lr_wflow %>%
        finalize_workflow(.y) %>%
        fit(analysis(.x))
    }),
    map2_dfr(splits, final_wf, dlcvOuter))

  lr_folds <- lr_folds %>%
    select(-splits)

  return(lr_folds)
}

#' Feature importance Linear Regression
#'
#' Visualize feature importance of the Linear Regression DLCV results
#'
#' @param x DLCV linear regression object generated by the dlcvLr function
#' @param plot Boolean to visualise plot
#'
#' @return Tibble with the feature importances in all folds.
#' @export
fipLr <- function(x, plot = T) {
  x <- x %>%
    mutate(coef_estimates = map(lr_model, ~ {
      .x %>%
        extract_fit_parsnip() %>%
        tidy() %>%
        transmute(feature = term,
                  estimate = -estimate) %>%
        arrange(-estimate) %>%
        filter(estimate != 0.0 & feature != "(Intercept)")
    }))

  lr_folds_feature_counts <- x %>%
    select(coef_estimates) %>%
    unnest(coef_estimates) %>%
    with_groups(feature, nest) %>%
    mutate(mean_coef = map_dbl(data, ~ mean(.x$estimate))) %>%
    arrange(mean_coef)

  if(plot) {
    p <- lr_folds_feature_counts %>%
      mutate(feature = factor(feature, levels = lr_folds_feature_counts$feature)) %>%
      unnest(data) %>%
      ggplot(aes(x = feature, y = estimate)) +
      geom_boxplot(aes(y = estimate)) +
      geom_point() +
      labs(x = "Feature",
           y = "Importance (coefficient)") +
      coord_flip()

    print(p)
  }

  return(lr_folds_feature_counts)
}
mikeniemant/nbs documentation built on June 23, 2022, 4:52 a.m.