R/linear-discriminant.R

#' #' Linear discriminant analysis
#' #'
#' #' Performs a linear discriminant analysis
#' #'
#' #' @param data A data set
#' #' @param group The group
#' #' @param n_class The number of classifications
#' #' @param n_pred The number of predictors
#' #'
#' #' @importFrom MVN mvn
#' #' @importFrom MASS lda
#' #' @importFrom dplyr bind_cols
#' #'
#' #' @references
#' #' \insertRef{boedeker2019linear}{qpm}
#' #'
#' #' @export
#'
#' linear_discrimnant <-
#'   function(data,
#'            group = NULL,
#'            n_class = 3L,
#'            n_pred = 3L) {
#'
#'   }
#'
#' data(airline)
#' data <- airline
#' group <- "JOB"
#' n_class <- n_pred <- 3L
#' n_cases <- nrow(data)
#'
#' classes <- unique(data[[group]])
#'
#' priors <- rep(1 / n_class, n_class)
#' # priors <- proportion(data[[group]])
#'
#' l_data <- data %>%
#'   split(data[[group]])
#'
#' # l_data %>%
#' #   lapply(function(x) {
#' #     mvn(x[1:3], multivariatePlot = "qq", mvnTest = "mardia")
#' #   })
#'
#' output <- lda(JOB ~ ., data, prior = priors)
#' pred <- predict(output) ## posterior values
#' posteriors <- as.data.frame(pred$posterior)
#' colnames(posteriors) <- sprintf("post_%s", levels(pred$class))
#'
#' post_data <- dplyr::bind_cols(data,
#'                               pred_class = pred$class,
#'                               posteriors)
#'
#' means <- l_data %>%
#'   lapply(function(x) {
#'     sapply(x[1:3], mean)
#'   })
#'
#' covs <- l_data %>%
#'   lapply(function(x) {
#'     cov(x[1:3])
#'   })
#'
#' ## covariate df
#' ns <- sapply(l_data, nrow)
#' cov_df <- sum(ns) - n_class
#'
#' ## pooled variance-covariance matrices
#'
#' cv_matrix <- Reduce(`+`,
#'                     lapply(l_data, function(x) {
#'                       (nrow(x) - 1) / cov_df * cov(x[1:3])
#'                     }))
#'
#' ## determinant of hte pooled variance-covariance matrix
#' det_cv <- det(cv_matrix)
#'
#' ## coefficients of LCFs for predictors within each classification
#' coef_preds <- lapply(means, function(m) {
#'   solve(cv_matrix) %*% as.matrix(m)
#' })
#'
#' c_bind(coef_preds)
#'
#' ## intercept of LCFs within each classification
#' intercepts <- mapply(function(x, y)
#'   - 0.5 * t(x) %*% y,
#'   coef_preds,
#'   means)
#'
#' output <- "hello!"
#' final <- structure(
#'   output,
#'   class      = "linear_disc",
#'   group      = group,
#'   classes    = classes,
#'   priors     = priors,
#'   intercepts = intercepts,
#'   coefs      = coef_preds,
#'   means      = means,
#'   cv_matrix  = cv_matrix,
#'   det_cv     = det_cv
#' )
#' final
#'
#' #' @export
#' print.linear_disc <- function(x, ...) {
#'   cat(x, sep = "")
#' }
#'
#'
#' #' @export
#' predict.linear_disc <-
#'   function(object,
#'            x,
#'            method = c("classify", "bayesian"),
#'            ...) {
#'     method <- match.arg(method)
#'     switch(method,
#'            classify = {
#'              if (!is.matrix(x))
#'                x <- as.matrix(x)
#'              mapply(
#'                function(mod_coef_pred,
#'                         mod_intercept,
#'                         mod_prior) {
#'                  x %*% mod_coef_pred + mod_intercept + log(mod_prior)
#'                },
#'                attr(object, "coefs"),
#'                attr(object, "intercepts"),
#'                attr(object, "priors")
#'              )
#'            },
#'            bayesian = {
#'              p <- length(attr(object, "classes"))
#'              d <- attr(object, "det_cv")
#'              cv_matrix <- attr(object, "cv_matrix")
#'              tx_m <- t(x) - m
#'
#'              fs <- sapply(attr(object, "means"),
#'                           function(m) {
#'                             m <- as.matrix(m)
#'                             f <- 1 / prod((((2 * pi) ^ (p / 2) * (d ^ 0.5))),
#'                                           exp(-0.5 * t(tx_m)) %*% solve(cv_matrix) %*% tx_m)
#'                           })
#'              ## posteriors
#'              temp <- mapply(prod, fs, attr(object, "priors"))
#'              posts <- sapply(temp, function(ts)
#'                ts / sum(temp))
#'              rbind(`fs??` = fs, posteriors = posts)
#'            })
#'   }
#' x <- airline[1, 1:3]
#' predict(final, airline[1, 1:3], "b")
#'
#'
#' # utils -------------------------------------------------------------------
#'
#'
#' c_bind <- function(ls) {
#'   res <- Reduce(cbind, ls)
#'   colnames(res) <- names(ls)
#'   res
#' }
jmbarbone/qpm documentation built on July 25, 2020, 10:41 p.m.