R/predict.R

Defines functions predict.MultinomialSLOPE predict.PoissonSLOPE predict.BinomialSLOPE predict.GaussianSLOPE predict.SLOPE

Documented in predict.BinomialSLOPE predict.GaussianSLOPE predict.MultinomialSLOPE predict.PoissonSLOPE predict.SLOPE

#' Generate predictions from SLOPE models
#'
#' Return predictions from models fit by [SLOPE()].
#'
#' @param object an object of class `"SLOPE"`, typically the result of
#'   a call to [SLOPE()]
#' @param x new data
#' @param type type of prediction; `"link"` returns the linear predictors,
#'   `"response"` returns the result of applying the link function,
#'    and `"class"` returns class predictions.
#' @param ... ignored and only here for method consistency
#' @param alpha penalty parameter for SLOPE models; if `NULL`, the
#'   values used in the original fit will be used
#' @param simplify if `TRUE`, [base::drop()] will be called before returning
#'   the coefficients to drop extraneous dimensions
#' @param exact if `TRUE` and the given parameter values differ from those in
#'   the original fit, the model will be refit by calling [stats::update()] on
#'   the object with the new parameters. If `FALSE`, the predicted values
#'   will be based on interpolated coefficients from the original
#'   penalty path.
#' @param sigma deprecated. Please use `alpha` instead.
#'
#' @seealso [stats::predict()], [stats::predict.glm()], [coef.SLOPE()]
#' @family SLOPE-methods
#'
#' @return Predictions from the model with scale determined by `type`.
#'
#' @examples
#' fit <- with(mtcars, SLOPE(cbind(mpg, hp), vs, family = "binomial"))
#' predict(fit, with(mtcars, cbind(mpg, hp)), type = "class")
#' @export
predict.SLOPE <- function(object,
                          x,
                          alpha = NULL,
                          type = "link",
                          simplify = TRUE,
                          sigma,
                          ...) {
  # This method (the base method) only generates linear predictors

  if (inherits(x, "sparseMatrix")) {
    x <- methods::as(x, "dgCMatrix")
  }

  if (inherits(x, "data.frame")) {
    x <- as.matrix(x)
  }

  if (!missing(sigma)) {
    warning("`sigma` is deprecated. Please use `alpha` instead.")
    alpha <- sigma
  }

  beta <- stats::coef(object, alpha = alpha, simplify = FALSE)

  intercept <- "(Intercept)" %in% dimnames(beta)[[1]]

  if (intercept) {
    x <- methods::cbind2(1, x)
  }

  n <- NROW(x)
  p <- NROW(beta)
  m <- NCOL(beta)
  n_penalties <- dim(beta)[3]

  stopifnot(p == NCOL(x))

  lin_pred <- array(
    dim = c(n, m, n_penalties),
    dimnames = list(
      rownames(x),
      dimnames(beta)[[2]],
      dimnames(beta)[[3]]
    )
  )

  for (i in seq_len(n_penalties)) {
    lin_pred[, , i] <- as.matrix(x %*% beta[, , i])
  }

  lin_pred
}

#' @rdname predict.SLOPE
#' @export
predict.GaussianSLOPE <- function(object,
                                  x,
                                  sigma = NULL,
                                  type = c("link", "response"),
                                  simplify = TRUE,
                                  ...) {
  type <- match.arg(type)

  out <- NextMethod(object, type = type) # always linear predictors

  if (simplify) {
    out <- drop(out)
  }

  out
}

#' @rdname predict.SLOPE
#' @export
predict.BinomialSLOPE <- function(object,
                                  x,
                                  sigma = NULL,
                                  type = c("link", "response", "class"),
                                  simplify = TRUE,
                                  ...) {
  type <- match.arg(type)

  lin_pred <- NextMethod(object, type = type)

  out <- switch(
    type,
    link = lin_pred,
    response = 1 / (1 + exp(-lin_pred)),
    class = {
      cnum <- ifelse(lin_pred > 0, 2, 1)
      clet <- object$class_names[cnum]

      if (is.matrix(cnum)) {
        clet <- array(clet, dim(cnum), dimnames(cnum))
      }

      clet
    }
  )

  if (simplify) {
    out <- drop(out)
  }

  out
}

#' @rdname predict.SLOPE
#' @export
predict.PoissonSLOPE <- function(object,
                                 x,
                                 sigma = NULL,
                                 type = c("link", "response"),
                                 exact = FALSE,
                                 simplify = TRUE,
                                 ...) {
  type <- match.arg(type)

  lin_pred <- NextMethod(object, type = type)

  out <- switch(
    type,
    link = lin_pred,
    response = exp(lin_pred)
  )

  if (simplify) {
    out <- drop(out)
  }

  out
}

#' @export
#' @rdname predict.SLOPE
predict.MultinomialSLOPE <- function(object,
                                     x,
                                     sigma = NULL,
                                     type = c("link", "response", "class"),
                                     exact = FALSE,
                                     simplify = TRUE,
                                     ...) {
  type <- match.arg(type)

  lin_pred <- NextMethod(object, type = type, simplify = FALSE)
  m <- NCOL(lin_pred)

  out <- switch(
    type,
    response = {
      n <- nrow(lin_pred)
      m <- ncol(lin_pred)
      path_length <- dim(lin_pred)[3]

      tmp <- array(0, c(n, m + 1, path_length))
      tmp[, 1:m, ] <- lin_pred

      aperm(apply(tmp, c(1, 3), function(x) exp(x) / sum(exp(x))), c(2, 1, 3))
    },
    link = lin_pred,
    class = {
      response <- stats::predict(object, x, type = "response", simplify = FALSE)
      tmp <- apply(response, c(1, 3), which.max)
      class_names <- object$class_names

      predicted_classes <-
        apply(
          tmp,
          2,
          function(a) factor(a, levels = 1:(m + 1), labels = class_names)
        )
      colnames(predicted_classes) <- dimnames(lin_pred)[[3]]
      predicted_classes
    }
  )

  if (simplify) {
    out <- drop(out)
  }

  out
}

Try the SLOPE package in your browser

Any scripts or data that you put into this service are public.

SLOPE documentation built on June 10, 2022, 1:05 a.m.