#' #' Linear discriminant analysis
#' #'
#' #' Performs a linear discriminant analysis
#' #'
#' #' @param data A data set
#' #' @param group The group
#' #' @param n_class The number of classifications
#' #' @param n_pred The number of predictors
#' #'
#' #' @importFrom MVN mvn
#' #' @importFrom MASS lda
#' #' @importFrom dplyr bind_cols
#' #'
#' #' @references
#' #' \insertRef{boedeker2019linear}{qpm}
#' #'
#' #' @export
#'
#' linear_discrimnant <-
#' function(data,
#' group = NULL,
#' n_class = 3L,
#' n_pred = 3L) {
#'
#' }
#'
#' data(airline)
#' data <- airline
#' group <- "JOB"
#' n_class <- n_pred <- 3L
#' n_cases <- nrow(data)
#'
#' classes <- unique(data[[group]])
#'
#' priors <- rep(1 / n_class, n_class)
#' # priors <- proportion(data[[group]])
#'
#' l_data <- data %>%
#' split(data[[group]])
#'
#' # l_data %>%
#' # lapply(function(x) {
#' # mvn(x[1:3], multivariatePlot = "qq", mvnTest = "mardia")
#' # })
#'
#' output <- lda(JOB ~ ., data, prior = priors)
#' pred <- predict(output) ## posterior values
#' posteriors <- as.data.frame(pred$posterior)
#' colnames(posteriors) <- sprintf("post_%s", levels(pred$class))
#'
#' post_data <- dplyr::bind_cols(data,
#' pred_class = pred$class,
#' posteriors)
#'
#' means <- l_data %>%
#' lapply(function(x) {
#' sapply(x[1:3], mean)
#' })
#'
#' covs <- l_data %>%
#' lapply(function(x) {
#' cov(x[1:3])
#' })
#'
#' ## covariate df
#' ns <- sapply(l_data, nrow)
#' cov_df <- sum(ns) - n_class
#'
#' ## pooled variance-covariance matrices
#'
#' cv_matrix <- Reduce(`+`,
#' lapply(l_data, function(x) {
#' (nrow(x) - 1) / cov_df * cov(x[1:3])
#' }))
#'
#' ## determinant of hte pooled variance-covariance matrix
#' det_cv <- det(cv_matrix)
#'
#' ## coefficients of LCFs for predictors within each classification
#' coef_preds <- lapply(means, function(m) {
#' solve(cv_matrix) %*% as.matrix(m)
#' })
#'
#' c_bind(coef_preds)
#'
#' ## intercept of LCFs within each classification
#' intercepts <- mapply(function(x, y)
#' - 0.5 * t(x) %*% y,
#' coef_preds,
#' means)
#'
#' output <- "hello!"
#' final <- structure(
#' output,
#' class = "linear_disc",
#' group = group,
#' classes = classes,
#' priors = priors,
#' intercepts = intercepts,
#' coefs = coef_preds,
#' means = means,
#' cv_matrix = cv_matrix,
#' det_cv = det_cv
#' )
#' final
#'
#' #' @export
#' print.linear_disc <- function(x, ...) {
#' cat(x, sep = "")
#' }
#'
#'
#' #' @export
#' predict.linear_disc <-
#' function(object,
#' x,
#' method = c("classify", "bayesian"),
#' ...) {
#' method <- match.arg(method)
#' switch(method,
#' classify = {
#' if (!is.matrix(x))
#' x <- as.matrix(x)
#' mapply(
#' function(mod_coef_pred,
#' mod_intercept,
#' mod_prior) {
#' x %*% mod_coef_pred + mod_intercept + log(mod_prior)
#' },
#' attr(object, "coefs"),
#' attr(object, "intercepts"),
#' attr(object, "priors")
#' )
#' },
#' bayesian = {
#' p <- length(attr(object, "classes"))
#' d <- attr(object, "det_cv")
#' cv_matrix <- attr(object, "cv_matrix")
#' tx_m <- t(x) - m
#'
#' fs <- sapply(attr(object, "means"),
#' function(m) {
#' m <- as.matrix(m)
#' f <- 1 / prod((((2 * pi) ^ (p / 2) * (d ^ 0.5))),
#' exp(-0.5 * t(tx_m)) %*% solve(cv_matrix) %*% tx_m)
#' })
#' ## posteriors
#' temp <- mapply(prod, fs, attr(object, "priors"))
#' posts <- sapply(temp, function(ts)
#' ts / sum(temp))
#' rbind(`fs??` = fs, posteriors = posts)
#' })
#' }
#' x <- airline[1, 1:3]
#' predict(final, airline[1, 1:3], "b")
#'
#'
#' # utils -------------------------------------------------------------------
#'
#'
#' c_bind <- function(ls) {
#' res <- Reduce(cbind, ls)
#' colnames(res) <- names(ls)
#' res
#' }
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.