R/predict_discrim.R

Defines functions predict_discrim

Documented in predict_discrim

# TODO: add stuff for other discriminant methods: 
#  library(mda) -> mda() -> predict.mda()
# TODO: use data.frame() for result, rather than cbind()
#  
#' Predicted values for discriminant analysis
#' 
#' `predict_discrim` calculates predicted class membership values for a linear or quadratic discriminant analysis,
#' returning a `data.frame` suitable for graphing or other analysis.
#' 
#' @details
#' The `predict()` methods provided for [MASS::lda()] and [MASS::qda()] are a mess, because they return their results as
#' a list, with components `class`, `posterior` and `x`. This function is designed as a wrapper on those to return
#' results in a more consistent and flexible way.
#' 
#' For use in graphs, where you want to show the classification boundaries or regions, you should supply a `newdata` data frame consisting
#' of two focal variables which are varied over their ranges, with the remaining variables used in the discriminant analysis
#' held fixed at typical values.
#' 
#' Using the `scores` argument, the function also returns the scores on the discriminant functions. This is only available for 
#' linear discriminant analysis with [MASS::lda()].
#' 
#' 
#' @param object   An object of class `"lda"` or `"qda"`  such as results from [MASS::lda()] or [MASS::qda()] 
#' @param newdata  A data frame of cases to be classified or, if `object` has a formula, a data frame with columns of the same names as the variables used. A vector will be interpreted as a row vector. If `newdata` is missing, an attempt will be made to retrieve the data used to fit the `lda` object.
#' @param prior The prior probabilities of the classes. By default, taken to be the proportions in what was set in the call to [MASS::lda()] or [MASS::qda()] 
#' @param dimen The dimension of the space to be used. If this is less than the number of available dimensions, 
#'              \eqn{\min(p, ng-1)}, only the first `dimen` discriminant components are used. (This argument is not yet implemented because [MASS::qda()] does not support this.)
#' @param scores A logical. If `TRUE`, the discriminant scores of the cases in `newdata` are appended as additional columns in the the result, with names `LD1`, `LD2`, ...
#' @param posterior Either a logical or the character string `"max"`. If `TRUE`, the posterior probabilities for all classes are included as columns named for the classes. If `FALSE`, these are omitted. If `"max"`, the maximum value of the probabilities across the classes are included, with the variable name `"maxp"`.
#' @param ...      arguments based from or to other methods, not yet used here
#' @md
#' @returns A `data.frame`, containing the the predicted class of the observations, values of the `newdata` variables and the maximum value of the posterior probabilities of the classes. `rownames()` in the result are inherited from those in `newdata`.
#' @importFrom insight get_modelmatrix find_response
#' @importFrom stats predict
#' @export
#' @examples
#' library(candisc)
#' library(MASS)   # for lda()
#' 
#' iris.lda <- lda(Species ~ ., iris)
#' pred_iris <- predict_discrim(iris.lda)
#' names(pred_iris)
#' 
#' # include scores, exclude posterior
#' pred_iris <- predict_discrim(iris.lda, scores = TRUE, posterior = FALSE)
#' names(pred_iris)
#' 
#' data(peng, package="heplots")
#' peng.lda <- lda(species ~ bill_length + bill_depth + flipper_length + body_mass, 
#'                 data = peng)
#' peng_pred <- predict_discrim(peng.lda, scores = TRUE)
#' str(peng_pred)

predict_discrim <- function(object, 
                            newdata, 
                            prior = object$prior,
                            dimen,
                            scores = FALSE,
                            posterior = "max",
                            ...) {
  cls <- class(object)
  if (!cls %in% c("lda", "qda")) {
    stop(paste('object must be of class "lda" or "qda", not', cls))
  }
  if (missing(newdata)) {
    newdata <- insight::get_modelmatrix(object) |>
      as.data.frame() |>
      dplyr::select(-"(Intercept)") 
  }
  nv <- ncol(newdata)
  pred <- predict(object, newdata, prior=prior, type = "prob")
  class <- pred$class
  post <- pred$posterior
  maxp <- apply(post, 1, max)

  # get response variable name to substitute for `class`
  response <- insight::find_response(object)
  
  # if the class variable is already in newdata, remove it
  if (response %in% colnames(newdata)) 
    newdata <- newdata[, !(names(newdata) %in% response)]
  
  
  ret <- cbind(class, newdata)
  colnames(ret)[1] <- response
  
  if (scores) {
    if (cls != "lda") warning(paste("Discriminant scores are not available for objects of class", cls))
    else {
    scores <- pred$x
    ret <- cbind(ret, scores)
    }
  }
  if (posterior == TRUE)
    ret <- cbind(ret, post)
  else if (posterior == "max")
    ret <- cbind(ret, maxp)
  
  ret
}

Try the candisc package in your browser

Any scripts or data that you put into this service are public.

candisc documentation built on Nov. 5, 2025, 7:08 p.m.