R/expressions.R

Defines functions get_expressions._elnet get_expressions._lognet get_expressions._multnet get_expressions stack_predict.multnet_prob stack_predict.multnet_class multi_net_helper multi_net_engine stack_predict.lognet_prob stack_predict.lognet_class stack_predict.elnet_numeric stack_predict prediction_eqn._multnet eexp prediction_eqn._elnet prediction_eqn._lognet eqn_constuctor prediction_eqn build_linear_predictor._multnet build_linear_predictor._lognet build_linear_predictor._elnet build_linear_predictor_glmnet build_linear_predictor_eng build_linear_predictor

Documented in build_linear_predictor build_linear_predictor._elnet build_linear_predictor._lognet build_linear_predictor._multnet get_expressions get_expressions._elnet get_expressions._lognet get_expressions._multnet prediction_eqn prediction_eqn._elnet prediction_eqn._lognet prediction_eqn._multnet stack_predict stack_predict.elnet_numeric stack_predict.lognet_class stack_predict.lognet_prob stack_predict.multnet_class stack_predict.multnet_prob

#' Creates an R expression for a linear predictor from a data frame of terms and
#' coefficients
#' @param x An object that uses a [glmnet::glmnet()] model and all numeric predictors.
#' @param ... Not currently used.
#' @return An R expression or a list of R expressions, depending on the type of
#' model being used.
#' @export
#' @keywords internal
build_linear_predictor <- function(x, ...) {
  UseMethod("build_linear_predictor")
}

#' @import rlang

# `x` should be a tidy-type format for the data with columns terms and estimate.
build_linear_predictor_eng <- function(x, ...) {
  slopes <- x |> dplyr::filter(terms != "(Intercept)")
  lin_pred <- purrr::map2(
    slopes$terms,
    slopes$estimate,
    function(.x, .y) rlang::expr((!!sym(.x) * !!.y))
  )
  if (any(x$terms == "(Intercept)")) {
    beta_0 <- x$estimate[x$terms == "(Intercept)"]
    lin_pred <- c(beta_0, lin_pred)
  }
  lin_pred <- purrr::reduce(lin_pred, function(l, r) rlang::expr(!!l + !!r))
  lin_pred
}

# This is used for linear and logistic regression
build_linear_predictor_glmnet <- function(x, ...) {
  lvls <- x$lvl
  coefs <-
    .get_glmn_coefs(x$fit, x$spec$args$penalty) |>
    dplyr::filter(estimate != 0)
  lp <- build_linear_predictor_eng(coefs)
  lp
}

#' @export
#' @rdname build_linear_predictor
build_linear_predictor._elnet <- function(x, ...) {
  build_linear_predictor_glmnet(x, ...)
}

#' @export
#' @rdname build_linear_predictor
build_linear_predictor._lognet <- function(x, ...) {
  build_linear_predictor_glmnet(x, ...)
}

#' @export
#' @rdname build_linear_predictor
build_linear_predictor._multnet <- function(x, ...) {
  lvls <- x$lvl
  coefs <-
    .get_glmn_coefs(x$fit, x$spec$args$penalty) |>
    dplyr::filter(estimate != 0) |>
    dplyr::group_nest(class, .key = "coefs") |>
    dplyr::mutate(lp = purrr::map(coefs, build_linear_predictor_eng))
  coefs
}

## -----------------------------------------------------------------------------

#' Convert one or more linear predictor to a format used for prediction
#'
#' @inheritParams build_linear_predictor
#' @param type The prediction type.
#' @return The return type varies, based on the model and prediction type.
#' @export
#' @keywords internal
prediction_eqn <- function(x, ...) {
  UseMethod("prediction_eqn")
}

eqn_constuctor <- function(x, model, type, lvls) {
  type <- match.arg(type, c("class", "prob", "numeric"))
  new_class <- paste(model, type, sep = "_")
  structure(x, class = new_class, levels = lvls)
}

#' @export
#' @rdname prediction_eqn
prediction_eqn._lognet <- function(x, type = "class", ...) {
  type <- match.arg(type, c("class", "prob"))
  model_class <- class(x$fit)[1]
  lp <- build_linear_predictor(x)

  # glmnet models the probability of the _second_ class
  lvls <- x$lvl

  if (type == "prob") {
    elem_names <- paste0(".pred_", lvls)
    res <- vector(mode = "list", length = 2)
    names(res) <- elem_names
    res[[elem_names[2]]] <- rlang::expr(stats::binomial()$linkinv(!!lp))
    res[[elem_names[1]]] <- rlang::expr(stats::binomial()$linkinv(-(!!lp)))
  } else {
    res <-
      list(
        .pred_class = rlang::expr(factor(
          ifelse((!!lp) > 0, !!lvls[2], !!lvls[1]),
          levels = !!lvls
        ))
      )
  }
  eqn_constuctor(res, model_class, type, lvls)
}

