R/fit.R

Defines functions print.model_fit allow_sparse check_xy_interface check_interface inher eval_mod fit_xy.model_spec fit.model_spec

Documented in fit.model_spec fit_xy.model_spec

#' Fit a Model Specification to a Dataset
#'
#' `fit()` and `fit_xy()` take a model specification, translate the required
#'  code by substituting arguments, and execute the model fit
#'  routine.
#'
#' @param object An object of class `model_spec` that has a chosen engine
#'  (via [set_engine()]).
#' @param formula An object of class `formula` (or one that can
#'  be coerced to that class): a symbolic description of the model
#'  to be fitted.
#' @param data Optional, depending on the interface (see Details
#'  below). A data frame containing all relevant variables (e.g.
#'  outcome(s), predictors, case weights, etc). Note: when needed, a
#'  \emph{named argument} should be used.
#' @param case_weights An optional classed vector of numeric case weights. This
#'   must return `TRUE` when [hardhat::is_case_weights()] is run on it. See
#'   [hardhat::frequency_weights()] and [hardhat::importance_weights()] for
#'   examples.
#' @param control A named list with elements `verbosity` and
#'  `catch`. See [control_parsnip()].
#' @param ... Not currently used; values passed here will be
#'  ignored. Other options required to fit the model should be
#'  passed using `set_engine()`.
#' @details  `fit()` and `fit_xy()` substitute the current arguments in the model
#'  specification into the computational engine's code, check them
#'  for validity, then fit the model using the data and the
#'  engine-specific code. Different model functions have different
#'  interfaces (e.g. formula or `x`/`y`) and these functions translate
#'  between the interface used when `fit()` or `fit_xy()` was invoked and the one
#'  required by the underlying model.
#'
#' When possible, these functions attempt to avoid making copies of the
#'  data. For example, if the underlying model uses a formula and
#'  `fit()` is invoked, the original data are references
#'  when the model is fit. However, if the underlying model uses
#'  something else, such as `x`/`y`, the formula is evaluated and
#'  the data are converted to the required format. In this case, any
#'  calls in the resulting model objects reference the temporary
#'  objects used to fit the model.
#'
#' If the model engine has not been set, the model's default engine will be used
#'  (as discussed on each model page). If the `verbosity` option of
#'  [control_parsnip()] is greater than zero, a warning will be produced.
#'
#' If you would like to use an alternative method for generating contrasts when
#' supplying a formula to `fit()`, set the global option `contrasts` to your
#' preferred method. For example, you might set it to:
#' `options(contrasts = c(unordered = "contr.helmert", ordered = "contr.poly"))`.
#' See the help page for [stats::contr.treatment()] for more possible contrast
#' types.
#'
#' For models with `"censored regression"` modes, an additional computation is
#' executed and saved in the parsnip object. The `censor_probs` element contains
#' a "reverse Kaplan-Meier" curve that models the probability of censoring. This
#' may be used later to compute inverse probability censoring weights for
#' performance measures.
#' @examplesIf !parsnip:::is_cran_check()
#' # Although `glm()` only has a formula interface, different
#' # methods for specifying the model can be used
#'
#' library(dplyr)
#' library(modeldata)
#' data("lending_club")
#'
#' lr_mod <- logistic_reg()
#'
#' using_formula <-
#'   lr_mod %>%
#'   set_engine("glm") %>%
#'   fit(Class ~ funded_amnt + int_rate, data = lending_club)
#'
#' using_xy <-
#'   lr_mod %>%
#'    set_engine("glm") %>%
#'   fit_xy(x = lending_club[, c("funded_amnt", "int_rate")],
#'          y = lending_club$Class)
#'
#' using_formula
#' using_xy
#' @return A `model_fit` object that contains several elements:
#' \itemize{
#'   \item \code{lvl}: If the outcome is a factor, this contains
#'    the factor levels at the time of model fitting.
#'   \item \code{spec}: The model specification object
#'    (\code{object} in the call to \code{fit})
#'   \item \code{fit}: when the model is executed without error,
#'    this is the model object. Otherwise, it is a \code{try-error}
#'    object with the error message.
#'   \item \code{preproc}: any objects needed to convert between
#'    a formula and non-formula interface (such as the \code{terms}
#'    object)
#' }
#'  The return value will also have a class related to the fitted model (e.g.
#'  `"_glm"`) before the base class of `"model_fit"`.
#'
#' @seealso [set_engine()], [control_parsnip()], `model_spec`, `model_fit`
#' @param x A matrix, sparse matrix, or data frame of predictors. Only some
#' models have support for sparse matrix input. See `parsnip::get_encoding()`
#' for details. `x` should have column names.
#' @param y A vector, matrix or data frame of outcome data.
#' @rdname fit
#' @export
#' @export fit.model_spec
fit.model_spec <-
  function(object,
           formula,
           data,
           case_weights = NULL,
           control = control_parsnip(),
           ...
  ) {
    if (object$mode == "unknown") {
      rlang::abort("Please set the mode in the model specification.")
    }
    control <- condense_control(control, control_parsnip())
    check_case_weights(case_weights, object)

    if (!inherits(formula, "formula")) {
      msg <- "The {.arg formula} argument must be a formula, but it is a \\
              {.cls {class(formula)[1]}}."

      if (inherits(formula, "recipe")) {
        msg <-
          c(
            msg,
            "i" = "To fit a model with a recipe preprocessor, please use a \\
                 {.help [workflow](workflows::workflow)}."
          )
      }

      cli::cli_abort(msg)
    }

    dots <- quos(...)

    if (length(possible_engines(object)) == 0) {
      prompt_missing_implementation(
        spec = object,
        prompt = cli::cli_abort,
        call = call2("fit")
      )
    }
    if (is.null(object$engine)) {
      eng_vals <- possible_engines(object)
      object$engine <- eng_vals[1]
      if (control$verbosity > 0) {
        rlang::warn(glue::glue("Engine set to `{object$engine}`."))
      }
    }

    if (all(c("x", "y") %in% names(dots))) {
      rlang::abort("`fit.model_spec()` is for the formula methods. Use `fit_xy()` instead.")
    }
    cl <- match.call(expand.dots = TRUE)
    # Create an environment with the evaluated argument objects. This will be
    # used when a model call is made later.
    eval_env <- rlang::env()

    wts <- weights_to_numeric(case_weights, object)

    formula <- patch_formula_environment_with_case_weights(
      formula = formula,
      data = data,
      case_weights = wts
    )

    eval_env$data <- data
    eval_env$formula <- formula
    eval_env$weights <- wts

    fit_interface <-
      check_interface(eval_env$formula, eval_env$data, cl, object)

    if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark"))
      rlang::abort(
        glue::glue(
          "spark objects can only be used with the formula interface to `fit()` ",
          "with a spark data object."
        )
      )

    # populate `method` with the details for this model type
    object <- add_methods(object, engine = object$engine)

    check_installs(object)

    interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_")

    # Now call the wrappers that transition between the interface
    # called here ("fit" interface) that will direct traffic to
    # what the underlying model uses. For example, if a formula is
    # used here, `fit_interface_formula` will determine if a
    # translation has to be made if the model interface is x/y/
    res <-
      switch(
        interfaces,
        # homogeneous combinations:
        formula_formula =
          form_form(
            object = object,
            control = control,
            env = eval_env
          ),

        # heterogenous combinations
        formula_matrix =
          form_xy(
            object = object,
            control = control,
            env = eval_env,
            target = object$method$fit$interface,
            ...
          ),
        formula_data.frame =
          form_xy(
            object = object,
            control = control,
            env = eval_env,
            target = object$method$fit$interface,
            ...
          ),

        rlang::abort(glue::glue("{interfaces} is unknown."))
      )
    res$censor_probs <- reverse_km(object, eval_env)
    model_classes <- class(res$fit)
    class(res) <- c(paste0("_", model_classes[1]), "model_fit")
    res
}

