R/loop_over_all_stages.R

Defines functions add_configs get_row_wise_grid loop_over_all_stages2 loop_over_all_stages

# Notes on debugging:
# 1. You can set `options(future.debug = TRUE)` to help
# 2. If you are debugging loop_over_all_stages, use the control option
#    `allow_par = FALSE`; that will use `lapply()` so that you can see output.

loop_over_all_stages <- function(resamples, grid, static) {
  # Some packages may use random numbers so attach them prior to initializing
  # the RNG seed
  attach_pkgs(static$pkgs, strategy = static$strategy)

  # Initialize some objects
  seed_length <- length(resamples$.seeds[[1]])

  # If we are using last_fit() (= zero seed length), don't mess with the RNG
  # stream; otherwise set everything up.
  if (seed_length > 0) {
    orig_seed <- .Random.seed
    # Set seed within the worker process
    assign(".Random.seed", resamples$.seeds[[1]], envir = .GlobalEnv)
    resamples$.seeds <- NULL
    withr::defer(assign(".Random.seed", orig_seed, envir = .GlobalEnv))
  }

  split <- resamples$splits[[1]]
  split_labs <- resamples |>
    dplyr::select(dplyr::starts_with("id"))

  pred_reserve <- NULL
  pred_iter <- 0
  notes <- new_note()
  extracts <- NULL

  sched <- schedule_grid(grid, static$wflow)

  config_tbl <- static$configs

  # Append data partitions here; these are the same for the duration of this function
  data_splits <- get_data_subsets(static$wflow, split, static$split_args)
  static <- update_static(static, data_splits)

  # Now that we have data, determine the names of the outcome data. NOTE that
  # if an inline function is used (e.g. add_formula(log(mpg) ~ .)), We will
  # potentially change it later. See #1024
  static$y_name <- outcome_names(static$wflow, data = split$data)

  # ----------------------------------------------------------------------------
  # Iterate over preprocessors

  num_iterations_pre <- max(nrow(sched), 1)

  for (iter_pre in seq_len(num_iterations_pre)) {
    current_sched_pre <- sched[iter_pre, ]
    location <- glue::glue("preprocessor {iter_pre}/{num_iterations_pre}")
    current_wflow <- .catch_and_log(
      finalize_fit_pre(static$wflow, current_sched_pre, static),
      control = static$control,
      split_labels = split_labs,
      location = location,
      notes = notes
    )

    if (is_failure(current_wflow)) {
      next
    }
    # Update y_name in case the workflow had an inline function like `log(mpg) ~ .`
    static$y_name <- outcome_names(current_wflow)

    num_iterations_model <- max(nrow(current_sched_pre$model_stage[[1]]), 1)

    # --------------------------------------------------------------------------
    # Iterate over model parameters

    # Make a copy of the current workflow so that we can finalize it multiple
    # times, since finalize_*() functions will only update parameters whose
    # values currently are tune()
    wflow_with_fitted_pre <- current_wflow

    for (iter_model in seq_len(num_iterations_model)) {
      current_sched_model <- current_sched_pre$model_stage[[1]][iter_model, ]

      # Splice in any parameters marked for tuning and fit the model
      location <- glue::glue(
        "preprocessor {iter_pre}/{num_iterations_pre}, model {iter_model}/{num_iterations_model}"
      )
      current_wflow <- .catch_and_log(
        finalize_fit_model(wflow_with_fitted_pre, current_sched_model),
        control = static$control,
        split_labels = split_labs,
        location = location,
        notes = notes
      )

      if (is_failure(current_wflow)) {
        next
      }

      current_grid <- rebind_grid(current_sched_pre, current_sched_model)

      has_submodel <- has_sub_param(current_sched_model$predict_stage[[1]])
      num_iterations_pred <- max(
        nrow(current_sched_model$predict_stage[[1]]),
        1
      )

      # --------------------------------------------------------------------------
      # Iterate over prediction submodels; no multipredict, just one at a time
      # wothout retraining the model

      for (iter_pred in seq_len(num_iterations_pred)) {
        current_sched_pred <- current_sched_model$predict_stage[[1]][
          iter_pred,
        ]

        if (has_submodel) {
          sub_nm <- get_sub_param(current_sched_pred)
          sub_grid <- current_sched_pred[, sub_nm]

          # The assigned submodel parameter (from min_grid()) is in the
          # current grid. Remove that and add the one that we are predicting on

          current_grid <- current_grid |>
            dplyr::select(-dplyr::all_of(sub_nm)) |>
            rebind_grid(current_sched_pred)

          # Remove the submodel column since it is in the currrent grid.
          location <- glue::glue(
            "preprocessor {iter_pre}/{num_iterations_pre}, model {iter_model}/{num_iterations_model} (predictions)"
          )
          current_pred <- .catch_and_log(
            predict_all_types(current_wflow, static, sub_grid) |>
              dplyr::select(-dplyr::all_of(sub_nm)),
            control = static$control,
            split_labels = split_labs,
            location = location,
            notes = notes
          )
        } else {
          location <- glue::glue(
            "preprocessor {iter_pre}/{num_iterations_pre}, model {iter_model}/{num_iterations_model} (predictions)"
          )
          current_pred <- .catch_and_log(
            predict_all_types(current_wflow, static),
            control = static$control,
            split_labels = split_labs,
            location = location,
            notes = notes
          )
        }

        if (is_failure(current_pred)) {
          next
        }
        current_pred <- remove_log_notes(current_pred)

        has_post <- has_tailor(current_wflow)
        num_iterations_post <- max(nrow(current_sched_pred$post_stage[[1]]), 1)

        # ----------------------------------------------------------------------
        # Iterate over postprocessors

        # Make a copy of the current workflow so that we can finalize it multiple
        # times, since finalize_*() functions will only update parameters whose
        # values currently are tune()
        wflow_with_fitted_pre_and_model <- current_wflow

        current_predict_grid <- current_grid

        for (iter_post in seq_len(num_iterations_post)) {
          if (has_post) {
            current_sched_post <-
              current_sched_pred$post_stage[[1]][iter_post, ]
            post_grid <- current_sched_post

            current_post_grid <- rebind_grid(
              current_predict_grid,
              current_sched_post
            )

            # make data for post-processor
            if (
              workflows::.workflow_postprocessor_requires_fit(current_wflow)
            ) {
              tailor_train_data <- static$data$cal$data
            } else {
              # if the postprocessor does not require a fit,
              # this does not cause data leakage
              tailor_train_data <- static$data$fit$data
            }

            location <- glue::glue(
              "preprocessor {iter_pre}/{num_iterations_pre}, model {iter_model}/{num_iterations_model}, postprocessing {iter_pred}/{num_iterations_pred}"
            )
            current_wflow <- .catch_and_log(
              finalize_fit_post(
                wflow_with_fitted_pre_and_model,
                data_calibration = tailor_train_data,
                grid = post_grid
              ),
              control = static$control,
              split_labels = split_labs,
              location = location,
              notes = notes
            )
            if (is_failure(current_wflow)) {
              next
            }

            # to predict, use the post-processor directly rather than the
            # workflow so that we don't have to generate the model predictions
            # a second time
            post_fit <- extract_postprocessor(current_wflow, estimated = TRUE)
            post_pred <- .catch_and_log(
              predict(post_fit, current_pred),
              control = static$control,
              split_labels = split_labs,
              location = location,
              notes = notes
            )
            if (is_failure(post_pred)) {
              next
            }

            final_pred <- dplyr::bind_cols(post_pred, current_post_grid)
            current_extract_grid <- current_post_grid
            # end submodels
          } else {
            # No postprocessor so just use what we have
            final_pred <- dplyr::bind_cols(current_pred, current_predict_grid)
            current_extract_grid <- current_predict_grid
          }

          current_wflow <- workflows::.fit_finalize(current_wflow)

          # --------------------------------------------------------------------
          # Allocate predictions to an overall object

          pred_iter <- pred_iter + 1
          pred_reserve <- dplyr::bind_rows(pred_reserve, final_pred)

          # --------------------------------------------------------------------
          # Extractions

          # TODO modularize this:
          if (!is.null(static$control$extract)) {
            location <- glue::glue(
              "preprocessor {iter_pre}/{num_iterations_pre}, model {iter_model}/{num_iterations_model} (extracts)"
            )
            elt_extract <- .catch_and_log(
              extract_details(current_wflow, static$control$extract),
              control = static$control,
              split_labels = split_labs,
              location = location,
              notes = notes
            )

            if (is.null(extracts)) {
              extracts <- tibble::tibble(.extracts = list(1))
              if (nrow(static$param_info) > 0) {
                extracts <- tibble::add_column(
                  current_extract_grid,
                  .extracts = list(1)
                )
              }
              extracts <- extracts[integer(), ]
            }

            if (nrow(static$param_info) > 0) {
              extracts <- tibble::add_row(
                extracts,
                tibble::add_column(
                  current_extract_grid,
                  .extracts = list(elt_extract)
                )
              )
            } else {
              extracts <- tibble::add_row(
                extracts,
                tibble::tibble(.extracts = list(elt_extract))
              )
            }
            if (is_failure(elt_extract)) {
              next
            }
          }

          # Output for these loops:
          # - pred_reserve (probably not null)
          # - extracts (may be null)
          # - notes
        } # post loop
      } # predict loop
    } # model loop
  } # pre loop

  # ----------------------------------------------------------------------------
  # Compute metrics on each config and eval_time

  if (is.null(pred_reserve)) {
    all_metrics <- NULL
  } else {
    location <- glue::glue("internal")
    all_metrics <- .catch_and_log(
      pred_reserve |>
        dplyr::group_by(!!!rlang::syms(static$param_info$id)) |>
        .estimate_metrics(
          metric = static$metrics,
          param_names = static$param_info$id,
          outcome_name = static$y_name,
          event_level = static$control$event_level,
          metrics_info = metrics_info(static$metrics)
        ) |>
        add_configs(static),
      control = static$control,
      split_labels = split_labs,
      location = location,
      notes = notes
    )
  }

  if (!is.null(extracts)) {
    extracts <- add_configs(extracts, static) |>
      dplyr::relocate(.config, .after = .extracts) |>
      dplyr::relocate(names(grid))

    # Failing rows are not in the output:
    empty_extract <- purrr::map_lgl(extracts$.extracts, is.null)
    extracts <- extracts[!empty_extract, ]
  }

  # ----------------------------------------------------------------------------
  # Return the results

  return_tbl <- tibble::tibble(
    .metrics = list(all_metrics),
    .notes = list(notes),
    outcome_names = static$y_name
  )

  if (!is.null(static$control$extract)) {
    if (is.null(extracts)) {
      # Everything failed; return NULL for each row
      return_tbl$.extracts <- purrr::map(1:nrow(return_tbl), \(x) NULL)
    } else {
      return_tbl <- dplyr::mutate(return_tbl, .extracts = list(extracts))
    }
  }

  return_tbl <- vctrs::vec_cbind(return_tbl, split_labs)

  if (static$control$save_pred) {
    if (is.null(pred_reserve)) {
      # Everything failed; return NULL for each row
      return_tbl$.predictions <- purrr::map(1:nrow(return_tbl), \(x) NULL)
    } else {
      return_tbl$.predictions <-
        list(
          add_configs(pred_reserve, static) |>
            # Filter out joined rows that corresponded to a config that failed
            dplyr::filter(!is.na(.row)) |>
            reorder_pred_cols(static$y_name, static$param_info$id)
        )
    }
  }

  return_tbl
}

