R/plot.R

Defines functions quantile_clip autoplot.tabnet_explain autoplot.tabnet_fit

Documented in autoplot.tabnet_explain autoplot.tabnet_fit

#' Plot tabnet_fit model loss along epochs
#'
#' @param object A `tabnet_fit` or `tabnet_pretrain` object as a result of
#' [tabnet_fit()] or [tabnet_pretrain()].
#' @param ...  not used.
#' @return A `ggplot` object.
#' @details
#'  Plot the training loss along epochs, and validation loss along epochs if any.
#'  A dot is added on epochs where model snapshot is available, helping
#'  the choice of `from_epoch` value for later model training resume.
#'
#' @examples
#' \donttest{
#' if (torch::torch_is_installed()) {
#' library(ggplot2)
#' data("attrition", package = "modeldata")
#' attrition_fit <- tabnet_fit(Attrition ~. , data=attrition, valid_split=0.2, epoch=11)
#'
#' # Plot the model loss over epochs
#' autoplot(attrition_fit)
#' }
#' }
#' @importFrom rlang .data
#'
autoplot.tabnet_fit <- function(object, ...) {

  epoch_checkpointed_seq <- seq_along(object$fit$checkpoints) * object$fit$config$checkpoint_epochs

  collect_metrics <- tibble::enframe(object$fit$metrics,name = "epoch") %>%
    tidyr::unnest_longer(value,indices_to = "dataset") %>%
    tidyr::unnest_wider(value) %>%
    # drop entries from pretrain that have missing `dataset`
    tidyr::drop_na(dataset) %>%
    tidyr::pivot_wider(values_from = loss, names_from = dataset) %>%
    # remove the valid col if all NAs to prevent ggplot warnings
    dplyr::select_if(function(x) {!all(is.na(x))} ) %>%
    tidyr::pivot_longer(cols = !epoch, names_to = "dataset", values_to = "loss") %>%
    # add checkpoints
    dplyr::mutate(mean_loss = purrr::map_dbl(loss, mean),
           has_checkpoint = epoch %in% (epoch_checkpointed_seq + min(epoch, na.rm=TRUE) - 1)) %>%
    dplyr::select(-loss)

  checkpoints <- collect_metrics %>%
    dplyr::filter(has_checkpoint, dataset=="train") %>%
    dplyr::mutate(size=2)
  p <- ggplot2::ggplot(collect_metrics, ggplot2::aes(x=epoch, y=mean_loss, color=dataset)) +
    ggplot2::geom_line() +
    ggplot2::geom_point(data = checkpoints, ggplot2::aes(x=epoch, y=mean_loss, color=dataset, size = .data$size ) ) +
    ggplot2::scale_y_log10() +
    ggplot2::guides(colour = ggplot2::guide_legend("Dataset", order=1, override.aes = list(size=1.5, shape=" ")),
           size= ggplot2::guide_legend("has checkpoint", order=2, override.aes = list(size=3, color="#F8766D"), label.theme = ggplot2::element_text(colour = "#FFFFFF"))) +
    ggplot2::theme(legend.position = "bottom") +
    ggplot2::labs(y="Mean loss (log scale)")
  p
  }

#' @rdname autoplot.tabnet_fit
autoplot.tabnet_pretrain <- autoplot.tabnet_fit

#' Plot tabnet_explain mask importance heatmap
#'
#' @param object A `tabnet_explain` object as a result of [tabnet_explain()].
#' @param type a character value. Either `"mask_agg"` the default, for a single
#'  heatmap of aggregated mask importance per predictor along the dataset,
#'   or `"steps"` for one heatmap at each mask step.
#' @param quantile numerical value between 0 and 1. Provides quantile clipping of the
#'  mask values
#' @param ...  not used.
#' @return A `ggplot` object.
#' @details
#'  Plot the tabnet_explain object mask importance per variable along the predicted dataset.
#'  `type="mask_agg"` output a single heatmap of mask aggregated values,
#'  `type="steps"` provides a plot faceted along the `n_steps` mask present in the model.
#'  `quantile=.995` may be used for strong outlier clipping, in order to better highlight
#'  low values. `quantile=1`, the default, do not clip any values.
#'
#' @examples
#' \donttest{
#' if (torch::torch_is_installed()) {
#' library(ggplot2)
#' data("attrition", package = "modeldata")
#' attrition_fit <- tabnet_fit(Attrition ~. , data=attrition, epoch=11)
#' attrition_explain <- tabnet_explain(attrition_fit, attrition)
#' # Plot the model aggregated mask interpretation heatmap
#' autoplot(attrition_explain)
#' }
#' }
#'
autoplot.tabnet_explain <- function(object, type = c("mask_agg", "steps"), quantile = 1, ...) {
  type <- match.arg(type)

  if (type == "steps") {
    .data <- object$masks %>%
      purrr::imap_dfr(~dplyr::mutate(
        .x,
        step = sprintf("Step %d", .y),
        rowname = dplyr::row_number()
      )) %>%
      tidyr::pivot_longer(-c(rowname, step), names_to = "variable", values_to = "mask_agg") %>%
      dplyr::group_by(step) %>%
      dplyr::mutate(mask_agg = quantile_clip(mask_agg, probs=quantile)) %>%
      dplyr::ungroup()
  } else {

  .data <- object$M_explain %>%
    dplyr::mutate(rowname = dplyr::row_number()) %>%
    tidyr::pivot_longer(-rowname, names_to = "variable", values_to = "mask_agg") %>%
    dplyr::mutate(mask_agg = quantile_clip(mask_agg, probs=quantile),
                  step = "mask_aggregate")
  }

  p <- ggplot2::ggplot(.data, ggplot2::aes(x = rowname, y = variable, fill = mask_agg)) +
    ggplot2::geom_tile() +
    ggplot2::scale_fill_viridis_c() +
    ggplot2::facet_wrap(~step) +
    ggplot2::theme_minimal()
  p
}

quantile_clip <- function(x, probs) {
  quantile <- quantile(x, probs = probs)
  purrr::map_dbl(x, ~min(.x, quantile))
}

Try the tabnet package in your browser

Any scripts or data that you put into this service are public.

tabnet documentation built on May 31, 2023, 6:27 p.m.