R/fit.R

Defines functions finalize_blueprint_variables pull_workflow_spec_encoding_tbl finalize_blueprint_formula finalize_blueprint_recipe finalize_blueprint validate_has_model validate_has_preprocessor .fit_finalize .fit_model .fit_pre fit.workflow

Documented in .fit_finalize .fit_model .fit_pre fit.workflow

#' Fit a workflow object
#'
#' @description
#' Fitting a workflow currently involves two main steps:
#'
#' - Preprocessing the data using a formula preprocessor, or by calling
#'   [recipes::prep()] on a recipe.
#'
#' - Fitting the underlying parsnip model using [parsnip::fit.model_spec()].
#'
#' @details
#' In the future, there will also be _postprocessing_ steps that can be added
#' after the model has been fit.
#'
#' @includeRmd man/rmd/indicators.Rmd details
#'
#' @param object A workflow
#'
#' @param data A data frame of predictors and outcomes to use when fitting the
#'   workflow
#'
#' @param ... Not used
#'
#' @param control A [control_workflow()] object
#'
#' @return
#' The workflow `object`, updated with a fit parsnip model in the
#' `object$fit$fit` slot.
#'
#' @name fit-workflow
#' @export
#' @examples
#' library(parsnip)
#' library(recipes)
#' library(magrittr)
#'
#' model <- linear_reg() %>%
#'   set_engine("lm")
#'
#' base_wf <- workflow() %>%
#'   add_model(model)
#'
#' formula_wf <- base_wf %>%
#'   add_formula(mpg ~ cyl + log(disp))
#'
#' fit(formula_wf, mtcars)
#'
#' recipe <- recipe(mpg ~ cyl + disp, mtcars) %>%
#'   step_log(disp)
#'
#' recipe_wf <- base_wf %>%
#'   add_recipe(recipe)
#'
#' fit(recipe_wf, mtcars)
fit.workflow <- function(object, data, ..., control = control_workflow()) {
  check_dots_empty()

  if (is_missing(data)) {
    abort("`data` must be provided to fit a workflow.")
  }

  workflow <- object
  workflow <- .fit_pre(workflow, data)
  workflow <- .fit_model(workflow, control)
  workflow <- .fit_finalize(workflow)

  # TODO: Post-processing before `.fit_finalize()`?

  workflow
}

# ------------------------------------------------------------------------------

#' Internal workflow functions
#'
#' `.fit_pre()`, `.fit_model()`, and `.fit_finalize()` are internal workflow
#' functions for _partially_ fitting a workflow object. They are only exported
#' for usage by the tuning package, [tune](https://github.com/tidymodels/tune),
#' and the general user should never need to worry about them.
#'
#' @param workflow A workflow
#'
#'   For `.fit_pre()`, this should be a fresh workflow.
#'
#'   For `.fit_model()`, this should be a workflow that has already been trained
#'   through `.fit_pre()`.
#'
#'   For `.fit_finalize()`, this should be a workflow that has been through
#'   both `.fit_pre()` and `.fit_model()`.
#'
#' @param data A data frame of predictors and outcomes to use when fitting the
#'   workflow
#'
#' @param control A [control_workflow()] object
#'
#' @name workflows-internals
#' @keywords internal
#' @export
#' @examples
#' library(parsnip)
#' library(recipes)
#' library(magrittr)
#'
#' model <- linear_reg() %>%
#'   set_engine("lm")
#'
#' wf_unfit <- workflow() %>%
#'   add_model(model) %>%
#'   add_formula(mpg ~ cyl + log(disp))
#'
#' wf_fit_pre <- .fit_pre(wf_unfit, mtcars)
#' wf_fit_model <- .fit_model(wf_fit_pre, control_workflow())
#' wf_fit <- .fit_finalize(wf_fit_model)
#'
#' # Notice that fitting through the model doesn't mark the
#' # workflow as being "trained"
#' wf_fit_model
#'
#' # Finalizing the workflow marks it as "trained"
#' wf_fit
#'
#' # Which allows you to predict from it
#' try(predict(wf_fit_model, mtcars))
#'
#' predict(wf_fit, mtcars)
.fit_pre <- function(workflow, data) {
  validate_has_preprocessor(workflow)
  # A model spec is required to ensure that we can always
  # finalize the blueprint, no matter the preprocessor
  validate_has_model(workflow)

  workflow <- finalize_blueprint(workflow)

  n <- length(workflow[["pre"]]$actions)

  for (i in seq_len(n)) {
    action <- workflow[["pre"]]$actions[[i]]

    # Update both the `workflow` and the `data` as we iterate through pre steps
    result <- fit(action, workflow = workflow, data = data)
    workflow <- result$workflow
    data <- result$data
  }

  # But only return the workflow, it contains the final set of data in `mold`
  workflow
}

