R/loop_over_all_stages-helpers.R

Defines functions extract_details attach_pkgs parsnip_to_engine engine_to_parsnip reorder_pred_cols determine_pred_types make_config_labs get_config_key replace_reserve_rows initialize_pred_reserve get_output_columns rebind_grid finalize_fit_model finalize_fit_pre predict_all_types finalize_fit_post has_tailor_estimated has_tailor_tuned has_tailor get_sub_param has_sub_param has_mod_param has_pre_param remove_stage update_static get_data_subsets check_static_data make_static

# ------------------------------------------------------------------------------
# Helpers for loop_over_all_stages()

# Note: in loop(), we add more elements for the outcome name(s), and the
# data partitions
make_static <- function(
  workflow,
  param_info,
  grid,
  metrics,
  eval_time,
  split_args,
  control,
  pkgs = "tune",
  strategy = "sequential",
  data = list(fit = NULL, pred = NULL, cal = NULL)
) {
  # check inputs
  if (!inherits(workflow, "workflow")) {
    cli::cli_abort("{.arg workflow} should be a {.cls workflow} object.")
  }
  if (!inherits(param_info, "parameters")) {
    cli::cli_abort("{.arg param_info} should be a {.cls parameters} object.")
  }
  if (!inherits(metrics, "metric_set")) {
    cli::cli_abort("{.arg metrics} should be a {.cls metric_set} object.")
  }
  if (!check_class_or_null(eval_time, "numeric")) {
    cli::cli_abort("{.arg eval_time} should be a numeric vector.")
  }

  configs <- get_config_key(grid, workflow)

  list(
    wflow = workflow,
    param_info = param_info,
    configs = configs,
    post_estimation = workflows::.workflow_postprocessor_requires_fit(
      workflow
    ),
    metrics = metrics,
    metric_info = tibble::as_tibble(metrics),
    pred_types = determine_pred_types(workflow, metrics),
    eval_time = eval_time,
    split_args = split_args,
    control = control,
    pkgs = pkgs,
    strategy = strategy,
    data = data
  )
}

check_static_data <- function(x, elem = "fit") {
  if (is.null(x)) {
    return(x)
  }

  nms <- sort(names(x))
  if (!identical(nms, c("data", "ind"))) {
    cli::cli_abort(
      "{.arg data_*} arguments should have names {.val data} and
                   {.val ind}, not {.val {nms}} in the {.field {elem}} slot."
    )
  }

  if (!is.integer(x[["ind"]])) {
    cli::cli_abort(
      "Element {.arg ind} should be an integer in the
                   {.field {elem}} slot."
    )
  }

  if (!is.data.frame(x[["data"]])) {
    cli::cli_abort(
      "Element {.arg data} should be a tibble in the {.field {elem}} slot."
    )
    if (!tibble::is_tibble(x[["data"]])) {
      x[["data"]] <- tibble::as_tibble(x[["data"]])
    }
  }

  x
}

get_data_subsets <- function(wflow, split, split_args = NULL) {
  fit_lst <- pred_lst <- cal_lst <- list(data = NULL, ind = NULL)
  pred_lst$data <- rsample::assessment(split)
  pred_lst$ind <- as.integer(split, data = "assessment")
  if (workflows::.workflow_postprocessor_requires_fit(wflow)) {
    # if the workflow has a postprocessor that needs training (i.e. calibration),
    # further split the analysis data into an "inner" analysis and
    # assessment set.
    # * the preprocessor and model (excluding the post-processor) are fitted
    #   on `analysis(internal_calibration_split(split))`, the inner analysis set (just
    #   referred to as analysis)
    # * that model generates predictions on `calibration(internal_calibration_split(split))`,
    #   the calibration set
    # * the post-processor is trained on the predictions generated from the
    #   calibration set
    # * the model (including the post-processor) generates predictions on the
    #   assessment set and those predictions are assessed with performance metrics
    split <- rsample::internal_calibration_split(split, split_args = split_args)

    cal_lst$ind <- as.integer(split, data = "assessment")
    cal_lst$data <- vctrs::vec_slice(split$data, cal_lst$ind)
  } else {
    cal_lst <- NULL
  }

  fit_lst$data <- rsample::analysis(split)
  fit_lst$ind <- as.integer(split, data = "analysis")
  list(fit = fit_lst, pred = pred_lst, cal = cal_lst)
}

update_static <- function(x, data) {
  check_static_data(data$fit)
  check_static_data(data$pred)
  check_static_data(data$cal)
  x$data <- data
  x
}

# ------------------------------------------------------------------------------
# For stages

remove_stage <- function(x) {
  stages <- c("model_stage", "predict_stage", "post_stage")
  x[, !(names(x) %in% stages)]
}

