R/glmnet-engines.R

Defines functions set_glmnet_penalty_path .check_glmnet_penalty_predict .check_glmnet_penalty_fit format_glmnet_multinom_class format_glmnet_multinom_prob format_glmnet_multi_multinom_reg format_glmnet_multi_logistic_reg format_glmnet_multi_linear_reg multi_predict_glmnet organize_multnet_prob organize_multnet_class organize_glmnet_prob organize_glmnet_class .organize_glmnet_pred predict_raw_glmnet predict_classprob_glmnet predict_class_glmnet predict_numeric_glmnet predict_glmnet

Documented in .check_glmnet_penalty_fit .check_glmnet_penalty_predict .organize_glmnet_pred

# glmnet call stack using `predict()` when object has
# classes "_<glmnet-class>" and "model_fit":
#
#  predict()
#   predict._<glmnet-class>(penalty = NULL)
#    predict_glmnet(penalty = NULL)             <-- checks and sets penalty
#     predict.model_fit()                       <-- checks for extra vars in ...
#      predict_numeric()
#       predict_numeric._<glmnet-class>()
#        predict_numeric_glmnet()
#         predict_numeric.model_fit()
#          predict.<glmnet-class>()


# glmnet call stack using `multi_predict` when object has
# classes "_<glmnet-class>" and "model_fit":
#
#  multi_predict()
#   multi_predict._<glmnet-class>(penalty = NULL)
#    predict._<glmnet-class>(multi = TRUE)
#     predict_glmnet(multi = TRUE)            <-- checks and sets penalty
#      predict.model_fit()                    <-- checks for extra vars in ...
#       predict_raw()
#        predict_raw._<glmnet-class>()
#         predict_raw_glmnet()
#          predict_raw.model_fit(opts = list(s = penalty))
#           predict.<glmnet-class>()


predict_glmnet <- function(object,
                           new_data,
                           type = NULL,
                           opts = list(),
                           penalty = NULL,
                           multi = FALSE,
                           ...) {
  # See discussion in https://github.com/tidymodels/parsnip/issues/195
  if (is.null(penalty) & !is.null(object$spec$args$penalty)) {
    penalty <- object$spec$args$penalty
  }

  object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi)

  object$spec <- eval_args(object$spec)
  predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
}

predict_numeric_glmnet <- function(object, new_data, ...) {
  object$spec <- eval_args(object$spec)
  predict_numeric.model_fit(object, new_data = new_data, ...)
}

predict_class_glmnet <- function(object, new_data, ...) {
  object$spec <- eval_args(object$spec)
  predict_class.model_fit(object, new_data = new_data, ...)
}

predict_classprob_glmnet <- function(object, new_data, ...) {
  object$spec <- eval_args(object$spec)
  predict_classprob.model_fit(object, new_data = new_data, ...)
}

predict_raw_glmnet <- function(object, new_data, opts = list(), ...)  {
  object$spec <- eval_args(object$spec)

  opts$s <- object$spec$args$penalty

  predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
}

# translation of glmnet classes to parsnip models
# elnet ~ linear_reg
# lognet ~ logistic_reg
# multnet ~ multinom_reg
# glmnetfit: that's a catch-all class for glmnet models fitted with a base-R
#  family, thus can be any of linear_reg, logistic_reg, multinom_reg, poisson_reg

#' @export
predict._elnet <- predict_glmnet

#' @export
predict_numeric._elnet <- predict_numeric_glmnet

#' @export
predict_raw._elnet <- predict_raw_glmnet

#' @export
predict._lognet <- predict_glmnet

#' @export
predict_class._lognet <- predict_class_glmnet

#' @export
predict_classprob._lognet <- predict_classprob_glmnet

#' @export
predict_raw._lognet <- predict_raw_glmnet

#' @export
predict._multnet <- predict_glmnet

#' @export
predict_class._multnet <- predict_class_glmnet

#' @export
predict_classprob._multnet <- predict_classprob_glmnet

#' @export
predict_raw._multnet <- predict_raw_glmnet

#' @export
predict._glmnetfit <- predict_glmnet

#' @export
predict_numeric._glmnetfit <- predict_numeric_glmnet

#' @export
predict_class._glmnetfit <- predict_class_glmnet

#' @export
predict_classprob._glmnetfit <- predict_classprob_glmnet

#' @export
predict_raw._glmnetfit <- predict_raw_glmnet

#' Organize glmnet predictions
#'
#' This function is for developer use and organizes predictions from glmnet
#' models.
#'
#' @param x Predictions as returned by the `predict()` method for glmnet models.
#' @param object An object of class `model_fit`.
#'
#' @rdname glmnet_helpers_prediction
#' @keywords internal
#' @export
.organize_glmnet_pred <- function(x, object) {
  unname(x[, 1])
}

