R/methods_predict.R

Defines functions predict.beset_rf predict.nested model.matrix.beset predict.beset

Documented in predict.beset

#' Predict Methods for \code{beset} Objects
#'
#' @param object A \code{beset_elnet} object.
#'
#' @param type Type of prediction required. Type "link" gives the linear
#' predictors for "binomial" and "poisson" models; for "gaussian" models it
#' gives the fitted values. Type "response" gives the fitted probabilities for
#' "binomial", fitted mean for "poisson", and fitted values for  "gaussian".
#' Type "coefficients" computes the coefficients. Note that for "binomial"
#' models, results are returned only for the class corresponding to the second
#' level of the factor response. Type "class" applies only to "binomial" models,
#' and produces the class label corresponding to the maximum probability. Type
#' "nonzero" returns a list of the indices of the nonzero coefficients.
#'
#' @inheritParams stats::predict.lm
#' @inheritParams glmnet::predict.glmnet
#' @inheritParams get_best

#' @export
predict.beset <- function(object, newdata, type = "response",
                          newoffset = NULL, alpha = NULL, lambda = NULL,
                          n_pred = NULL, metric = "auto", oneSE = TRUE,
                          na.action = na.pass, tt = NULL, ...){
  if(inherits(object, "rf")){
    return(predict.beset_rf(object, newdata, type = "response", ...))
  }
  metric <- tryCatch(
    match.arg(metric, c("auto", "auc", "mae", "mce", "mse", "rsq")),
    error = function(c){
      c$message <- gsub("arg", "metric", c$message)
      c$call <- NULL
      stop(c)
    }
  )
  tryCatch(
    if(
      (metric == "auc" && object$family != "binomial") ||
      (metric == "mae" && object$family == "binomial")
    ) error = function(c){
      c$message <- paste(metric, "not available for", object$family, "models")
      c$call <- NULL
      stop(c)
    }
  )
  if(metric == "auto"){
    metric <- if(object$family == "gaussian") "mse" else "mce"
  }
  if(is.null(tt)) tt <- terms(object)
  if (missing(newdata) || is.null(newdata)) {
    X <- model.matrix(object)
    newoffset <- object$parameters$offset
  } else {
    Terms <- delete.response(tt)
    m <- model.frame(Terms, newdata, na.action = na.action,
                     xlev = object$xlevels)
    X <- model.matrix(Terms, m, contrasts.arg = object$contrasts)
    if(is.null(newoffset) && all(object$parameters$fit$offset == 0))
      newoffset <- rep(0, nrow(X))
  }
  if(inherits(object, "elnet")){
    model <- get_best.elnet(
      object, alpha = alpha, lambda = lambda, metric = metric, oneSE = oneSE,
      ...
    )
    if("(Intercept)" %in% colnames(X)) X <- X[, -1, drop = FALSE]
    yhat <- predict(
      model, newx = X, s = model$best_lambda, type = type,
      newoffset = newoffset, ...)
  } else {
    model <- get_best.glm(
      object, n_pred = n_pred, metric = metric, oneSE = oneSE, ...
    )
    yhat <- model$family$linkinv(
      X[, names(coef(model)), drop = FALSE] %*% coef(model) + newoffset
    )
  }
  as.vector(yhat)
}

#' @export
model.matrix.beset <- function(object, ...){
  object$parameters$x
}

#' @export
predict.nested <- function(object, newdata, type = "response",
                           newoffset = NULL, alpha = NULL, lambda = NULL,
                           n_pred = NULL, metric = "auto", oneSE = TRUE,
                           na.action = na.pass, ...){
  if (missing(newdata) || is.null(newdata)) {
    newdata <- object$data
  }
  map(
    object$beset, ~ predict(
      ., newdata, type, newoffset, alpha, lambda, n_pred, metric, oneSE,
      na.action, object$terms
    )
  ) %>% transpose %>% simplify_all %>% map_dbl(mean)
}

#' @export
predict.beset_rf <- function(object, newdata, type = "response", ...){
  map(object$forests, ~ predict(., newdata, type, ...)) %>%
    transpose %>% simplify_all %>% map_dbl(mean)
}
jashu/beset documentation built on April 20, 2023, 5:28 a.m.