has_pre_param <- function(x) {
  any(names(x) != "model_stage")
}

has_mod_param <- function(x) {
  any(names(x) != "predict_stage")
}

# ------------------------------------------------------------------------------

# This is run on a `predict_stage` column:
has_sub_param <- function(x) {
  not_post_list <- names(x) != "post_stage"
  has_param_col <- any(not_post_list)
  if (!has_param_col) {
    return(FALSE)
  }
  param_col_nm <- names(x)[not_post_list]
  param_col <- x[[param_col_nm]]
  two_plus_vals <- length(param_col) > 1
  two_plus_vals
}

# This is run on a `predict_stage` column:
get_sub_param <- function(x) {
  not_post_list <- names(x) != "post_stage"
  names(x)[not_post_list]
}

# ------------------------------------------------------------------------------
# from workflows
# nocov start
has_tailor <- function(x) {
  "tailor" %in% names(x$post$actions)
}

has_tailor_tuned <- function(x) {
  if (!has_tailor(x)) {
    res <- FALSE
  } else {
    res <- any(tune_args(x)$source == "tailor")
  }
  res
}
has_tailor_estimated <- function(x) {
  if (!has_tailor(x)) {
    res <- FALSE
  } else {
    post <- hardhat::extract_postprocessor(x)
    res <- tailor::tailor_requires_fit(post)
  }
  res
}
# nocov end

# ------------------------------------------------------------------------------
# Prediction and postprocessing

finalize_fit_post <- function(wflow_current, data_calibration, grid = NULL) {
  if (is.null(grid)) {
    grid <- dplyr::tibble()
  }

  post_obj <- hardhat::extract_postprocessor(wflow_current) |>
    finalize_tailor(grid)
  wflow_current <- set_workflow_tailor(wflow_current, post_obj)

  workflows::.fit_post(wflow_current, data_calibration)
}

# ------------------------------------------------------------------------------

predict_all_types <- function(
  wflow_fit,
  static,
  submodel_grid = NULL,
  predictee = "assessment"
) {
  predictee <- rlang::arg_match(predictee, c("assessment", "calibration"))
  outputs <- get_output_columns(wflow_fit)

  if (predictee == "calibration" && static$post_estimation) {
    if (is.null(static$data$cal)) {
      cli::cli_abort(
        "Calibration data were requested but not reserved.",
        call = NULL
      )
    }
    .data <- static$data$cal$data
    .ind <- static$data$cal$ind
  } else {
    .data <- static$data$pred$data
    .ind <- static$data$pred$ind
  }

  processed_data_pred <- forge_from_workflow(.data, wflow_fit)
  processed_data_pred$outcomes <- processed_data_pred$outcomes |>
    dplyr::mutate(.row = .ind)

  model_fit <- wflow_fit |> hardhat::extract_fit_parsnip()

  # TODO tune::predict_model has some pre-prediction checks

  sub_param <- names(submodel_grid)

  # Convert argument names to parsnip format see #1011
  submodel_grid <- engine_to_parsnip(static$wflow, submodel_grid)

  pred <- NULL
  for (type_iter in static$pred_types) {
    tmp_res <- predict_wrapper(
      model = model_fit,
      new_data = processed_data_pred$predictors,
      type = type_iter,
      eval_time = static$eval_time,
      subgrid = submodel_grid
    )
    tmp_res$.row <- .ind

    # predict_wrapper() is designed to predict all submodels at once; we get a
    # list column back called .pred with a single row. Collapse that and move
    # the submodel column.
    if (length(sub_param) > 0) {
      tmp_res <- tidyr::unnest(tmp_res, cols = c(.pred))

      # For censored regression (and eventually quantile regression),
      # dynamic predictions will generate a list column that can have multiple
      # predictions per training set row so those end up in a list column.
      # Since we unnested, we need to re-nest to get back to the original
      # format.
      if (type_iter %in% dyn_types) {
        tmp_res <-
          tmp_res |>
          tidyr::nest(.pred = c(-dplyr::all_of(sub_param), -.row))
      }
    }

    # Now go back to engine names
    tmp_res <- parsnip_to_engine(static$wflow, tmp_res)

    if (is.null(pred)) {
      pred <- tmp_res
    } else {
      pred <- dplyr::full_join(pred, tmp_res, by = c(sub_param, ".row"))
    }
  }

  pred <- pred |>
    dplyr::full_join(processed_data_pred$outcomes, by = ".row")

  # Add implicitly grouped metric data, if applicable
  metrics_by <- get_metrics_by(static$metrics)
  if (has_metrics_by(metrics_by)) {
    .data$.row <- .ind
    pred <- dplyr::full_join(pred, .data[c(metrics_by, ".row")], by = ".row")
  }

  # Add case weights (if needed)
  if (has_case_weights(wflow_fit)) {
    case_weights <- extract_case_weights(.data, wflow_fit)
    if (.use_case_weights_with_yardstick(case_weights[[1]])) {
      case_weights <- dplyr::mutate(case_weights, .row = .ind)
      pred <- dplyr::full_join(pred, case_weights, by = ".row")
    }
  }

  pred <- maybe_add_ipcw(pred, model_fit, static$pred_types)

  pred
}

