R/autoplot.R

Defines functions check_penalty_value autoplot_glmnet top_coefs reformat_coefs map_glmnet_coefs autoplot.glmnet autoplot.model_fit

Documented in autoplot.glmnet autoplot.model_fit

#' Create a ggplot for a model object
#'
#' This method provides a good visualization method for model results.
#' Currently, only methods for glmnet models are implemented.
#'
#' @param object A model fit object.
#' @param min_penalty A single, non-negative number for the smallest penalty
#' value that should be shown in the plot. If left `NULL`, the whole data
#' range is used.
#' @param best_penalty A single, non-negative number that will show a vertical
#' line marker. If left `NULL`, no line is shown. When this argument is used,
#' the \pkg{ggrepl} package is required.
#' @param top_n A non-negative integer for how many model predictors to label.
#' The top predictors are ranked by their absolute coefficient value. For
#' multinomial or multivariate models, the `top_n` terms are selected within
#' class or response, respectively.
#' @param ... For [autoplot.glmnet()], options to pass to
#' [ggrepel::geom_label_repel()]. Otherwise, this argument is ignored.
#' @return A ggplot object with penalty on the x-axis and coefficients on the
#' y-axis. For multinomial or multivariate models, the plot is faceted.
#' @details The \pkg{glmnet} package will need to be attached or loaded for
#' its `autoplot()` method to work correctly.
#'
#' @export
autoplot.model_fit <- function(object, ...) {
  autoplot(object$fit, ...)
}

# glmnet is not a formal dependency here.
# unit tests are located at https://github.com/tidymodels/extratests
# nocov start

#' @export
#' @rdname autoplot.model_fit
autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL,
                            top_n = 3L) {
  autoplot_glmnet(object, min_penalty, best_penalty, top_n, ...)
}


map_glmnet_coefs <- function(x) {
  coefs <- coef(x)
  # If parsnip is used to fit the model, glmnet should be attached and this will
  # work. If an object is loaded from a new session, they will need to load the
  # package.
  if (is.null(coefs)) {
    rlang::abort("Please load the glmnet package before running `autoplot()`.")
  }
  p <- x$dim[1]
  if (is.list(coefs)) {
    classes <- names(coefs)
    coefs <- purrr::map(coefs, reformat_coefs, p = p, penalty = x$lambda)
    coefs <- purrr::map2_dfr(coefs, classes, ~ dplyr::mutate(.x, class = .y))
  } else {
    coefs <- reformat_coefs(coefs, p = p, penalty = x$lambda)
  }
  coefs
}

reformat_coefs <- function(x, p, penalty) {
  x <- as.matrix(x)
  num_estimates <- nrow(x)
  if (num_estimates > p) {
    # The intercept is first
    x <- x[-(num_estimates - p),, drop = FALSE]
  }
  term_lab <- rownames(x)
  colnames(x) <- paste(seq_along(penalty))
  x <- tibble::as_tibble(x)
  x$term <- term_lab
  x <- tidyr::pivot_longer(x, cols = -term, names_to = "index", values_to = "estimate")
  x$penalty <- rep(penalty, p)
  x$index <- NULL
  x
}

top_coefs <- function(x, top_n = 5) {
  x %>%
    dplyr::group_by(term) %>%
    dplyr::arrange(term, dplyr::desc(abs(estimate))) %>%
    dplyr::slice(1) %>%
    dplyr::ungroup() %>%
    dplyr::arrange(dplyr::desc(abs(estimate))) %>%
    dplyr::slice(seq_len(top_n))
}

autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) {
  check_penalty_value(min_penalty)

  tidy_coefs <-
    map_glmnet_coefs(x) %>%
    dplyr::filter(penalty >= min_penalty)

  actual_min_penalty <- min(tidy_coefs$penalty)
  num_terms <- length(unique(tidy_coefs$term))
  top_n <- min(top_n[1], num_terms)
  if (top_n < 0) {
    top_n <- 0
  }

  has_groups <- any(names(tidy_coefs) == "class")

  # Keep the large values
  if (has_groups) {
    label_coefs <-
      tidy_coefs %>%
      dplyr::group_nest(class) %>%
      dplyr::mutate(data = purrr::map(data, top_coefs, top_n = top_n)) %>%
      dplyr::select(class, data) %>%
      tidyr::unnest(cols = data)
  } else {
    if (is.null(best_penalty)) {
      label_coefs <- tidy_coefs %>%
        top_coefs(top_n)
    } else {
      label_coefs <- tidy_coefs %>%
        dplyr::filter(penalty > best_penalty) %>%
        dplyr::filter(penalty == min(penalty)) %>%
        dplyr::arrange(dplyr::desc(abs(estimate))) %>%
        dplyr::slice(seq_len(top_n))
    }
  }

  label_coefs <-
    label_coefs %>%
    dplyr::mutate(penalty = best_penalty %||% actual_min_penalty) %>%
    dplyr::mutate(label = gsub(".pred_no_", "", term))

  # plot the paths and highlight the large values
  p <-
    tidy_coefs %>%
    ggplot2::ggplot(ggplot2::aes(x = penalty, y = estimate, group = term, col = term))

  if (has_groups) {
    p <- p + ggplot2::facet_wrap(~ class)
  }

  if (!is.null(best_penalty)) {
    check_penalty_value(best_penalty)
    p <- p + ggplot2::geom_vline(xintercept = best_penalty, lty = 3)
  }

  p <- p +
    ggplot2::geom_line(alpha = .4, show.legend = FALSE) +
    ggplot2::scale_x_log10()

  if(top_n > 0) {
    rlang::check_installed("ggrepel")
    p <- p +
      ggrepel::geom_label_repel(
        data = label_coefs,
        ggplot2::aes(y = estimate, label = label),
        show.legend = FALSE,
        ...
      )
  }
  p
}

check_penalty_value <- function(x) {
  cl <- match.call()
  arg_val <- as.character(cl$x)
  if (!is.vector(x) || length(x) != 1 || !is.numeric(x) || x < 0) {
    msg <- paste0("Argument '", arg_val, "' should be a single, non-negative value.")
    rlang::abort(msg)
  }
  invisible(x)
}

# nocov end

Try the parsnip package in your browser

Any scripts or data that you put into this service are public.

parsnip documentation built on Aug. 18, 2023, 1:07 a.m.