Nothing
#' Creates an R expression for a linear predictor from a data frame of terms and
#' coefficients
#' @param x An object that uses a [glmnet::glmnet()] model and all numeric predictors.
#' @param ... Not currently used.
#' @return An R expression or a list of R expressions, depending on the type of
#' model being used.
#' @export
#' @keywords internal
build_linear_predictor <- function(x, ...) {
UseMethod("build_linear_predictor")
}
#' @import rlang
# `x` should be a tidy-type format for the data with columns terms and estimate.
build_linear_predictor_eng <- function(x, ...) {
slopes <- x |> dplyr::filter(terms != "(Intercept)")
lin_pred <- purrr::map2(
slopes$terms,
slopes$estimate,
function(.x, .y) rlang::expr((!!sym(.x) * !!.y))
)
if (any(x$terms == "(Intercept)")) {
beta_0 <- x$estimate[x$terms == "(Intercept)"]
lin_pred <- c(beta_0, lin_pred)
}
lin_pred <- purrr::reduce(lin_pred, function(l, r) rlang::expr(!!l + !!r))
lin_pred
}
# This is used for linear and logistic regression
build_linear_predictor_glmnet <- function(x, ...) {
lvls <- x$lvl
coefs <-
.get_glmn_coefs(x$fit, x$spec$args$penalty) |>
dplyr::filter(estimate != 0)
lp <- build_linear_predictor_eng(coefs)
lp
}
#' @export
#' @rdname build_linear_predictor
build_linear_predictor._elnet <- function(x, ...) {
build_linear_predictor_glmnet(x, ...)
}
#' @export
#' @rdname build_linear_predictor
build_linear_predictor._lognet <- function(x, ...) {
build_linear_predictor_glmnet(x, ...)
}
#' @export
#' @rdname build_linear_predictor
build_linear_predictor._multnet <- function(x, ...) {
lvls <- x$lvl
coefs <-
.get_glmn_coefs(x$fit, x$spec$args$penalty) |>
dplyr::filter(estimate != 0) |>
dplyr::group_nest(class, .key = "coefs") |>
dplyr::mutate(lp = purrr::map(coefs, build_linear_predictor_eng))
coefs
}
## -----------------------------------------------------------------------------
#' Convert one or more linear predictor to a format used for prediction
#'
#' @inheritParams build_linear_predictor
#' @param type The prediction type.
#' @return The return type varies, based on the model and prediction type.
#' @export
#' @keywords internal
prediction_eqn <- function(x, ...) {
UseMethod("prediction_eqn")
}
eqn_constuctor <- function(x, model, type, lvls) {
type <- match.arg(type, c("class", "prob", "numeric"))
new_class <- paste(model, type, sep = "_")
structure(x, class = new_class, levels = lvls)
}
#' @export
#' @rdname prediction_eqn
prediction_eqn._lognet <- function(x, type = "class", ...) {
type <- match.arg(type, c("class", "prob"))
model_class <- class(x$fit)[1]
lp <- build_linear_predictor(x)
# glmnet models the probability of the _second_ class
lvls <- x$lvl
if (type == "prob") {
elem_names <- paste0(".pred_", lvls)
res <- vector(mode = "list", length = 2)
names(res) <- elem_names
res[[elem_names[2]]] <- rlang::expr(stats::binomial()$linkinv(!!lp))
res[[elem_names[1]]] <- rlang::expr(stats::binomial()$linkinv(-(!!lp)))
} else {
res <-
list(
.pred_class = rlang::expr(factor(
ifelse((!!lp) > 0, !!lvls[2], !!lvls[1]),
levels = !!lvls
))
)
}
eqn_constuctor(res, model_class, type, lvls)
}
#' @export
#' @rdname prediction_eqn
prediction_eqn._elnet <- function(x, type = "numeric", ...) {
type <- match.arg(type, "numeric")
model_class <- class(x$fit)[1]
res <- list(.pred = build_linear_predictor(x))
eqn_constuctor(res, model_class, type, NULL)
}
eexp <- function(x) rlang::expr(exp(!!x))
#' @export
#' @rdname prediction_eqn
prediction_eqn._multnet <- function(x, type = "class", ...) {
type <- match.arg(type, c("class", "prob"))
model_class <- class(x$fit)[1]
lvls <- x$lvl
res <- build_linear_predictor(x) |>
dplyr::mutate(.pred = purrr::map(lp, eexp))
names(res$.pred) <- paste0(".pred_", res$class)
eqn_constuctor(res, model_class, type, lvls)
}
## -----------------------------------------------------------------------------
#' Convert one or more linear predictor to a format used for prediction
#'
#' @inheritParams build_linear_predictor
#' @param x A set of model expressions generated by [prediction_eqn()].
#' @return The return type varies, based on the model and prediction type.
#' @export
#' @keywords internal
stack_predict <- function(x, ...) {
UseMethod("stack_predict")
}
#' @export
#' @rdname stack_predict
stack_predict.elnet_numeric <- function(x, data, ...) {
tibble::tibble(.pred = rlang::eval_tidy(x$.pred, data))
}
#' @export
#' @rdname stack_predict
stack_predict.lognet_class <- function(x, data, ...) {
tibble::tibble(.pred_class = rlang::eval_tidy(x$.pred_class, data))
}
#' @export
#' @rdname stack_predict
stack_predict.lognet_prob <- function(x, data, ...) {
purrr::map_dfc(x, rlang::eval_tidy, data = data)
}
multi_net_engine <- function(x, data, ...) {
res <-
purrr::map_dfc(x$.pred, rlang::eval_tidy, data = data) |>
multi_net_helper()
}
multi_net_helper <- function(data, ...) {
data |>
dplyr::rowwise() |>
dplyr::mutate(
.sum = sum(dplyr::c_across(dplyr::starts_with(".pred_")))
) |>
dplyr::mutate(
dplyr::across(dplyr::starts_with(".pred_"), function(.x) .x / .sum),
idx = which.max(dplyr::c_across(dplyr::starts_with(".pred_")))
) |>
dplyr::ungroup()
}
#' @export
#' @rdname stack_predict
stack_predict.multnet_class <- function(x, data, ...) {
lvls <- attr(x, "levels")
res <- multi_net_engine(x, data)
res <-
res |>
dplyr::mutate(
.pred_class = gsub(".pred_", "", colnames(res)[idx]),
.pred_class = factor(.pred_class, levels = lvls)
) |>
dplyr::select(.pred_class)
res
}
#' @export
#' @rdname stack_predict
stack_predict.multnet_prob <- function(x, data, ...) {
multi_net_engine(x, data) |>
dplyr::select(dplyr::starts_with(".pred_"))
}
## -----------------------------------------------------------------------------
#' Obtain prediction equations for all possible values of type
#' @param x A `parsnip` model with the \code{glmnet} engine.
#' @param ... Not used
#' @return A named list with prediction equations for each possibel type.
#' @export
get_expressions <- function(x, ...) {
UseMethod("get_expressions")
}
#' @export
#' @rdname get_expressions
get_expressions._multnet <- function(x, ...) {
list(
class = prediction_eqn(x, type = "class"),
prob = prediction_eqn(x, type = "prob")
)
}
#' @export
#' @rdname get_expressions
get_expressions._lognet <- function(x, ...) {
list(
class = prediction_eqn(x, type = "class"),
prob = prediction_eqn(x, type = "prob")
)
}
#' @export
#' @rdname get_expressions
get_expressions._elnet <- function(x, ...) {
list(
numeric = prediction_eqn(x, type = "numeric")
)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.