R/expressions.R

Defines functions stack_predict.multnet_class multi_net_helper multi_net_engine stack_predict.lognet_prob stack_predict.lognet_class stack_predict.elnet_numeric stack_predict prediction_eqn._multnet eexp prediction_eqn._elnet prediction_eqn._lognet eqn_constuctor prediction_eqn build_linear_predictor._multnet build_linear_predictor._lognet build_linear_predictor._elnet build_linear_predictor_glmnet build_linear_predictor_eng build_linear_predictor

Documented in build_linear_predictor build_linear_predictor._elnet build_linear_predictor._lognet build_linear_predictor._multnet prediction_eqn prediction_eqn._elnet prediction_eqn._lognet prediction_eqn._multnet stack_predict stack_predict.elnet_numeric stack_predict.lognet_class stack_predict.lognet_prob stack_predict.multnet_class

#' Creates an R expression for a linear predictor from a data frame of terms and
#' coefficients
#' @param x An object that uses a [glmnet::glmnet()] model and all numeric predictors. 
#' @param ... Not currently used. 
#' @return An R expression or a list of R expressions, depending on the type of
#' model being used. 
#' @export
#' @keywords internal
build_linear_predictor <- function(x, ...) {
  UseMethod("build_linear_predictor")
}

#' @import rlang

# `x` should be a tidy-type format for the data with columns terms and estimate. 
build_linear_predictor_eng <- function(x, ...) {
  slopes <- x %>% dplyr::filter(terms != "(Intercept)")
  lin_pred <- purrr::map2(slopes$terms, slopes$estimate, ~ rlang::expr((!!sym(.x) * !!.y)))
  if (any(x$terms == "(Intercept)")) {
    beta_0 <- x$estimate[x$terms == "(Intercept)"]
    lin_pred <- c(beta_0, lin_pred)
  }
  lin_pred <- purrr::reduce(lin_pred, function(l, r) rlang::expr(!!l + !!r))
  lin_pred
}

# This is used for linear and logistic regression
build_linear_predictor_glmnet <- function(x, ...) {
  lvls <- x$lvl
  coefs <- 
    .get_glmn_coefs(x$fit, x$spec$args$penalty) %>% 
    dplyr::filter(estimate != 0)
  lp <- build_linear_predictor_eng(coefs)
  lp
}

#' @export
#' @rdname build_linear_predictor
build_linear_predictor._elnet <- function(x, ...) {
  build_linear_predictor_glmnet(x, ...)
}

#' @export
#' @rdname build_linear_predictor
build_linear_predictor._lognet <- function(x, ...) {
  build_linear_predictor_glmnet(x, ...)
}

#' @export
#' @rdname build_linear_predictor
build_linear_predictor._multnet <- function(x, ...) {
  lvls <- x$lvl
  coefs <- 
    .get_glmn_coefs(x$fit, x$spec$args$penalty) %>% 
    dplyr::filter(estimate != 0) %>% 
    dplyr::group_nest(class, .key = "coefs") %>% 
    dplyr::mutate(lp = purrr::map(coefs, build_linear_predictor_eng))
  coefs
}

## -----------------------------------------------------------------------------

#' Convert one or more linear predictor to a format used for prediction
#' 
#' @inheritParams build_linear_predictor
#' @param type The prediction type. 
#' @return The return type varies, based on the model and prediction type. 
#' @export
#' @keywords internal
prediction_eqn <- function(x, ...) {
  UseMethod("prediction_eqn")
}

eqn_constuctor <- function(x, model, type, lvls) {
  type <- match.arg(type, c("class", "prob", "numeric"))
  new_class <- paste(model, type, sep = "_")
  structure(x, class = new_class, levels = lvls)
}

#' @export
#' @rdname prediction_eqn
prediction_eqn._lognet <- function(x, type = "class", ...) {
  type <- match.arg(type, c("class", "prob"))
  model_class <- class(x$fit)[1]
  lp <- build_linear_predictor(x)
  
  # glmnet models the probability of the _second_ class
  lvls <- x$lvl
  
  if (type == "prob") {
    elem_names <- paste0(".pred_", lvls)
    res <- vector(mode = "list", length = 2)
    names(res) <- elem_names
    res[[ elem_names[2] ]] <- rlang::expr(stats::binomial()$linkinv(!!lp))
    res[[ elem_names[1] ]] <- rlang::expr(stats::binomial()$linkinv(-(!!lp)))
  } else {
    res <- 
      list(
        .pred_class = 
          rlang::expr(factor(ifelse((!!lp) > 0, !!lvls[2], !!lvls[1]), levels = !!lvls))
      )
  }
  eqn_constuctor(res, model_class, type, lvls)
}