# ------------------------------------------------------------------------------

#' @rdname fit
#' @export
#' @export fit_xy.model_spec
fit_xy.model_spec <-
  function(object,
           x,
           y,
           case_weights = NULL,
           control = control_parsnip(),
           ...
  ) {
    if (object$mode == "unknown") {
      rlang::abort("Please set the mode in the model specification.")
    }

    if (inherits(object, "surv_reg")) {
      rlang::abort("Survival models must use the formula interface.")
    }

    control <- condense_control(control, control_parsnip())

    if (is.null(colnames(x))) {
      rlang::abort("'x' should have column names.")
    }
    check_case_weights(case_weights, object)

    dots <- quos(...)
    if (is.null(object$engine)) {
      eng_vals <- possible_engines(object)
      object$engine <- eng_vals[1]
      if (control$verbosity > 0) {
        rlang::warn(glue::glue("Engine set to `{object$engine}`."))
      }
    }
    y_var <- colnames(y)

    if (object$engine != "spark" & NCOL(y) == 1 & !(is.atomic(y))) {
      if (is.matrix(y)) {
        y <- y[, 1]
      } else {
        y <- y[[1]]
      }
    }

    cl <- match.call(expand.dots = TRUE)
    eval_env <- rlang::env()
    eval_env$x <- x
    eval_env$y <- y
    eval_env$y_var <- y_var
    eval_env$weights <- weights_to_numeric(case_weights, object)

    # TODO case weights: pass in eval_env not individual elements
    fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)

    if (object$engine == "spark")
      rlang::abort(
        glue::glue(
          "spark objects can only be used with the formula interface to `fit()` ",
          "with a spark data object."
        )
      )

    # populate `method` with the details for this model type
    object <- add_methods(object, engine = object$engine)

    check_installs(object)

    interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_")

    # Now call the wrappers that transition between the interface
    # called here ("fit" interface) that will direct traffic to
    # what the underlying model uses. For example, if a formula is
    # used here, `fit_interface_formula` will determine if a
    # translation has to be made if the model interface is x/y/
    res <-
      switch(
        interfaces,
        # homogeneous combinations:
        matrix_matrix = , data.frame_matrix =
          xy_xy(
            object = object,
            env = eval_env,
            control = control,
            target = "matrix",
            ...
          ),

        data.frame_data.frame = , matrix_data.frame =
          xy_xy(
            object = object,
            env = eval_env,
            control = control,
            target = "data.frame",
            ...
          ),

        # heterogenous combinations
        matrix_formula = ,  data.frame_formula =
          xy_form(
            object = object,
            env = eval_env,
            control = control,
            ...
          ),
        rlang::abort(glue::glue("{interfaces} is unknown."))
      )
    res$censor_probs <- reverse_km(object, eval_env)
    model_classes <- class(res$fit)
    class(res) <- c(paste0("_", model_classes[1]), "model_fit")
    res
  }

