R/stage.R

Defines functions new_named_list has_action is_stage new_stage new_stage_post new_stage_fit new_stage_pre

new_stage_pre <- function(actions = new_named_list(), mold = NULL, case_weights = NULL) {
  if (!is.null(mold) && !is.list(mold)) {
    abort("`mold` must be a result of calling `hardhat::mold()`.", .internal = TRUE)
  }

  if (!is_null(case_weights) && !hardhat::is_case_weights(case_weights)) {
    abort("`case_weights` must be a true case weights column.", .internal = TRUE)
  }

  new_stage(
    actions = actions,
    mold = mold,
    case_weights = case_weights,
    subclass = "stage_pre"
  )
}

new_stage_fit <- function(actions = new_named_list(), fit = NULL) {
  if (!is.null(fit) && !is_model_fit(fit)) {
    abort("`fit` must be a `model_fit`.", .internal = TRUE)
  }

  new_stage(actions = actions, fit = fit, subclass = "stage_fit")
}

new_stage_post <- function(actions = new_named_list()) {
  new_stage(actions, subclass = "stage_post")
}

# ------------------------------------------------------------------------------

# A `stage` is a collection of `action`s

# There are 3 stages that actions can fall into:
# - pre
# - fit
# - post

new_stage <- function(actions = new_named_list(),
                      ...,
                      subclass = character()) {
  if (!is_list_of_actions(actions)) {
    abort("`actions` must be a list of actions.", .internal = TRUE)
  }

  if (!is_uniquely_named(actions)) {
    abort("`actions` must be uniquely named.", .internal = TRUE)
  }

  fields <- list2(...)

  if (!is_uniquely_named(fields)) {
    abort("`...` must be uniquely named.", .internal = TRUE)
  }

  fields <- list2(actions = actions, !!!fields)

  structure(fields, class = c(subclass, "stage"))
}

# ------------------------------------------------------------------------------

is_stage <- function(x) {
  inherits(x, "stage")
}

has_action <- function(stage, name) {
  name %in% names(stage$actions)
}

# ------------------------------------------------------------------------------

new_named_list <- function() {
  # To standardize results for testing.
  # Mainly applicable when `[[<-` removes all elements from a named list and
  # leaves a named list behind that we want to compare against.
  set_names(list(), character())
}

Try the workflows package in your browser

Any scripts or data that you put into this service are public.

workflows documentation built on May 29, 2024, 3:57 a.m.