R/perf_mod.R

Defines functions perf_mod.resamples summary.perf_mod print.perf_mod make_formula perf_mod.rset perf_mod.default perf_mod

Documented in perf_mod perf_mod.resamples perf_mod.rset

#' Bayesian Analysis of Resampling Statistics
#'
#'   Bayesian analysis used here to answer the question: "when looking at
#'  resampling results, are the differences between models 'real?'" To answer
#'  this, a model can be created were the _outcome_ is the resampling statistics
#'  (e.g. accuracy or RMSE). These values are explained by the model types. In
#'  doing this, we can get parameter estimates for each model's affect on
#'  performance and make statistical (and practical) comparisons between models.
#'
#' @param object Depending on the context (see Details below):
#'
#'   * A data frame with `id` columns for the resampling groupds and metric
#'     results in all of the other columns..
#'   * An `rset` object (such as [rsample::vfold_cv()]) containing the `id`
#'     column(s) and at least two numeric columns of model performance
#'     statistics (e.g. accuracy).
#'   * An object from `caret::resamples`.
#'   * An object with class `tune_results`, which could be produced by
#'     `tune::tune_grid()`, `tune::tune_bayes()` or similar.
#'   * A workflow set where all results contain the metric value given in the
#'     `metric` argument value.
#'
#' @param formula An optional model formula to use for the Bayesian hierarchical model
#' (see Details below).
#' @param ... Additional arguments to pass to [rstanarm::stan_glmer()] such as
#'  `verbose`, `prior`, `seed`, `refresh`, `family`, etc.
#' @param metric A single character string for the metric used in the
#' `tune_results` that should be used in the Bayesian analysis. If none is given,
#' the first metric value is used.
#' @param filter A conditional logic statement that can be used to filter the
#' statistics generated by `tune_results` using the tuning parameter values or
#' the `.config` column.
#' @return An object of class `perf_mod`. If a workfkow set is given in
#' `object`, there is an extra class of `"perf_mod_workflow_set"`.
#' @details These functions can be used to process and analyze matched
#'  resampling statistics from different models using a Bayesian generalized
#'  linear model with effects for the model and the resamples.
#'
#' ## Bayesian Model formula
#'
#' By default, a generalized linear model with Gaussian error and an identity
#'  link is fit to the data and has terms for the predictive model grouping
#'  variable. In this way, the performance metrics can be compared between
#'  models.
#'
#' Additionally, random effect terms are also used. For most resampling
#'  methods (except repeated _V_-fold cross-validation), a simple random
#'  intercept model its used with an exchangeable (i.e. compound-symmetric)
#'  variance structure. In the case of repeated cross-validation, two random
#'  intercept terms are used; one for the repeat and another for the fold within
#'  repeat. These also have exchangeable correlation structures.
#'
#'   The above model specification assumes that the variance in the performance
#'  metrics is the same across models. However, this is unlikely to be true in
#'  some cases. For example, for simple binomial accuracy, it well know that the
#'  variance is highest when the accuracy is near 50 percent. When the argument
#'  `hetero_var = TRUE`, the variance structure uses random intercepts for each
#'  model term. This may produce more realistic posterior distributions but may
#'  take more time to converge.
#'
#'  Examples of the default formulas are:
#'
#'  \preformatted{
#'    # One ID field and common variance:
#'      statistic ~ model + (model | id)
#'
#'    # One ID field and heterogeneous variance:
#'      statistic ~ model + (model + 0 | id)
#'
#'    # Repeated CV (id = repeat, id2 = fold within repeat)
#'    # with a common variance:
#'      statistic ~ model + (model | id2/id)
#'
#'    # Repeated CV (id = repeat, id2 = fold within repeat)
#'    # with a heterogeneous variance:
#'      statistic ~ model + (model + 0| id2/id)
#'
#'    # Default for unknown resampling method and
#'    # multiple ID fields:
#'      statistic ~ model + (model | idN/../id)
#'  }
#'
#'  Custom formulas should use `statistic` as the outcome variable and `model`
#'  as the factor variable with the model names.
#'
#'   Also, as shown in the package vignettes, the Gaussian assumption make be
#'  unrealistic. In this case, there are at least two approaches that can be
#'  used. First, the outcome statistics can be transformed prior to fitting the
#'  model. For example, for accuracy, the logit transformation can be used to
#'  convert the outcome values to be on the real line and a model is fit to
#'  these data. Once the posterior distributions are computed, the inverse
#'  transformation can be used to put them back into the original units. The
#'  `transform` argument can be used to do this.
#'
#'   The second approach would be to use a different error distribution from the
#'  exponential family. For RMSE values, the Gamma distribution may produce
#'  better results at the expense of model computational complexity. This can be
#'  achieved by passing the `family` argument to `perf_mod` as one might with
#'  the `glm` function.
#'
#' ## Input formats
#'
#' There are several ways to give resampling results to the `perf_mod()` function. To
#' illustrate, here are some example objects using 10-fold cross-validation for a
#' simple two-class problem:
#'
#'
#' ```r
#'    library(tidymodels)
#'    library(tidyposterior)
#'    library(workflowsets)
#'
#'    data(two_class_dat, package = "modeldata")
#'
#'    set.seed(100)
#'    folds <- vfold_cv(two_class_dat)
#' ```
#'
#' We can define two different models (for simplicity, with no tuning parameters).
#'
#'
#' ```r
#'    logistic_reg_glm_spec <-
#'      logistic_reg() %>%
#'      set_engine('glm')
#'
#'    mars_earth_spec <-
#'      mars(prod_degree = 1) %>%
#'      set_engine('earth') %>%
#'      set_mode('classification')
#' ```
#'
#' For tidymodels, the [tune::fit_resamples()] function can be used to estimate
#' performance for each model/resample:
#'
#'
#' ```r
#'    rs_ctrl <- control_resamples(save_workflow = TRUE)
#'
#'    logistic_reg_glm_res <-
#'      logistic_reg_glm_spec %>%
#'      fit_resamples(Class ~ ., resamples = folds, control = rs_ctrl)
#'
#'    mars_earth_res <-
#'      mars_earth_spec %>%
#'      fit_resamples(Class ~ ., resamples = folds, control = rs_ctrl)
#' ```
#'
#' From these, there are several ways to pass the results to `perf_mod()`.
#'
#' ### Data Frame as Input
#'
#' The most general approach is to have a data frame with the resampling labels (i.e.,
#' one or more id columns) as well as columns for each model that you would like to
#' compare.
#'
#' For the model results above, [tune::collect_metrics()] can be used along with some
#' basic data manipulation steps:
#'
#'
#' ```r
#'    logistic_roc <-
#'      collect_metrics(logistic_reg_glm_res, summarize = FALSE) %>%
#'      dplyr::filter(.metric == "roc_auc") %>%
#'      dplyr::select(id, logistic = .estimate)
#'
#'    mars_roc <-
#'      collect_metrics(mars_earth_res, summarize = FALSE) %>%
#'      dplyr::filter(.metric == "roc_auc") %>%
#'      dplyr::select(id, mars = .estimate)
#'
#'    resamples_df <- full_join(logistic_roc, mars_roc, by = "id")
#'    resamples_df
#' ```
#'
#' ```
#'    ## # A tibble: 10 x 3
#'    ##   id     logistic  mars
#'    ##   <chr>     <dbl> <dbl>
#'    ## 1 Fold01    0.908 0.875
#'    ## 2 Fold02    0.904 0.917
#'    ## 3 Fold03    0.924 0.938
#'    ## 4 Fold04    0.881 0.881
#'    ## 5 Fold05    0.863 0.864
#'    ## 6 Fold06    0.893 0.889
#'    ## # … with 4 more rows
#' ```
#'
#' We can then give this directly to `perf_mod()`:
#'
#'
#' ```r
#'    set.seed(101)
#'    roc_model_via_df <- perf_mod(resamples_df, refresh = 0)
#'    tidy(roc_model_via_df) %>% summary()
#' ```
#'
#' ```
#'    ## # A tibble: 2 x 4
#'    ##   model     mean lower upper
#'    ##   <chr>    <dbl> <dbl> <dbl>
#'    ## 1 logistic 0.892 0.879 0.906
#'    ## 2 mars     0.888 0.875 0.902
#' ```
#'
#' ### rsample Object as Input
#'
#' Alternatively, the result columns can be merged back into the original `rsample`
#' object. The up-side to using this method is that `perf_mod()` will know exactly
#' which model formula to use for the Bayesian model:
#'
#'
#' ```r
#'    resamples_rset <-
#'      full_join(folds, logistic_roc, by = "id") %>%
#'      full_join(mars_roc, by = "id")
#'
#'    set.seed(101)
#'    roc_model_via_rset <- perf_mod(resamples_rset, refresh = 0)
#'    tidy(roc_model_via_rset) %>% summary()
#' ```
#'
#' ```
#'    ## # A tibble: 2 x 4
#'    ##   model     mean lower upper
#'    ##   <chr>    <dbl> <dbl> <dbl>
#'    ## 1 logistic 0.892 0.879 0.906
#'    ## 2 mars     0.888 0.875 0.902
#' ```
#'
#' ### Workflow Set Object as Input
#'
#' Finally, for tidymodels, a workflow set object can be used. This is a collection of
#' models/preprocessing combinations in one object. We can emulate a workflow set using
#' the existing example results then pass that to `perf_mod()`:
#'
#'
#' ```r
#'    example_wset <-
#'      as_workflow_set(logistic = logistic_reg_glm_res, mars = mars_earth_res)
#'
#'    set.seed(101)
#'    roc_model_via_wflowset <- perf_mod(example_wset, refresh = 0)
#'    tidy(roc_model_via_rset) %>% summary()
#' ```
#'
#' ```
#'    ## # A tibble: 2 x 4
#'    ##   model     mean lower upper
#'    ##   <chr>    <dbl> <dbl> <dbl>
#'    ## 1 logistic 0.892 0.879 0.906
#'    ## 2 mars     0.888 0.875 0.902
#' ```
#'
#' ### caret resamples object
#'
#' The `caret` package can also be used. An equivalent set of models are created:
#'
#'
#'
#' ```r
#'    library(caret)
#'
#'    set.seed(102)
#'    logistic_caret <- train(Class ~ ., data = two_class_dat, method = "glm",
#'                            trControl = trainControl(method = "cv"))
#'
#'    set.seed(102)
#'    mars_caret <- train(Class ~ ., data = two_class_dat, method = "gcvEarth",
#'                        tuneGrid = data.frame(degree = 1),
#'                        trControl = trainControl(method = "cv"))
#' ```
#'
#' Note that these two models use the same resamples as one another due to setting the
#' seed prior to calling `train()`. However, these are different from the tidymodels
#' results used above (so the final results will be different).
#'
#' `caret` has a `resamples()` function that can collect and collate the resamples.
#' This can also be given to `perf_mod()`:
#'
#'
#' ```r
#'    caret_resamples <- resamples(list(logistic = logistic_caret, mars = mars_caret))
#'
#'    set.seed(101)
#'    roc_model_via_caret <- perf_mod(caret_resamples, refresh = 0)
#'    tidy(roc_model_via_caret) %>% summary()
#' ```
#'
#' ```
#'    ## # A tibble: 2 x 4
#'    ##   model     mean lower upper
#'    ##   <chr>    <dbl> <dbl> <dbl>
#'    ## 1 logistic 0.821 0.801 0.842
#'    ## 2 mars     0.822 0.802 0.842
#' ```
#' @references
#' Kuhn and Silge (2021) _Tidy Models with R_, Chapter 11,
#' \url{https://www.tmwr.org/compare.html}
#' @seealso [tidy.perf_mod()], [tidyposterior::contrast_models()]
#' @export
perf_mod <- function(object, ...) {
  UseMethod("perf_mod")
}