loop_over_all_stages2 <- function(index, resamples, grid, static) {
  loop_over_all_stages(resamples[[index$b]], grid[[index$s]], static)
}

# ------------------------------------------------------------------------------

# This will take a grid and make a list of subgrids that should be used when
# we parallel process over grid candidates. The function will make 1-row grids
# except when there is a submodel parameter. In that case, it will create a
# subgrid that has fixed values for non-submodel parameters and the associated
# values of the submodel.
get_row_wise_grid <- function(wflow, grid) {
  param_tuned <- tune_args(wflow)$id
  submodel <- wflow |>
    hardhat::extract_spec_parsnip() |>
    get_submodel_info() |>
    dplyr::filter(has_submodel) |>
    purrr::pluck("id")

  const_param <- setdiff(param_tuned, submodel)
  const_param <- rlang::syms(const_param)

  if (length(submodel) == 0) {
    inds <- seq_len(nrow(grid))
  } else {
    grid_inds <- grid |>
      parsnip::add_rowindex() |>
      dplyr::group_nest(!!!const_param) |>
      dplyr::mutate(inds = dplyr::row_number()) |>
      tidyr::unnest(c(data)) |>
      dplyr::select(-.row)
    grid <- grid_inds[, param_tuned]
    inds <- grid_inds$inds
  }
  vctrs::vec_split(grid, inds)$val
}

# ------------------------------------------------------------------------------

add_configs <- function(x, static) {
  config_tbl <- static$configs
  if (length(static$param_info$id) > 0) {
    x <- dplyr::left_join(x, config_tbl, by = static$param_info$id)
  } else {
    x <- dplyr::bind_cols(x, config_tbl)
  }

  dplyr::arrange(x, .config)
}

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Sept. 1, 2025, 5:10 p.m.