R/poisson_reg.R

Defines functions multi_predict._fishnet predict_raw._fishnet predict_numeric._fishnet predict._fishnet check_penalty

Documented in multi_predict._fishnet predict_raw._fishnet

# For `predict` methods that use `glmnet`, we have specific methods.
# Only one value of the penalty should be allowed when called by `predict()`:

check_penalty <- function(penalty = NULL, object, multi = FALSE) {
  if (is.null(penalty)) {
    penalty <- object$fit$lambda
  }

  # when using `predict()`, allow for a single lambda
  if (!multi) {
    if (length(penalty) != 1) {
      rlang::abort(
        glue::glue(
          "`penalty` should be a single numeric value. `multi_predict()` ",
          "can be used to get multiple predictions per row of data.",
        )
      )
    }
  }

  if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda) {
    rlang::abort(
      glue::glue(
        "The glmnet model was fit with a single penalty value of ",
        "{object$fit$lambda}. Predicting with a value of {penalty} ",
        "will give incorrect results from `glmnet()`."
      )
    )
  }

  penalty
}

# ------------------------------------------------------------------------------
# glmnet call stack for poissom regression using `predict` when object has
# classes "_fishnet" and "model_fit":
#
#  predict()
# 	predict._fishnet(penalty = NULL)   <-- checks and sets penalty
#    predict.model_fit()             <-- checks for extra vars in ...
#     predict_numeric()
#      predict_numeric._fishnet()
#       predict_numeric.model_fit()
#        predict.fishnet()


# glmnet call stack for poisson regression using `multi_predict` when object has
# classes "_fishnet" and "model_fit":
#
# 	multi_predict()
#    multi_predict._fishnet(penalty = NULL)
#      predict._fishnet(multi = TRUE)          <-- checks and sets penalty
#       predict.model_fit()                  <-- checks for extra vars in ...
#        predict_raw()
#         predict_raw._fishnet()
#          predict_raw.model_fit(opts = list(s = penalty))
#           predict.fishnet()


#' @export
predict._fishnet <-
  function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
    if (any(names(enquos(...)) == "newdata")) {
      rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
    }

    # See discussion in https://github.com/tidymodels/parsnip/issues/195
    if (is.null(penalty) & !is.null(object$spec$args$penalty)) {
      penalty <- object$spec$args$penalty
    }

    object$spec$args$penalty <- check_penalty(penalty, object, multi)

    object$spec <- parsnip::eval_args(object$spec)
    predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
  }

#' @export
predict_numeric._fishnet <- function(object, new_data, ...) {
  if (any(names(enquos(...)) == "newdata")) {
    rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
  }

  object$spec <- parsnip::eval_args(object$spec)
  parsnip::predict_numeric.model_fit(object, new_data = new_data, ...)
}

#' Model predictions across many sub-models
#'
#' For some models, predictions can be made on sub-models in the model object.
#' @param object A `model_fit` object.
#' @param new_data A rectangular data object, such as a data frame.
#' @param opts A list of options..
#' @param ... Optional arguments to pass to `predict.model_fit(type = "raw")`
#'  such as `type`.
#' @return A tibble with the same number of rows as the data being predicted.
#'  There is a list-column named `.pred` that contains tibbles with
#'  multiple rows per sub-model.
#' @export
#' @keywords internal
predict_raw._fishnet <- function(object, new_data, opts = list(), ...) {
  if (any(names(enquos(...)) == "newdata")) {
    rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
  }

  object$spec <- parsnip::eval_args(object$spec)
  opts$s <- object$spec$args$penalty
  parsnip::predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
}

#' @importFrom dplyr full_join as_tibble arrange
#' @importFrom tidyr gather
#' @export
#' @rdname predict_raw._fishnet
#' @param penalty A numeric vector of penalty values.
multi_predict._fishnet <-
  function(object, new_data, type = NULL, penalty = NULL, ...) {
    if (any(names(enquos(...)) == "newdata")) {
      rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
    }

    dots <- list(...)

    object$spec <- eval_args(object$spec)

    if (is.null(penalty)) {
      # See discussion in https://github.com/tidymodels/parsnip/issues/195
      if (!is.null(object$spec$args$penalty)) {
        penalty <- object$spec$args$penalty
      } else {
        penalty <- object$fit$lambda
      }
    }

    pred <- predict._fishnet(
      object,
      new_data = new_data, type = "raw",
      opts = dots, penalty = penalty, multi = TRUE
    )
    param_key <- tibble(group = colnames(pred), penalty = penalty)
    pred <- as_tibble(pred)
    pred$.row <- 1:nrow(pred)
    pred <- gather(pred, group, .pred, -.row)
    if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
      pred <- full_join(param_key, pred, by = "group", multiple = "all")
    } else {
      pred <- full_join(param_key, pred, by = "group")
    }
    pred$group <- NULL
    pred <- arrange(pred, .row, penalty)
    .row <- pred$.row
    pred$.row <- NULL
    pred <- split(pred, .row)
    names(pred) <- NULL
    tibble(.pred = pred)
  }

Try the poissonreg package in your browser

Any scripts or data that you put into this service are public.

poissonreg documentation built on Aug. 22, 2022, 5:07 p.m.