organize_glmnet_class <- function(x, object) {
  prob_to_class_2(x[, 1], object)
}

organize_glmnet_prob <- function(x, object) {
  res <- tibble(v1 = 1 - x[, 1], v2 = x[, 1])
  colnames(res) <- object$lvl
  res
}

organize_multnet_class <- function(x, object) {
  if (vec_size(x) > 1) {
    x <- x[,1]
  } else {
    x <- as.character(x)
  }
  x
}

organize_multnet_prob <- function(x, object) {
  if (vec_size(x) > 1) {
    x <- as_tibble(x[,,1])
  } else {
    x <- tibble::as_tibble_row(x[,,1])
  }
  x
}

# -------------------------------------------------------------------------

multi_predict_glmnet <- function(object,
                                 new_data,
                                 type = NULL,
                                 penalty = NULL,
                                 ...) {
  type <- check_pred_type(object, type)
  check_spec_pred_type(object, type)
  if (type == "prob") {
    check_spec_levels(object)
  }

  dots <- list(...)

  if (object$spec$mode == "classification") {
    if (is_quosure(penalty)) {
      penalty <- eval_tidy(penalty)
    }
  }

  object$spec <- eval_args(object$spec)

  if (is.null(penalty)) {
    # See discussion in https://github.com/tidymodels/parsnip/issues/195
    if (!is.null(object$spec$args$penalty)) {
      penalty <- object$spec$args$penalty
    } else {
      penalty <- object$fit$lambda
    }
  }

  model_type <- class(object$spec)[1]

  if (object$spec$mode == "classification") {
    if (type == "prob" |
        model_type == "logistic_reg") {
      dots$type <- "response"
    } else {
      dots$type <- type
    }
  }

  pred <- predict(object, new_data = new_data, type = "raw",
                  opts = dots, penalty = penalty, multi = TRUE)


  res <- switch(
    model_type,
    "linear_reg" = format_glmnet_multi_linear_reg(pred, penalty = penalty),
    "logistic_reg" = format_glmnet_multi_logistic_reg(pred,
                                                      penalty = penalty,
                                                      type = type,
                                                      lvl = object$lvl),
    "multinom_reg" = format_glmnet_multi_multinom_reg(pred,
                                                      penalty = penalty,
                                                      type = type,
                                                      lvl = object$lvl,
                                                      n_obs = nrow(new_data))
  )

  res
}

#' @export
#' @rdname multi_predict
#' @param penalty A numeric vector of penalty values.
multi_predict._elnet <- multi_predict_glmnet

#' @export
#' @rdname multi_predict
multi_predict._lognet <- multi_predict_glmnet

#' @export
#' @rdname multi_predict
multi_predict._multnet <- multi_predict_glmnet

#' @export
#' @rdname multi_predict
multi_predict._glmnetfit <- multi_predict_glmnet

format_glmnet_multi_linear_reg <- function(pred, penalty) {
  penalty_key <- tibble(s = colnames(pred), penalty = penalty)

  pred <- as_tibble(pred)
  pred$.row <- seq_len(nrow(pred))
  pred <- tidyr::pivot_longer(pred, -.row, names_to = "s", values_to = ".pred")

  pred <- dplyr::full_join(penalty_key, pred, by = "s", multiple = "all")

  pred <- pred %>%
    dplyr::select(-s) %>%
    dplyr::arrange(penalty) %>%
    tidyr::nest(.by = .row, .key = ".pred") %>%
    dplyr::select(-.row)

  pred
}

format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) {
  type <- rlang::arg_match(type, c("class", "prob"))

  penalty_key <- tibble(s = colnames(pred), penalty = penalty)

  pred <- as_tibble(pred)
  pred$.row <- seq_len(nrow(pred))
  pred <- tidyr::pivot_longer(pred, -.row, names_to = "s", values_to = ".pred")

  if (type == "class") {
    pred <- pred %>%
      dplyr::mutate(.pred_class = dplyr::if_else(.pred >= 0.5, lvl[2], lvl[1]),
                    .pred_class = factor(.pred_class, levels = lvl),
                    .keep = "unused")
  } else {
    pred <- pred %>%
      dplyr::mutate(.pred_class_2 = 1 - .pred) %>%
      rlang::set_names(c(".row", "s", paste0(".pred_", rev(lvl)))) %>%
      dplyr::select(c(".row", "s", paste0(".pred_", lvl)))
  }

  pred <- dplyr::full_join(penalty_key, pred, by = "s", multiple = "all")

  pred <- pred %>%
    dplyr::select(-s) %>%
    dplyr::arrange(penalty) %>%
    tidyr::nest(.by = .row, .key = ".pred") %>%
    dplyr::select(-.row)

  pred
}