# ------------------------------------------------------------------------------
# Fitting/training functions

finalize_fit_pre <- function(wflow_current, grid, static) {
  pre_proc <- hardhat::extract_preprocessor(wflow_current)

  if (inherits(pre_proc, "recipe")) {
    grid <- remove_stage(grid)
    pre_proc_param <- hardhat::extract_parameter_set_dials(pre_proc)
    pre_proc_id <- pre_proc_param$id

    if (length(pre_proc_id) > 0) {
      grid <- grid[, pre_proc_id]
      pre_proc <- finalize_recipe(pre_proc, grid)
      wflow_current <- set_workflow_recipe(wflow_current, pre_proc)
    }
  }
  workflows::.fit_pre(wflow_current, static$data$fit$data)
}

finalize_fit_model <- function(wflow_current, grid) {
  mod_spec <- hardhat::extract_spec_parsnip(wflow_current)

  grid <- remove_stage(grid)
  mod_param <- hardhat::extract_parameter_set_dials(mod_spec)
  mod_id <- mod_param$id

  if (length(mod_id) > 0) {
    grid <- grid[, mod_id]
    mod_spec <- finalize_model(mod_spec, grid)
    wflow_current <- set_workflow_spec(wflow_current, mod_spec)
  }

  # .catch_and_log_fit()
  .fit_model(wflow_current, workflows::control_workflow())
}

# ------------------------------------------------------------------------------
# Misc functions

rebind_grid <- function(...) {
  list(...) |> purrr::map(remove_stage) |> purrr::list_cbind()
}

get_output_columns <- function(x) {
  # This needs a fitted model or workflow
  pred_cols <- parsnip::.get_prediction_column_names(x, syms = TRUE)
  res <- c(list(outcome = rlang::syms(outcome_names(x))), pred_cols)
  res
}

# ------------------------------------------------------------------------------
# pre-allocating predictions

initialize_pred_reserve <- function(predictions, grid_size) {
  if (!tibble::is_tibble(predictions)) {
    predictions <- dplyr::as_tibble(predictions)
  }
  grid_size <- max(1, grid_size)
  ptype <- predictions[0, ]
  size <- nrow(predictions) * grid_size
  res <- ptype[1:size, ]
  dplyr::as_tibble(res)
}

replace_reserve_rows <- function(iter, chunk) {
  start_loc <- (iter - 1) * chunk + 1
  end_loc <- iter * chunk
  start_loc:end_loc
}

# ------------------------------------------------------------------------------
# Add .config to grid

get_config_key <- function(grid, wflow) {
  info <- tune_args(wflow)
  key <- grid

  only_param <- setdiff(info$id, names(grid))
  if (length(only_param) > 0) {
    cli::cli_abort(
      "Some parameters are tagged for tuning but are not in the grid:
      {.arg {only_param}}",
      call = NULL
    )
  }

  only_grid <- setdiff(names(grid), info$id)
  if (length(only_grid) > 0) {
    cli::cli_abort(
      "Some parameters are in the grid but are not tagged for tuning:
      {.arg {only_grid}}",
      call = NULL
    )
  }

  pre_param <- info$id[info$source == "recipe"]
  if (length(pre_param) > 0) {
    key <- make_config_labs(grid, pre_param) |>
      dplyr::full_join(key, by = pre_param)
  } else {
    key <- key |>
      dplyr::mutate(pre = "pre0")
  }

  mod_param <- info$id[info$source == "model_spec"]
  if (length(mod_param) > 0) {
    key <- make_config_labs(grid, mod_param, "mod") |>
      dplyr::full_join(key, by = mod_param)
  } else {
    key <- key |>
      dplyr::mutate(mod = "mod0")
  }

  post_param <- info$id[info$source == "tailor"]
  if (length(post_param) > 0) {
    key <- make_config_labs(grid, post_param, "post") |>
      dplyr::full_join(key, by = post_param)
  } else {
    key <- key |>
      dplyr::mutate(post = "post0")
  }

  # in the case of resampling without tuning, grid and thus key are 0-row tibbles
  if (nrow(key) < 1) {
    key <- dplyr::tibble(
      pre = "pre0",
      mod = "mod0",
      post = "post0"
    )
  }

  key$.config <- paste(key$pre, key$mod, key$post, sep = "_")
  key$.config <- gsub("_$", "", key$.config)
  key |>
    dplyr::arrange(.config) |>
    dplyr::select(dplyr::all_of(info$id), .config)
}