#' @export
perf_mod.default <- function(object, ...) {
  rlang::abort(
    "`object` should have at least one of these classes: ",
    "'rset', 'workflow_set', 'data.frame', 'resamples', or 'vfold_cv'. ",
    "See ?perf_mod"
  )
}

#' @rdname perf_mod
#' @param transform An named list of transformation and inverse
#'  transformation functions. See [logit_trans()] as an example.
#' @param hetero_var A logical; if `TRUE`, then different
#'  variances are estimated for each model group. Otherwise, the
#'  same variance is used for each group. Estimating heterogeneous
#'  variances may slow or prevent convergence.
#' @export

perf_mod.rset <-
  function(object, transform = no_trans, hetero_var = FALSE, formula = NULL, ...) {
    check_trans(transform)
    rset_type <- try(pretty(object), silent = TRUE)
    if (inherits(rset_type, "try-error")) {
      rset_type <- NA
    }

    ## dplyr::filter (and `[` !) drops the other classes =[
    if (inherits(object, "bootstraps")) {
      oc <- class(object)
      object <- object %>% dplyr::filter(id != "Apparent")
      class(object) <- oc
    }

    if (any(names(object) == "splits")) {
      object$splits <- NULL
    }
    resamples <-
      tidyr::pivot_longer(object,
        cols = c(-dplyr::matches("(^id$)|(^id[0-9])")),
        names_to = "model",
        values_to = "statistic"
      ) %>%
      dplyr::mutate(statistic = transform$func(statistic))

    ## Make a formula based on resampling type (repeatedcv, rof),
    ## This could be done with more specific classes
    id_cols <- grep("(^id$)|(^id[1-9]$)", names(object), value = TRUE)
    formula <- make_formula(id_cols, hetero_var, formula)

    model_names <- unique(as.character(resamples$model))

    mod <- stan_glmer(formula, data = resamples, ...)

    res <- list(
      stan = mod,
      hetero_var = hetero_var,
      names = model_names,
      rset_type = rset_type,
      ids = get_id_vals(resamples),
      transform = transform,
      metric = list(name = NA_character_, direction = NA_character_)
    )
    class(res) <- "perf_mod"
    res
  }