#' @rdname workflows-internals
#' @export
.fit_model <- function(workflow, control) {
  action_model <- workflow[["fit"]][["actions"]][["model"]]
  fit(action_model, workflow = workflow, control = control)
}

#' @rdname workflows-internals
#' @export
.fit_finalize <- function(workflow) {
  set_trained(workflow, TRUE)
}

# ------------------------------------------------------------------------------

validate_has_preprocessor <- function(x, ..., call = caller_env()) {
  check_dots_empty()

  has_preprocessor <-
    has_preprocessor_formula(x) ||
      has_preprocessor_recipe(x) ||
      has_preprocessor_variables(x)

  if (!has_preprocessor) {
    message <- c(
      "The workflow must have a formula, recipe, or variables preprocessor.",
      i = "Provide one with `add_formula()`, `add_recipe()`, or `add_variables()`."
    )
    abort(message, call = call)
  }

  invisible(x)
}

validate_has_model <- function(x, ..., call = caller_env()) {
  check_dots_empty()

  has_model <- has_action(x$fit, "model")

  if (!has_model) {
    message <- c(
      "The workflow must have a model.",
      i = "Provide one with `add_model()`."
    )
    abort(message, call = call)
  }

  invisible(x)
}

# ------------------------------------------------------------------------------

finalize_blueprint <- function(workflow) {
  # Use user supplied blueprint if provided
  if (has_blueprint(workflow)) {
    return(workflow)
  }

  if (has_preprocessor_recipe(workflow)) {
    finalize_blueprint_recipe(workflow)
  } else if (has_preprocessor_formula(workflow)) {
    finalize_blueprint_formula(workflow)
  } else if (has_preprocessor_variables(workflow)) {
    finalize_blueprint_variables(workflow)
  } else {
    abort("`workflow` should have a preprocessor at this point.", .internal = TRUE)
  }
}

finalize_blueprint_recipe <- function(workflow) {
  # Use the default blueprint, no parsnip model encoding info is used here
  blueprint <- hardhat::default_recipe_blueprint()

  recipe <- extract_preprocessor(workflow)

  update_recipe(workflow, recipe = recipe, blueprint = blueprint)
}

finalize_blueprint_formula <- function(workflow) {
  tbl_encodings <- pull_workflow_spec_encoding_tbl(workflow)

  indicators <- tbl_encodings$predictor_indicators
  intercept <- tbl_encodings$compute_intercept

  if (!is_string(indicators)) {
    abort("`indicators` encoding from parsnip should be a string.", .internal = TRUE)
  }
  if (!is_bool(intercept)) {
    abort("`intercept` encoding from parsnip should be a bool.", .internal = TRUE)
  }

  # Use model specific information to construct the blueprint
  blueprint <- hardhat::default_formula_blueprint(
    indicators = indicators,
    intercept = intercept
  )

  formula <- extract_preprocessor(workflow)

  update_formula(workflow, formula = formula, blueprint = blueprint)
}

pull_workflow_spec_encoding_tbl <- function(workflow) {
  spec <- extract_spec_parsnip(workflow)
  spec_cls <- class(spec)[[1]]

  if (modelenv::is_unsupervised_spec(spec)) {
    tbl_encodings <- modelenv::get_encoding(spec_cls)
  } else {
    tbl_encodings <- parsnip::get_encoding(spec_cls)
  }

  indicator_engine <- tbl_encodings$engine == spec$engine
  indicator_mode <- tbl_encodings$mode == spec$mode
  indicator_spec <- indicator_engine & indicator_mode

  out <- tbl_encodings[indicator_spec, , drop = FALSE]

  if (nrow(out) != 1L) {
    abort("Exactly 1 model/engine/mode combination must be located.", .internal = TRUE)
  }

  out
}

finalize_blueprint_variables <- function(workflow) {
  # Use the default blueprint, no parsnip model encoding info is used here
  blueprint <- hardhat::default_xy_blueprint()

  variables <- extract_preprocessor(workflow)

  update_variables(
    workflow,
    blueprint = blueprint,
    variables = variables
  )
}

Try the workflows package in your browser

Any scripts or data that you put into this service are public.

workflows documentation built on March 7, 2023, 7:50 p.m.