#' @export
#' @rdname prediction_eqn
prediction_eqn._elnet <- function(x, type = "numeric", ...) {
  type <- match.arg(type, "numeric")
  model_class <- class(x$fit)[1]
  res <- list(.pred = build_linear_predictor(x))
  eqn_constuctor(res, model_class, type, NULL)
}

eexp <- function(x) rlang::expr(exp(!!x))

#' @export
#' @rdname prediction_eqn
prediction_eqn._multnet <- function(x, type = "class", ...) {
  type <- match.arg(type, c("class", "prob"))
  model_class <- class(x$fit)[1]
  lvls <- x$lvl
  res <- build_linear_predictor(x) %>% 
    dplyr::mutate(.pred = purrr::map(lp, eexp))
  names(res$.pred) <- paste0(".pred_", res$class)
  eqn_constuctor(res, model_class, type, lvls)
}


## -----------------------------------------------------------------------------

#' Convert one or more linear predictor to a format used for prediction
#' 
#' @inheritParams build_linear_predictor
#' @param x A set of model expressions generated by [prediction_eqn()].
#' @return The return type varies, based on the model and prediction type. 
#' @export
#' @keywords internal
stack_predict <- function(x, ...) {
  UseMethod("stack_predict")
}

#' @export
#' @rdname stack_predict
stack_predict.elnet_numeric <- function(x, data, ...) {
  tibble::tibble(.pred = rlang::eval_tidy(x$.pred, data))
}

#' @export
#' @rdname stack_predict
stack_predict.lognet_class <- function(x, data, ...) {
  tibble::tibble(.pred_class = rlang::eval_tidy(x$.pred_class, data))
}

#' @export
#' @rdname stack_predict
stack_predict.lognet_prob <- function(x, data, ...) {
  purrr::map_dfc(x, rlang::eval_tidy, data = data)
}


multi_net_engine <- function(x, data, ...) {
  res <- 
    purrr::map_dfc(x$.pred, rlang::eval_tidy, data = data) %>%
    multi_net_helper()
}

multi_net_helper <- function(data, ...) {
  data %>%
    dplyr::rowwise() %>% 
    dplyr::mutate(
      .sum = sum(dplyr::c_across(dplyr::starts_with(".pred_")))) %>%
    dplyr::mutate(
      dplyr::across(dplyr::starts_with(".pred_"), ~ .x/.sum),
      idx = which.max(dplyr::c_across(dplyr::starts_with(".pred_")))
    ) %>% 
    dplyr::ungroup()
}

#' @export
#' @rdname stack_predict
stack_predict.multnet_class <- function(x, data, ...) {
  lvls <- attr(x, "levels")
  res <- 
    multi_net_engine(x, data) %>% 
    dplyr::mutate(
      .pred_class = gsub(".pred_", "", colnames(.)[idx]),
      .pred_class = factor(.pred_class, levels = lvls)
    ) %>% 
    dplyr::select(.pred_class)
  res
}

#' @export
#' @rdname stack_predict
stack_predict.multnet_prob <- function(x, data, ...) {
  multi_net_engine(x, data) %>% 
    dplyr::select(dplyr::starts_with(".pred_"))
}


## -----------------------------------------------------------------------------

#' Obtain prediction equations for all possible values of type
#' @param x A `parsnip` model with the \code{glmnet} engine. 
#' @param ... Not used
#' @return A named list with prediction equations for each possibel type.
#' @export
get_expressions <- function(x, ...) {
  UseMethod("get_expressions")
}

#' @export
#' @rdname get_expressions
get_expressions._multnet <- function(x, ...) {
  list(
    class = prediction_eqn(x, type = "class"),
    prob  = prediction_eqn(x, type = "prob")
  )
}

#' @export
#' @rdname get_expressions
get_expressions._lognet <- function(x, ...) {
  list(
    class = prediction_eqn(x, type = "class"),
    prob  = prediction_eqn(x, type = "prob")
  )
}

#' @export
#' @rdname get_expressions
get_expressions._elnet <- function(x, ...) {
  list(
    numeric = prediction_eqn(x, type = "numeric")
  )
}

Try the stacks package in your browser

Any scripts or data that you put into this service are public.

stacks documentation built on Nov. 6, 2023, 5:08 p.m.