#' @export
#' @rdname prediction_eqn
prediction_eqn._elnet <- function(x, type = "numeric", ...) {
  type <- match.arg(type, "numeric")
  model_class <- class(x$fit)[1]
  res <- list(.pred = build_linear_predictor(x))
  eqn_constuctor(res, model_class, type, NULL)
}

eexp <- function(x) rlang::expr(exp(!!x))

#' @export
#' @rdname prediction_eqn
prediction_eqn._multnet <- function(x, type = "class", ...) {
  type <- match.arg(type, c("class", "prob"))
  model_class <- class(x$fit)[1]
  lvls <- x$lvl
  res <- build_linear_predictor(x) |>
    dplyr::mutate(.pred = purrr::map(lp, eexp))
  names(res$.pred) <- paste0(".pred_", res$class)
  eqn_constuctor(res, model_class, type, lvls)
}


## -----------------------------------------------------------------------------

#' Convert one or more linear predictor to a format used for prediction
#'
#' @inheritParams build_linear_predictor
#' @param x A set of model expressions generated by [prediction_eqn()].
#' @return The return type varies, based on the model and prediction type.
#' @export
#' @keywords internal
stack_predict <- function(x, ...) {
  UseMethod("stack_predict")
}

#' @export
#' @rdname stack_predict
stack_predict.elnet_numeric <- function(x, data, ...) {
  tibble::tibble(.pred = rlang::eval_tidy(x$.pred, data))
}

#' @export
#' @rdname stack_predict
stack_predict.lognet_class <- function(x, data, ...) {
  tibble::tibble(.pred_class = rlang::eval_tidy(x$.pred_class, data))
}

#' @export
#' @rdname stack_predict
stack_predict.lognet_prob <- function(x, data, ...) {
  purrr::map_dfc(x, rlang::eval_tidy, data = data)
}


multi_net_engine <- function(x, data, ...) {
  res <-
    purrr::map_dfc(x$.pred, rlang::eval_tidy, data = data) |>
    multi_net_helper()
}

multi_net_helper <- function(data, ...) {
  data |>
    dplyr::rowwise() |>
    dplyr::mutate(
      .sum = sum(dplyr::c_across(dplyr::starts_with(".pred_")))
    ) |>
    dplyr::mutate(
      dplyr::across(dplyr::starts_with(".pred_"), function(.x) .x / .sum),
      idx = which.max(dplyr::c_across(dplyr::starts_with(".pred_")))
    ) |>
    dplyr::ungroup()
}

#' @export
#' @rdname stack_predict
stack_predict.multnet_class <- function(x, data, ...) {
  lvls <- attr(x, "levels")
  res <- multi_net_engine(x, data)
  res <- 
    res |>
    dplyr::mutate(
      .pred_class = gsub(".pred_", "", colnames(res)[idx]),
      .pred_class = factor(.pred_class, levels = lvls)
    ) |>
    dplyr::select(.pred_class)
  res
}

#' @export
#' @rdname stack_predict
stack_predict.multnet_prob <- function(x, data, ...) {
  multi_net_engine(x, data) |>
    dplyr::select(dplyr::starts_with(".pred_"))
}


## -----------------------------------------------------------------------------

#' Obtain prediction equations for all possible values of type
#' @param x A `parsnip` model with the \code{glmnet} engine.
#' @param ... Not used
#' @return A named list with prediction equations for each possibel type.
#' @export
get_expressions <- function(x, ...) {
  UseMethod("get_expressions")
}

#' @export
#' @rdname get_expressions
get_expressions._multnet <- function(x, ...) {
  list(
    class = prediction_eqn(x, type = "class"),
    prob = prediction_eqn(x, type = "prob")
  )
}

#' @export
#' @rdname get_expressions
get_expressions._lognet <- function(x, ...) {
  list(
    class = prediction_eqn(x, type = "class"),
    prob = prediction_eqn(x, type = "prob")
  )
}

#' @export
#' @rdname get_expressions
get_expressions._elnet <- function(x, ...) {
  list(
    numeric = prediction_eqn(x, type = "numeric")
  )
}

Try the stacks package in your browser

Any scripts or data that you put into this service are public.

stacks documentation built on June 10, 2025, 9:14 a.m.