make_formula <- function(ids, hetero_var, formula) {
  if (is.null(formula)) {
    ids <- sort(ids)
    p <- length(ids)
    if (p > 1) {
      msg <-
        paste0(
          "There were multiple resample ID columns in the data. It is ",
          "unclear what the model formula should be for the hierarchical ",
          "model. This analysis used the formula: "
        )
      nested <- paste0(rev(ids), collapse = "/")
      if (hetero_var) {
        f_chr <- paste0("statistic ~  model + (model + 0 |", nested, ")")
        f <- as.formula(f_chr)
      } else {
        f_chr <- paste0("statistic ~  model + (1 |", nested, ")")
        f <- as.formula(f_chr)
      }
      msg <- paste0(
        msg, rlang::expr_label(f),
        " The `formula` arg can be used to change this value."
      )
      rlang::warn(msg)
    } else {
      if (hetero_var) {
        f <- statistic ~ model + (model + 0 | id)
      } else {
        f <- statistic ~ model + (1 | id)
      }
    }
  } else {
    f <- formula
  }
  attr(f, ".Environment") <- rlang::base_env()
  f
}

#' @export
print.perf_mod <- function(x, ...) {
  cat("Bayesian Analysis of Resampling Results\n")
  if (!is.na(x$rset_type)) {
    cat("Original data: ")
    cat(x$rset_type, sep = "\n")
  }
  cat("\n")
  invisible(x)
}

