R/PLNLDA.R

Defines functions PLNLDA_param PLNLDA

Documented in PLNLDA PLNLDA_param

#' Poisson lognormal model towards Linear Discriminant Analysis
#'
#' Fit the Poisson lognormal for LDA with a variational algorithm. Use the (g)lm syntax for model specification (covariates, offsets).
#'
#' @param formula an object of class "formula": a symbolic description of the model to be fitted.
#' @param data an optional data frame, list or environment (or object coercible by as.data.frame to a data frame) containing the variables in the model. If not found in data, the variables are taken from environment(formula), typically the environment from which lm is called.
#' @param subset an optional vector specifying a subset of observations to be used in the fitting process.
#' @param weights an optional vector of observation weights to be used in the fitting process.
#' @param grouping a factor specifying the class of each observation used for discriminant analysis.
#' @param control a list-like structure for controlling the optimization, with default generated by [PLN_param()].  See the associated documentation
#'
#' @return an R6 object with class [PLNLDAfit()]
#'
#' @details The parameter `control` is a list controlling the optimization with the following entries:
#' * "covariance" character setting the model for the covariance matrix. Either "full" or "spherical". Default is "full".
#' * "trace" integer for verbosity.
#' * "inception" Set up the initialization. By default, the model is initialized with a multivariate linear model applied on log-transformed data. However, the user can provide a PLNfit (typically obtained from a previous fit), which often speed up the inference.
#' * "ftol_rel" stop when an optimization step changes the objective function by less than ftol multiplied by the absolute value of the parameter. Default is 1e-8
#' * "ftol_abs" stop when an optimization step changes the objective function by less than ftol multiplied by the absolute value of the parameter. Default is 0
#' * "xtol_rel" stop when an optimization step changes every parameters by less than xtol multiplied by the absolute value of the parameter. Default is 1e-6
#' * "xtol_abs" stop when an optimization step changes every parameters by less than xtol multiplied by the absolute value of the parameter. Default is 0
#' * "maxeval" stop when the number of iteration exceeds maxeval. Default is 10000
#' * "maxtime" stop when the optimization time (in seconds) exceeds maxtime. Default is -1 (no restriction)
#' * "algorithm" the optimization method used by NLOPT among LD type, i.e. "CCSAQ", "MMA", "LBFGS", "VAR1", "VAR2". See NLOPT documentation for further details. Default is "CCSAQ".
#'
#' @rdname PLNLDA
#' @examples
#' data(trichoptera)
#' trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)
#' myPLNLDA <- PLNLDA(Abundance ~ 1, grouping = Group, data = trichoptera)
#' @seealso The class [`PLNLDAfit`]
#' @importFrom stats model.frame model.matrix model.response model.offset
#' @export
PLNLDA <- function(formula, data, subset, weights, grouping, control = PLN_param()) {

  ## Temporary test for deprecated use of list()
  if (!inherits(control, "PLNmodels_param"))
    stop("We now use the function PLN_param() to generate the list of parameters that controls the fit:
    replace 'list(my_arg = xx)' by PLN_param(my_arg = xx) and see the documentation of PLN_param().")

  ## look for grouping in the data or the parent frame
  if (inherits(try(eval(grouping), silent = TRUE), "try-error")) {
    grouping <- try(eval(substitute(grouping), data), silent = TRUE)
    if (inherits(grouping, "try-error")) stop("invalid grouping")
  }
  grouping <- as.factor(grouping)

  # force the intercept term if excluded (prevent interferences with group means when coding discrete variables)
  the_call <- match.call(expand.dots = FALSE)
  the_call$formula <- update.formula(formula(the_call), ~ . +1)

  ## extract the data matrices and weight and remove the intercept
  args <- extract_model(the_call, parent.frame())
  args$X <- args$X[ , colnames(args$X) != "(Intercept)", drop = FALSE]

  ## Initialize LDA with appropriate covariance model
  myLDA <- switch(control$covariance,
    "spherical" = PLNLDAfit_spherical$new(grouping, args$Y, args$X, args$O, args$w, args$formula, control),
    "diagonal"  = PLNLDAfit_diagonal$new(grouping, args$Y, args$X, args$O, args$w, args$formula, control),
    PLNLDAfit$new(grouping, args$Y, args$X, args$O, args$w, args$formula, control))

  ## Compute the group means
  if (control$trace > 0) cat("\n Performing discriminant Analysis...")
  myLDA$optimize(grouping, args$Y, args$X, args$O, args$w, control$config_optim)

  ## Post-treatment: prepare LDA visualization
  myLDA$postTreatment(grouping, args$Y, args$X, args$O, control$config_post, control$config_optim)

  if (control$trace > 0) cat("\n DONE!\n")
  myLDA
}

#' Control of a PLNLDA fit
#'
#' Helper to define list of parameters to control the PLNLDA fit. All arguments have defaults.
#'
#' @param backend optimization back used, either "nlopt" or "torch". Default is "nlopt"
#' @param covariance character setting the model for the covariance matrix. Either "full", "diagonal" or "spherical". Default is "full".
#' @param config_optim a list for controlling the optimizer (either "nlopt" or "torch" backend). See details
#' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details
#' @param trace a integer for verbosity.
#' @param inception Set up the parameters initialization: by default, the model is initialized with a multivariate linear model applied on
#'    log-transformed data, and with the same formula as the one provided by the user. However, the user can provide a PLNfit (typically obtained from a previous fit),
#'    which sometimes speeds up the inference.
#'
#' @return list of parameters configuring the fit.
#' @inherit PLN_param details
#' @export
PLNLDA_param <- function(
    backend       = c("nlopt", "torch"),
    trace         = 1,
    covariance    = c("full", "diagonal", "spherical"),
    config_post   = list(),
    config_optim  = list(),
    inception     = NULL     # pretrained PLNfit used as initialization
) {
  backend <- match.arg(backend)
  covariance <- match.arg(covariance)
  if (!is.null(inception)) stopifnot(isPLNfit(inception))

  ## post-treatment config
  config_pst <- config_post_default_PLNLDA
  config_pst[names(config_post)] <- config_post
  config_pst$trace <- trace

  ## optimization config
  stopifnot(backend %in% c("nlopt", "torch"))
  if (backend == "nlopt") {
    stopifnot(config_optim$algorithm %in% available_algorithms_nlopt)
    config_opt <- config_default_nlopt
  }
  if (backend == "torch") {
    stopifnot(config_optim$algorithm %in% available_algorithms_torch)
    config_opt <- config_default_torch
  }
  config_opt[names(config_optim)] <- config_optim
  config_opt$trace <- trace

  structure(list(
    backend       = backend   ,
    trace         = trace     ,
    covariance    = covariance,
    config_post   = config_pst,
    config_optim  = config_opt,
    inception     = inception), class = "PLNmodels_param")
}
PLN-team/PLNmodels documentation built on April 15, 2024, 9:01 a.m.