make_config_labs <- function(grid, param, val = "pre") {
  res <- grid |>
    dplyr::select(dplyr::all_of(param)) |>
    dplyr::distinct() |>
    dplyr::arrange(!!!rlang::syms(param)) |>
    dplyr::mutate(
      num = format(dplyr::row_number()),
      num = gsub(" ", "0", num),
      {{ val }} := paste0(val, num)
    ) |>
    dplyr::select(-num)

  res
}

determine_pred_types <- function(wflow, metrics) {
  model_mode <- extract_spec_parsnip(wflow)$mode

  pred_types <- unique(metrics_info(metrics)$type)
  if (has_tailor(wflow)) {
    post <- extract_postprocessor(wflow)
    post_out <- purrr::map(post$adjustments, "outputs")
    post_in <- purrr::map(post$adjustments, "inputs")
    post_types <- unlist(c(post_out, post_in))
    post_types[grepl("probability", post_types)] <- "prob"
    post_cls <- purrr::map(post$adjustments, class)
    post_cls <- unlist(post_cls)
    if (any(post_cls == "probability_calibration")) {
      post_types <- c(post_types, "class", "prob")
    }
    post_cls <- unique(post_cls)
    pred_types <- unique(c(pred_types, post_types))
  }

  if (any(pred_types == "everything")) {
    if (model_mode == "regression") {
      pred_types <- c(pred_types, "numeric")
    } else if (model_mode == "classification") {
      pred_types <- c(pred_types, "class", "prob")
    } else if (model_mode == "censored regression") {
      pred_types <- c(
        pred_types,
        "static_survival_metric",
        "dynamic_survival_metric"
      )
    } else {
      cli::cli_abort(
        "No prediction types are known for mode {.val model_mode}."
      )
    }

    pred_types <- pred_types[pred_types != "everything"]
  }

  sort(unique(pred_types))
}

reorder_pred_cols <- function(x, outcome = character(0), param = character(0)) {
  x |>
    dplyr::relocate(
      # Outcome first
      dplyr::all_of(outcome),
      # Dynamic columns next
      dplyr::any_of("^\\.eval_time"),
      # dplyr::any_of(".quantile_level"),     # placeholder for future value
      # Prediction columns
      dplyr::matches("^\\.pred_time$"),
      dplyr::matches("^\\.pred$"),
      dplyr::matches("^\\.pred_class$"),
      dplyr::matches("^\\.pred_."),
      # Row indicator
      dplyr::any_of("^\\.row"),
      # Resample indicator(s)
      dplyr::matches("^id$"),
      dplyr::matches("^id[1-9]$"),
      .before = dplyr::everything()
    ) |>
    # Put any tuning parameters at the end, .config at the very end
    dplyr::relocate(
      dplyr::all_of(param),
      .config,
      .after = dplyr::everything()
    )
}


engine_to_parsnip <- function(wflow, grid) {
  if (is.null(grid)) {
    return(grid)
  }
  grid_nm <- names(grid)
  key <- parsnip::.model_param_name_key(wflow) |>
    dplyr::filter(user != parsnip & user %in% grid_nm) |>
    dplyr::select(-engine)

  if (nrow(key) == 0) {
    return(grid)
  }
  nm_lst <- key$user
  names(nm_lst) <- key$parsnip
  dplyr::rename(grid, dplyr::all_of(nm_lst))
}

parsnip_to_engine <- function(wflow, grid) {
  grid_nm <- names(grid)
  key <- parsnip::.model_param_name_key(wflow) |>
    dplyr::filter(user != parsnip & parsnip %in% grid_nm) |>
    dplyr::select(-engine)

  if (nrow(key) == 0) {
    return(grid)
  }
  nm_lst <- key$parsnip
  names(nm_lst) <- key$user
  dplyr::rename(grid, dplyr::all_of(nm_lst))
}

# ------------------------------------------------------------------------------

attach_pkgs <- function(pkgs, strategy = "sequential") {
  sshh_load <- purrr::quietly(library)

  if (length(pkgs) > 0 & strategy != "sequential") {
    # In parallel, load it all
    pkgs_res <- purrr::map(pkgs, ~ sshh_load(.x, character.only = TRUE))
  }

  invisible(pkgs)
}

extract_details <- function(object, extractor) {
  if (is.null(extractor)) {
    return(list())
  }
  extractor(object)
}

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Sept. 1, 2025, 5:10 p.m.