#' @export
summary.perf_mod <- function(object, ...) {
  summary(object$stan)
}


#' @export
#' @rdname perf_mod
#' @param metric A single character value for the statistic from
#'  the `resamples` object that should be analyzed.
perf_mod.resamples <-
  function(object,
           transform = no_trans,
           hetero_var = FALSE,
           metric = object$metrics[1],
           ...) {
    suffix <- paste0("~", metric, "$")
    metric_cols <- grep(suffix,
      names(object$values),
      value = TRUE
    )
    object$values <- object$values %>%
      dplyr::select(Resample, !!metric_cols) %>%
      setNames(gsub(suffix, "", names(.)))

    if (is_repeated_cv(object)) {
      split_up <- strsplit(as.character(object$values$Resample), "\\.")
      object$values <- object$values %>%
        dplyr::mutate(
          id = map_chr(split_up, function(x) x[2]),
          id2 = map_chr(split_up, function(x) x[1])
        ) %>%
        dplyr::select(-Resample)
      class(object$values) <- c("vfold_cv", "rset", class(object$values))
      cv_att <- list(
        v = length(unique(object$values$id2)),
        repeats = length(unique(object$values$id)),
        strata = FALSE
      )
      for (i in names(cv_att)) attr(object$values, i) <- cv_att[[i]]
    } else {
      object$values <- object$values %>%
        dplyr::rename(id = Resample)
      class(object$values) <- c("rset", class(object$values))
    }


    res <- perf_mod(object$values, transform = transform, hetero_var = hetero_var, ...)
    res$metric <- list(name = metric_cols[1], direction = NA_character_)
    res
  }

#' @export
#' @rdname perf_mod
perf_mod.data.frame <-
  function(object,
           transform = no_trans,
           hetero_var = FALSE,
           formula = NULL,
           ...) {
    id_cols <- grep("(^id)|(^id[1-9]$)", names(object), value = TRUE)
    if (length(id_cols) == 0) {
      rlang::abort("One or more `id` columns are required.")
    }

    class(object) <- c("rset", class(object))

    res <- perf_mod(object,
      transform = transform, hetero_var = hetero_var,
      formula = formula, ...
    )
    res$metric <- list(name = NA_character_, direction = NA_character_)
    res
  }


