R/dlcv_boruta.R

Defines functions fipBoruta dlcvBoruta

Documented in dlcvBoruta fipBoruta

#' DLCV Boruta
#'
#' Double Loop Cross Validation with Boruta. Features are selected in the inner loop and used in the outer loop to train the model that is subsequently tested on the left-out fold in the outer loop.
#'
#' @param folds rsample object with either group V-fold or the standard V-fold cross validation folds.
#' @param rec recipe object used for training
#' @param features features
#' @return Tibble with k outer loop models, and training and testing predictions.
#' @export
dlcvBoruta <- function(folds, rec, features) {
  lr_spec <- parsnip::logistic_reg() %>%
    parsnip::set_mode("classification") %>%
    parsnip::set_engine("glm")

  # dt_spec <- decision_tree() %>%
  #   set_mode("classification") %>%
  #   set_engine("rpart")

  # Feature selection with boruta
  bor_folds <- folds %>%
    dplyr::mutate(bor_imp = purrr::map(splits, ~ {
      bor_fit <- Boruta::Boruta(x = rsample::analysis(.x) %>% dplyr::select(dplyr::all_of(features)),
                                y = rsample::analysis(.x) %>% dplyr::pull(y))

      bor_imp <- dplyr::tibble(bor_fit$ImpHistory %>%
                          dplyr::as_tibble() %>%
                          tidyr::pivot_longer(cols = dplyr::everything(),
                                       names_to = "var",
                                       values_to = "importance")) %>%
        dplyr::left_join(data.frame(var = names(bor_fit$finalDecision),
                                    decision = bor_fit$finalDecision,
                                    row.names = NULL),
                  by = "var")

      return(bor_imp)
    }))

  # Extract predictors
  bor_folds <- bor_folds %>%
    dplyr::mutate(bor_predictors = purrr::map(bor_imp, ~ .x %>%
                                  dplyr::filter(decision == "Confirmed") %>%
                                  # filter(decision %in% c("Confirmed", "Tentative")) %>%
                                  dplyr::distinct(var) %>%
                                  dplyr::pull(var)))

  # bor_folds %>%
  #   select(id, bor_imp) %>%
  #   unnest(bor_imp) %>%
  #   filter(decision == "Confirmed") %>%
  #   count(var) %>%
  #   ggplot(aes(x = reorder(var, -n), y = n)) +
  #   geom_point() +
  #   labs(x = "Feature",
  #        y = "Times selected")

  # lr
  bor_lr_folds <- bor_folds %>%
    dplyr::mutate(final_wf = purrr::map2(splits, bor_predictors, ~ {
      form <- as.formula(paste0("y ~ ",
                                paste(.y,
                                      collapse = " + ")))

      bor_recipe <- recipes::recipe(form, data = .x)

      bor_workflow <- workflows::workflow() %>%
        workflows::add_recipe(bor_recipe) %>%
        workflows::add_model(lr_spec) # dt_spec

      bor_workflow <- bor_workflow %>%
        parsnip::fit(rsample::analysis(.x))

      return(bor_workflow)
    }),
    purrr::map2_dfr(splits, final_wf, dlcvOuter))

  bor_lr_folds <- bor_lr_folds %>%
    dplyr::select(-splits)

  return(bor_lr_folds)
}

#' Feature importance Boruta DLCV
#'
#' Visualize feature importance of the Boruta DLCV results
#'
#' @param x DLCV boruta object generated by the dlcvBoruta function
#' @param plot Boolean to visualise plot
#' @param min_n Minimal number of folds to visualise
#' @param y_nudge Nudge y coordinate to visualise number of times features was selected in fold.
#'
#' @return Tibble with the feature importances and number of times selected in folds.
#' @export
fipBoruta <- function(x, plot = T, min_n = 0, y_nudge = NULL) {
  # Extract coefficients
  x <- x %>%
    dplyr::mutate(coef_estimates = purrr::map(final_wf, ~ {
      .x %>%
        workflows::extract_fit_parsnip() %>%
        broom::tidy() %>%
        dplyr::transmute(feature = term,
                  estimate = -estimate) %>% # tidymodels bug
        dplyr::arrange(-estimate) %>%
        dplyr::filter(estimate != 0.0 & feature != "(Intercept)")
    }))

  # Count features
  bor_lr_feature_counts <- x %>%
    dplyr::select(coef_estimates) %>%
    tidyr::unnest(coef_estimates) %>%
    dplyr::with_groups(feature, nest) %>%
    dplyr::mutate(n = purrr::map_dbl(data, nrow),
           mean_coef = purrr::map_dbl(data, ~ mean(.x$estimate))) %>%
    dplyr::filter(n >= min_n) %>%
    dplyr::arrange(mean_coef)

  # Plot
  if(plot) {
    if(is.null(y_nudge)) {
      y_nudge <- max(bor_lr_feature_counts$mean_coef) +
        sd(bor_lr_feature_counts$mean_coef)
    }

    p <- bor_lr_feature_counts %>%
      dplyr::mutate(feature = factor(feature, levels = bor_lr_feature_counts$feature)) %>%
      tidyr::unnest(data) %>%
      ggplot2::ggplot(ggplot2::aes(x = feature, y = estimate)) +
      ggplot2::geom_boxplot(ggplot2::aes(y = estimate)) +
      ggplot2::geom_point() +
      ggplot2::geom_label(mapping = ggplot2::aes(x = feature, y = y_nudge, label = n)) +
      ggplot2::labs(x = "Feature",
           y = "Importance (coefficient)") +
      ggplot2::coord_flip()

    print(p)

    # bor_lr_feature_counts %>%
    #   mutate(feature = factor(feature, levels = bor_lr_feature_counts$feature)) %>%
    #   unnest(data) %>%
    #   ggplot(aes(x = factor(n), y = estimate, colour = feature)) +
    #   geom_boxplot() +
    #   geom_point(position = position_dodge(width = 0.75)) +
    #   labs(x = "Number of times selected in outer folds",
    #        y = "Importance (coefficient)")
  }

  return(bor_lr_feature_counts)
}
mikeniemant/nbs documentation built on June 23, 2022, 4:52 a.m.