R/grid_performance.R

Defines functions maybe_surv_prob maybe_estimate unnest_parameters estimate_surv estimate_class_prob estimate_reg .estimate_metrics metrics_info pred_type

Documented in .estimate_metrics metrics_info

# Functions related to computing predicted performance

# will add further options later for surv_prob etc.
pred_type <- function(x) {
  cls <- class(x)[class(x) != "function"][1]
  res <- dplyr::case_when(
    cls == "class_metric" ~ "class",
    cls == "prob_metric" ~ "prob",
    cls == "numeric_metric" ~ "numeric",
    cls == "dynamic_survival_metric" ~ "survival",
    cls == "integrated_survival_metric" ~ "survival",
    cls == "static_survival_metric" ~ "time",
    TRUE ~ "unknown"
  )
  res
}

#' @export
#' @keywords internal
#' @rdname empty_ellipses
metrics_info <- function(x) {
  metric_list <- rlang::env_get(environment(x), "fns")
  metric_dir <- purrr::map_chr(metric_list, attr, "direction")
  res <- tibble::new_tibble(
    list(
      .metric = names(metric_dir),
      direction = unname(metric_dir),
      type = unname(purrr::map_chr(metric_list, pred_type))
    ),
    nrow = length(metric_dir)
  )
  res
}

#' Internal functions used by other tidymodels packages
#'
#' These are not to be meant to be invoked directly by users.
#' @param dat A data set.
#' @param metric A metric set.
#' @param param_names A character vector of tuning parameter names.
#' @param outcome_name A character string for the column of `dat` that is the
#' outcome.
#' @param event_level A logical passed from the control function.
#' @param x A character vector of package names.
#' @param .expr Code to execute.
#' @param ... Object to pass to the internal `tune_log()` function.
#' @param bad_only A logical for whether warnings and errors should be caught.
#' @param notes Character data to add to the logging.
#' @param workflow A workflow.
#' @param grid_preprocessor A tibble with parameter information.
#' @param new_data A data frame or matrix of predictors to process.
#' @param metrics_info The output of `tune:::metrics_info(metrics)`---only
#' included as an argument to allow for pre-computing.
#' @param catalog A logical passed to `tune_log()` giving whether the message
#' is compatible with the issue cataloger. Defaults to `TRUE`. Updates that are
#' always unique and do not represent a tuning "issue" can bypass the cataloger
#' by setting `catalog = FALSE`.
#' @keywords internal
#' @name tune-internal-functions
#' @export
.estimate_metrics <- function(dat, metric, param_names, outcome_name, event_level,
                              metrics_info = metrics_info(metrics)) {
  # The call stack is:
  #
  # tune_grid_loop_iter()
  #   append_metrics(). <many times>
  #    .estimate_metrics()

  # predictions made in predict_model()

  if (inherits(dat, "try-error")) {
    return(NULL)
  }

  # Determine the type of prediction that is required
  types <- unique(metrics_info$type)

  if (length(outcome_name) > 1L) {
    rlang::abort(paste0(
      "Internal error: Multiple outcomes are not ",
      "supported in `.estimate_metrics()`."
    ))
  }

  if (case_weights_column_name() %in% names(dat)) {
    case_weights <- sym(case_weights_column_name())
  } else {
    case_weights <- NULL
  }

  if (all(types == "numeric")) {
    estimate_reg(dat, metric, param_names, outcome_name, case_weights)
  } else if (all(types == "class" | types == "prob")) {
    estimate_class_prob(dat, metric, param_names, outcome_name, case_weights, types, event_level)
  } else if (all(types == "time" | types == "survival")) {
    estimate_surv(dat, metric, param_names, outcome_name, case_weights, types)
  } else {
    rlang::abort("Metric type not yet supported by tune.")
  }
}

estimate_reg <- function(dat, metric, param_names, outcome_name, case_weights) {
  dat %>%
    dplyr::group_by(!!!rlang::syms(param_names)) %>%
    metric(estimate = .pred, truth = !!sym(outcome_name), case_weights = !!case_weights)
}

estimate_class_prob <- function(dat, metric, param_names, outcome_name,
                                case_weights, types, event_level) {
  truth <- sym(outcome_name)

  estimate <- NULL
  if (any(types == "class")) {
    estimate <- sym(".pred_class")
  }

  probs <- NULL
  if (any(types == "prob")) {
    levels <- levels(dat[[outcome_name]])
    probs <- paste0(".pred_", levels)

    # Special case binary class prob metrics,
    # as yardstick requires only 1 column passed through
    if (length(probs) == 2) {
      if (identical(event_level, "first")) {
        probs <- probs[[1]]
      } else if (identical(event_level, "second")) {
        probs <- probs[[2]]
      } else {
        rlang::abort("`event_level` must be either 'first' or 'second'.")
      }
    }
  }

  dat %>%
    dplyr::group_by(!!!rlang::syms(param_names)) %>%
    metric(
      truth = !!truth,
      estimate = !!estimate,
      !!!probs,
      case_weights = !!case_weights,
      event_level = event_level
    )
}

# ------------------------------------------------------------------------------

estimate_surv <- function(dat, metric, param_names, outcome_name, case_weights, types) {
  #  potentially need to work around submodel parameters since those are within .pred
  dat <- unnest_parameters(dat, param_names)
  dat %>%
    dplyr::group_by(!!!rlang::syms(param_names)) %>%
    metric(
      truth = !!rlang::sym(outcome_name),
      estimate = !!maybe_estimate(metric),
      case_weights = !!case_weights,
      !!maybe_surv_prob(metric)
    )
}

unnest_parameters <- function(x, params = NULL) {
   if (is.null(params)) {
    return(x)
   }
   
  # When multi_predict() is used, .pred will have the tuning parameter values.
  # Other (non-submodel) parameters will be outside of 'x'.
  # If this happens, pull the submodel parameters out of .pred and put them at
  # the out level ('x') with the rest (if any)
  
  outer_nms <- names(x)
  has_pred <- any(outer_nms == ".pred")
   if (!has_pred) {
    return(x)
  }
  
  inner_nms <- names(x$.pred[[1]])
  has_inner_params <- any(params %in% inner_nms)
  if (!has_inner_params) {
    return(x)
  }

  x <- x %>% parsnip::add_rowindex()

  others <- x %>% dplyr::select(-.pred)
  rm_cols <- c(".pred_censored", ".weight_time")
  nest_cols <- c(".eval_time", ".pred_survival", ".weight_censored")
  x <-
    x %>%
    dplyr::select(.pred, .row) %>%
    tidyr::unnest(cols = c(.pred)) %>%
    dplyr::select(-dplyr::any_of(rm_cols)) %>%
    tidyr::nest(.pred = c(all_of(nest_cols)), .by = c(.row, dplyr::all_of(params))) %>%
    dplyr::full_join(others, by = ".row") %>%
    dplyr::select(-.row)
  x
}

maybe_estimate <- function(x) {
  info <- tibble::as_tibble(x)
  if (any(info$class == "static_survival_metric")) {
    res <- rlang::sym(".pred_time")
  } else {
    res <- NULL
  }
  res
}

maybe_surv_prob <- function(x) {
  info <- tibble::as_tibble(x)
  # dyn_inputs defined in checks.R
  if (any(info$class %in% dyn_inputs)) {
    res <- rlang::sym(".pred")
  } else {
    res <- NULL
  }
  res
}

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on May 29, 2024, 7:32 a.m.