#' Creates an R expression for a linear predictor from a data frame of terms and
#' coefficients
#' @param x An object that uses a [glmnet::glmnet()] model and all numeric predictors.
#' @param ... Not currently used.
#' @return An R expression or a list of R expressions, depending on the type of
#' model being used.
#' @export
#' @keywords internal
build_linear_predictor <- function(x, ...) {
UseMethod("build_linear_predictor")
}
#' @import rlang
# `x` should be a tidy-type format for the data with columns terms and estimate.
build_linear_predictor_eng <- function(x, ...) {
slopes <- x %>% dplyr::filter(terms != "(Intercept)")
lin_pred <- purrr::map2(slopes$terms, slopes$estimate, ~ rlang::expr((!!sym(.x) * !!.y)))
if (any(x$terms == "(Intercept)")) {
beta_0 <- x$estimate[x$terms == "(Intercept)"]
lin_pred <- c(beta_0, lin_pred)
}
lin_pred <- purrr::reduce(lin_pred, function(l, r) rlang::expr(!!l + !!r))
lin_pred
}
# This is used for linear and logistic regression
build_linear_predictor_glmnet <- function(x, ...) {
lvls <- x$lvl
coefs <-
.get_glmn_coefs(x$fit, x$spec$args$penalty) %>%
dplyr::filter(estimate != 0)
lp <- build_linear_predictor_eng(coefs)
lp
}
#' @export
#' @rdname build_linear_predictor
build_linear_predictor._elnet <- function(x, ...) {
build_linear_predictor_glmnet(x, ...)
}
#' @export
#' @rdname build_linear_predictor
build_linear_predictor._lognet <- function(x, ...) {
build_linear_predictor_glmnet(x, ...)
}
#' @export
#' @rdname build_linear_predictor
build_linear_predictor._multnet <- function(x, ...) {
lvls <- x$lvl
coefs <-
.get_glmn_coefs(x$fit, x$spec$args$penalty) %>%
dplyr::filter(estimate != 0) %>%
dplyr::group_nest(class, .key = "coefs") %>%
dplyr::mutate(lp = purrr::map(coefs, build_linear_predictor_eng))
coefs
}
## -----------------------------------------------------------------------------
#' Convert one or more linear predictor to a format used for prediction
#'
#' @inheritParams build_linear_predictor
#' @param type The prediction type.
#' @return The return type varies, based on the model and prediction type.
#' @export
#' @keywords internal
prediction_eqn <- function(x, ...) {
UseMethod("prediction_eqn")
}
eqn_constuctor <- function(x, model, type, lvls) {
type <- match.arg(type, c("class", "prob", "numeric"))
new_class <- paste(model, type, sep = "_")
structure(x, class = new_class, levels = lvls)
}
#' @export
#' @rdname prediction_eqn
prediction_eqn._lognet <- function(x, type = "class", ...) {
type <- match.arg(type, c("class", "prob"))
model_class <- class(x$fit)[1]
lp <- build_linear_predictor(x)
# glmnet models the probability of the _second_ class
lvls <- x$lvl
if (type == "prob") {
elem_names <- paste0(".pred_", lvls)
res <- vector(mode = "list", length = 2)
names(res) <- elem_names
res[[ elem_names[2] ]] <- rlang::expr(stats::binomial()$linkinv(!!lp))
res[[ elem_names[1] ]] <- rlang::expr(stats::binomial()$linkinv(-(!!lp)))
} else {
res <-
list(
.pred_class =
rlang::expr(factor(ifelse((!!lp) > 0, !!lvls[2], !!lvls[1]), levels = !!lvls))
)
}
eqn_constuctor(res, model_class, type, lvls)
}
#' @export
#' @rdname prediction_eqn
prediction_eqn._elnet <- function(x, type = "numeric", ...) {
type <- match.arg(type, "numeric")
model_class <- class(x$fit)[1]
res <- list(.pred = build_linear_predictor(x))
eqn_constuctor(res, model_class, type, NULL)
}
eexp <- function(x) rlang::expr(exp(!!x))
#' @export
#' @rdname prediction_eqn
prediction_eqn._multnet <- function(x, type = "class", ...) {
type <- match.arg(type, c("class", "prob"))
model_class <- class(x$fit)[1]
lvls <- x$lvl
res <- build_linear_predictor(x) %>%
dplyr::mutate(.pred = purrr::map(lp, eexp))
names(res$.pred) <- paste0(".pred_", res$class)
eqn_constuctor(res, model_class, type, lvls)
}
## -----------------------------------------------------------------------------
#' Convert one or more linear predictor to a format used for prediction
#'
#' @inheritParams build_linear_predictor
#' @param x A set of model expressions generated by [prediction_eqn()].
#' @return The return type varies, based on the model and prediction type.
#' @export
#' @keywords internal
stack_predict <- function(x, ...) {
UseMethod("stack_predict")
}
#' @export
#' @rdname stack_predict
stack_predict.elnet_numeric <- function(x, data, ...) {
tibble::tibble(.pred = rlang::eval_tidy(x$.pred, data))
}
#' @export
#' @rdname stack_predict
stack_predict.lognet_class <- function(x, data, ...) {
tibble::tibble(.pred_class = rlang::eval_tidy(x$.pred_class, data))
}
#' @export
#' @rdname stack_predict
stack_predict.lognet_prob <- function(x, data, ...) {
purrr::map_dfc(x, rlang::eval_tidy, data = data)
}
multi_net_engine <- function(x, data, ...) {
res <-
purrr::map_dfc(x$.pred, rlang::eval_tidy, data = data) %>%
multi_net_helper()
}
multi_net_helper <- function(data, ...) {
data %>%
dplyr::rowwise() %>%
dplyr::mutate(
.sum = sum(dplyr::c_across(dplyr::starts_with(".pred_")))) %>%
dplyr::mutate(
dplyr::across(dplyr::starts_with(".pred_"), ~ .x/.sum),
idx = which.max(dplyr::c_across(dplyr::starts_with(".pred_")))
) %>%
dplyr::ungroup()
}
#' @export
#' @rdname stack_predict
stack_predict.multnet_class <- function(x, data, ...) {
lvls <- attr(x, "levels")
res <-
multi_net_engine(x, data) %>%
dplyr::mutate(
.pred_class = gsub(".pred_", "", colnames(.)[idx]),
.pred_class = factor(.pred_class, levels = lvls)
) %>%
dplyr::select(.pred_class)
res
}
#' @export
#' @rdname stack_predict
stack_predict.multnet_prob <- function(x, data, ...) {
multi_net_engine(x, data) %>%
dplyr::select(dplyr::starts_with(".pred_"))
}
## -----------------------------------------------------------------------------
#' Obtain prediction equations for all possible values of type
#' @param x A `parsnip` model with the \code{glmnet} engine.
#' @param ... Not used
#' @return A named list with prediction equations for each possibel type.
#' @export
get_expressions <- function(x, ...) {
UseMethod("get_expressions")
}
#' @export
#' @rdname get_expressions
get_expressions._multnet <- function(x, ...) {
list(
class = prediction_eqn(x, type = "class"),
prob = prediction_eqn(x, type = "prob")
)
}
#' @export
#' @rdname get_expressions
get_expressions._lognet <- function(x, ...) {
list(
class = prediction_eqn(x, type = "class"),
prob = prediction_eqn(x, type = "prob")
)
}
#' @export
#' @rdname get_expressions
get_expressions._elnet <- function(x, ...) {
list(
numeric = prediction_eqn(x, type = "numeric")
)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.