#' @export
#' @rdname perf_mod
perf_mod.tune_results <-
  function(object, metric = NULL, transform = no_trans,
           hetero_var = FALSE, formula = NULL, filter = NULL, ...) {
    metric_info <- tune::.get_tune_metrics(object)
    metric_info <- tune::metrics_info(metric_info)
    if (!is.null(metric)) {
      if (all(metric != metric_info$.metric)) {
        rlang::abort(
          paste0(
            "'metric` should be one of: ",
            paste0("'", metric_info$.metric, "'", collapse = ", ")
          )
        )
      }
      metric <- metric[1]
    } else {
      metric <- metric_info$.metric[1]
    }
    metric_dir <- metric_info$direction[metric_info$.metric == metric]


    dat <- tune::collect_metrics(object, summarize = FALSE)
    dat <- dplyr::filter(dat, .metric == metric)

    filters <- rlang::enexpr(filter)
    if (!is.null(filters)) {
      dat <- dplyr::filter(dat, !!filters)
    }

    id_vars <- grep("(^id$)|(^id[0-9])", names(dat), value = TRUE)
    keep_vars <- c(id_vars, ".estimate", ".config")
    if (any(names(dat) == ".iter")) {
      keep_vars <- c(keep_vars, ".iter")
    }
    dat <- dplyr::select(dat, dplyr::all_of(keep_vars))

    dat <- tidyr::pivot_wider(dat,
      id_cols = dplyr::all_of(id_vars),
      names_from = ".config", values_from = ".estimate"
    )

    rset_info <- attributes(object)$rset_info$att
    rset_info$class <- c(rset_info$class, class(dplyr::tibble()))
    dat <- rlang::exec("structure", .Data = dat, !!!rset_info)
    res <- perf_mod(dat, transform = transform, hetero_var = hetero_var, formula = formula, ...)
    res$metric <- list(name = metric, direction = metric_dir)
    res
  }

#' @export
#' @rdname perf_mod
perf_mod.workflow_set <-
  function(object, metric = NULL, transform = no_trans, hetero_var = FALSE, formula = NULL, ...) {
    check_trans(transform)
    metric_info <- tune::.get_tune_metrics(object$result[[1]])
    metric_info <- tune::metrics_info(metric_info)
    if (!is.null(metric)) {
      if (all(metric != metric_info$.metric)) {
        rlang::abort(
          paste0(
            "'metric` should be one of: ",
            paste0("'", metric_info$.metric, "'", collapse = ", ")
          )
        )
      }
      metric <- metric[1]
    } else {
      metric <- metric_info$.metric[1]
    }
    metric_dir <- metric_info$direction[metric_info$.metric == metric]

    resamples <-
      tune::collect_metrics(object, summarize = FALSE) %>%
      dplyr::filter(.metric == metric & id != "Apparent")

    ranked <-
      workflowsets::rank_results(object, rank_metric = metric, select_best = TRUE) %>%
      dplyr::select(wflow_id, .config)
    resamples <- dplyr::inner_join(resamples, ranked, by = c("wflow_id", ".config"))

    if (any(names(resamples) == ".iter")) {
      resamples$sub_model <- paste(resamples$.config, resamples$.iter, sep = "_")
    } else {
      resamples$sub_model <- resamples$.config
    }

    resamples <-
      resamples %>%
      dplyr::select(model = wflow_id, sub_model, dplyr::starts_with("id"), statistic = .estimate)

    ## Make a formula based on resampling type (repeatedcv, rof),
    ## This could be done with more specific classes
    id_cols <- grep("(^id)|(^id[1-9]$)", names(object), value = TRUE)
    formula <- make_formula(id_cols, hetero_var, formula)

    model_names <- unique(as.character(resamples$model))

    mod <- rstanarm::stan_glmer(formula, data = resamples, ...)

    res <- list(
      stan = mod,
      hetero_var = hetero_var,
      names = model_names,
      rset_type = attributes(object$result[[1]])$rset_info$label,
      metric = list(name = metric, direction = metric_dir),
      ids = get_id_vals(resamples),
      transform = transform
    )
    structure(res, class = c("perf_mod_workflow_set", "perf_mod"))
  }

Try the tidyposterior package in your browser

Any scripts or data that you put into this service are public.

tidyposterior documentation built on Oct. 12, 2023, 1:07 a.m.