# ------------------------------------------------------------------------------

eval_mod <- function(e, capture = FALSE, catch = FALSE, envir = NULL, ...) {
  if (capture) {
    if (catch) {
      junk <- capture.output(res <- try(eval_tidy(e, env = envir, ...), silent = TRUE))
    } else {
      junk <- capture.output(res <- eval_tidy(e, env = envir, ...))
    }
  } else {
    if (catch) {
      res <- try(eval_tidy(e, env = envir, ...), silent = TRUE)
    } else {
      res <- eval_tidy(e, env = envir, ...)
    }
  }
  res
}

# ------------------------------------------------------------------------------

inher <- function(x, cls, cl) {
  if (!is.null(x) && !inherits(x, cls)) {
    call <- match.call()
    obj <- deparse(call[["x"]])
    if (length(cls) > 1)
      rlang::abort(
        glue::glue(
          "`{obj}` should be one of the following classes: ",
          glue::glue_collapse(glue::glue("'{cls}'"), sep = ", ")
        )
      )
    else
      rlang::abort(
        glue::glue("`{obj}` should be a {cls} object")
      )
  }
  invisible(x)
}

# ------------------------------------------------------------------------------

check_interface <- function(formula, data, cl, model) {
  inher(formula, "formula", cl)
  inher(data, c("data.frame", "tbl_spark"), cl)

  # Determine the `fit()` interface
  form_interface <- !is.null(formula) & !is.null(data)

  if (form_interface)
    return("formula")
  rlang::abort("Error when checking the interface.")
}

check_xy_interface <- function(x, y, cl, model) {

  sparse_ok <- allow_sparse(model)
  sparse_x <- inherits(x, "dgCMatrix")
  if (!sparse_ok & sparse_x) {
    rlang::abort("Sparse matrices not supported by this model/engine combination.")
  }

  if (sparse_ok) {
    inher(x, c("data.frame", "matrix", "dgCMatrix"), cl)
  } else {
    inher(x, c("data.frame", "matrix"), cl)
  }

  if (!is.null(y) && !is.atomic(y))
    inher(y, c("data.frame", "matrix"), cl)

  # rule out spark data sets that don't use the formula interface
  if (inherits(x, "tbl_spark") | inherits(y, "tbl_spark"))
    rlang::abort(
      glue::glue(
        "spark objects can only be used with the formula interface via `fit()` ",
        "with a spark data object."
        )
      )


  if (sparse_ok) {
    matrix_interface <- !is.null(x) && !is.null(y) && (is.matrix(x) | sparse_x)
  } else {
    matrix_interface <- !is.null(x) && !is.null(y) && is.matrix(x)
  }

  df_interface <- !is.null(x) & !is.null(y) && is.data.frame(x)

  if (matrix_interface) {
    return("matrix")
  }
  if (df_interface) {
    return("data.frame")
  }

  check_outcome(y, model)

  rlang::abort("Error when checking the interface")
}

allow_sparse <- function(x) {
  res <- get_from_env(paste0(class(x)[1], "_encoding"))
  all(res$allow_sparse_x[res$engine == x$engine])
}

#' @method print model_fit
#' @export
print.model_fit <- function(x, ...) {
  cat("parsnip model object\n\n")

  if (is.null(x$elapsed$print) && !is.na(x$elapsed[["elapsed"]])) {
    elapsed <- x$elapsed[["elapsed"]]
    cat("Fit time: ", prettyunits::pretty_sec(elapsed), "\n")
  }

  if (isTRUE(x$elapsed$print)) {
    elapsed <- x$elapsed$elapsed[["elapsed"]]
    cat("Fit time: ", prettyunits::pretty_sec(elapsed), "\n")
  }

  if (inherits(x$fit, "try-error")) {
    cat("Model fit failed with error:\n", x$fit, "\n")
  } else {
    print(x$fit, ...)
  }
  invisible(x)
}
tidymodels/parsnip documentation built on April 12, 2024, 2:14 a.m.