format_glmnet_multi_multinom_reg <- function(pred, penalty, type, lvl, n_obs) {
  type <- rlang::arg_match(type, c("class", "prob"))

  pred <- switch(
    type,
    prob = format_glmnet_multinom_prob(pred, penalty, lvl, n_obs),
    class = format_glmnet_multinom_class(pred, penalty, lvl, n_obs)
  )

  pred <- pred %>%
    dplyr::arrange(.row, penalty) %>%
    tidyr::nest(.by = .row, .key = ".pred") %>%
    dplyr::select(-.row)

  pred
}

format_glmnet_multinom_prob <- function(pred, penalty, lvl, n_obs) {
  # pred is an array with
  # dim 1 = observations
  # dim 2 = levels of the response
  # dim 3 = penalty values
  apply(pred, 3, as_tibble) %>%
    purrr::list_rbind() %>%
    rlang::set_names(paste0(".pred_", lvl)) %>%
    dplyr::mutate(
      .row = rep(seq_len(n_obs), times = length(penalty)),
      penalty = rep(penalty, each = n_obs)
    ) %>%
    dplyr::relocate(penalty)
}

format_glmnet_multinom_class <- function(pred, penalty, lvl, n_obs) {
  # pred is a matrix n_obs x n_penalty
  # unless n_obs == 1, then it's a vector of length n_penalty
  tibble(
    .row = rep(seq_len(n_obs), times = length(penalty)),
    penalty = rep(penalty, each = n_obs),
    .pred_class = factor(as.vector(pred), levels = lvl)
  )
}

# -------------------------------------------------------------------------

#' Helper functions for checking the penalty of glmnet models
#'
#' @description
#' These functions are for developer use.
#'
#' `.check_glmnet_penalty_fit()` checks that the model specification for fitting a
#' glmnet model contains a single value.
#'
#' `.check_glmnet_penalty_predict()` checks that the penalty value used for prediction is valid.
#' If called by `predict()`, it needs to be a single value. Multiple values are
#' allowed for `multi_predict()`.
#'
#' @param x An object of class `model_spec`.
#' @rdname glmnet_helpers
#' @keywords internal
#' @export
.check_glmnet_penalty_fit <- function(x) {
  pen <- rlang::eval_tidy(x$args$penalty)

  if (length(pen) != 1) {
    rlang::abort(c(
      "For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).",
      glue::glue("There are {length(pen)} values for `penalty`."),
      "To try multiple values for total regularization, use the tune package.",
      "To predict multiple penalties, use `multi_predict()`"
    ))
  }
}

#' @param penalty A penalty value to check.
#' @param object An object of class `model_fit`.
#' @param multi A logical indicating if multiple values are allowed.
#'
#' @rdname glmnet_helpers
#' @keywords internal
#' @export
.check_glmnet_penalty_predict <- function(penalty = NULL, object, multi = FALSE) {
  if (is.null(penalty)) {
    penalty <- object$fit$lambda
  }

  # when using `predict()`, allow for a single lambda
  if (!multi) {
    if (length(penalty) != 1) {
      rlang::abort(
        glue::glue(
          "`penalty` should be a single numeric value. `multi_predict()` ",
          "can be used to get multiple predictions per row of data.",
        )
      )
    }
  }

  if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda) {
    rlang::abort(
      glue::glue(
        "The glmnet model was fit with a single penalty value of ",
        "{object$fit$lambda}. Predicting with a value of {penalty} ",
        "will give incorrect results from `glmnet()`."
      )
    )
  }

  penalty
}

set_glmnet_penalty_path <- function(x) {
  if (any(names(x$eng_args) == "path_values")) {
    # Since we decouple the parsnip `penalty` argument from being the same
    # as the glmnet `lambda` value, `path_values` allows users to set the
    # path differently from the default that glmnet uses. See
    # https://github.com/tidymodels/parsnip/issues/431
    x$method$fit$args$lambda <- x$eng_args$path_values
    x$eng_args$path_values <- NULL
    x$method$fit$args$path_values <- NULL
  } else {
    # See discussion in https://github.com/tidymodels/parsnip/issues/195
    x$method$fit$args$lambda <- NULL
  }
  x
}

Try the parsnip package in your browser

Any scripts or data that you put into this service are public.

parsnip documentation built on Aug. 18, 2023, 1:07 a.m.