R/ZIPLN.R

Defines functions ZIPLN_param ZIPLN

Documented in ZIPLN ZIPLN_param

#' Zero Inflated Poisson lognormal model
#'
#' Fit the multivariate Zero Inflated Poisson lognormal model with a variational algorithm. Use the (g)lm syntax for model specification (covariates, offsets, subset).
#'
#' @inheritParams PLN
#' @param control a list-like structure for controlling the optimization, with default generated by [ZIPLN_param()]. See the associated documentation
#' for details.
#' @param zi a character describing the model used for zero inflation, either of
#' - "single" (default, one parameter shared by all counts)
#' - "col" (one parameter per variable / feature)
#' - "row" (one parameter per sample / individual).
#' If covariates are specified in the formula RHS (see details) this parameter is ignored.
#'
#' @details
#' Covariates for the Zero-Inflation parameter (using a logistic regression model) can be specified in the formula RHS using the pipe
#' (`~ PLN effect | ZI effect`) to separate covariates for the PLN part of the model from those for the Zero-Inflation part.
#' Note that different covariates can be used for each part.
#'
#' @return an R6 object with class [`ZIPLNfit`]
#'
#' @rdname ZIPLN
#' @include ZIPLNfit-class.R
#' @examples
#' data(trichoptera)
#' trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)
#' ## Use different models for zero-inflation...
#' myZIPLN_single <- ZIPLN(Abundance ~ 1, data = trichoptera, zi = "single")
#' \dontrun{
#' myZIPLN_row    <- ZIPLN(Abundance ~ 1, data = trichoptera, zi = "row")
#' myZIPLN_col    <- ZIPLN(Abundance ~ 1, data = trichoptera, zi = "col")
#' ## ...including logistic regression on covariates
#' myZIPLN_covar  <- ZIPLN(Abundance ~ 1 | 1 + Wind, data = trichoptera)
#' }
#' @seealso The class [`ZIPLNfit`]
#' @importFrom stats model.frame model.matrix model.response model.offset terms as.formula
#' @export
ZIPLN <- function(formula, data, subset, zi = c("single", "row", "col"), control = ZIPLN_param()) {

  ## extract the data matrices and weights
  data_ <- extract_model_zi(match.call(expand.dots = FALSE), parent.frame())
  control$ziparam <- ifelse((data_$zicovar), "covar", match.arg(zi))

  ## initialization
  if (control$trace > 0) cat("\n Initialization...")
  myPLN <- switch(control$covariance,
                  "diagonal"  = ZIPLNfit_diagonal$new(data_, control),
                  "spherical" = ZIPLNfit_spherical$new(data_, control),
                  "fixed"     = ZIPLNfit_fixed$new(data_, control),
                  "sparse"    = ZIPLNfit_sparse$new(data_, control),
                  ZIPLNfit$new(data_, control)) # default: full covariance

  ## optimization
  if (control$trace > 0) cat("\n Adjusting a ZI-PLN model with",
                  control$covariance,"covariance model and",
                  control$ziparam, "specific parameter(s) in Zero inflation component.")
  myPLN$optimize(data_, control$config_optim)

  if (control$trace > 0) cat("\n DONE!\n")
  myPLN
}

## -----------------------------------------------------------------
##  Series of setter to default parameters for user's main functions

#' Control of a ZIPLN fit
#'
#' Helper to define list of parameters to control the PLN fit. All arguments have defaults.
#'
#' @inheritParams PLN_param
#' @inheritParams PLNnetwork_param
#' @param penalty a user-defined penalty to sparsify the residual covariance. Defaults to 0 (no sparsity).
#' @return list of parameters used during the fit and post-processing steps
#'
#' @inherit PLN_param details
#' @details See [PLN_param()] and [PLNnetwork_param()] for a full description of the generic optimization parameters. Like [PLNnetwork_param()], ZIPLN_param() has two parameters controlling the optimization due the inner-outer loop structure of the optimizer:
#' * "ftol_out" outer solver stops when an optimization step changes the objective function by less than `ftol_out` multiplied by the absolute value of the parameter. Default is 1e-6
#' * "maxit_out" outer solver stops when the number of iteration exceeds `maxit_out`. Default is 100
#' and one additional parameter controlling the form of the variational approximation of the zero inflation:
#' * "approx_ZI" either uses an exact or approximated conditional distribution for the zero inflation. Default is FALSE
#'
#' @export
ZIPLN_param <- function(
    backend       = c("nlopt"),
    trace         = 1,
    covariance    = c("full", "diagonal", "spherical", "fixed", "sparse"),
    Omega         = NULL,
    penalty       = 0,
    penalize_diagonal = TRUE   ,
    penalty_weights   = NULL   ,
    config_post   = list(),
    config_optim  = list(),
    inception     = NULL     # pretrained ZIPLNfit used as initialization
) {

  covariance <- match.arg(covariance)
  if (covariance == "fixed") stopifnot("Omega must be provied for fixed covariance" = inherits(Omega, "matrix") | inherits(Omega, "Matrix")) %>% try()
  if (inherits(Omega, "matrix") | inherits(Omega, "Matrix")) covariance <- "fixed"
  if (covariance == "sparse") stopifnot("You should provide a positive penalty when chosing 'sparse' covariance" = penalty > 0) %>% try()
  if (penalty > 0) covariance <- "sparse"
  if (!is.null(inception)) stopifnot(isZIPLNfit(inception))

  ## post-treatment config
  config_pst <- config_post_default_PLN
  config_pst[names(config_post)] <- config_post
  config_pst$trace <- trace

  ## optimization config
  stopifnot(backend %in% c("nlopt"))
  stopifnot(config_optim$algorithm %in% available_algorithms_nlopt)
  config_opt <- config_default_nlopt
  config_opt$trace <- trace
  config_opt$ftol_out  <- 1e-6
  config_opt$maxit_out <- 100
  config_opt$approx_ZI <- TRUE
  config_opt[names(config_optim)] <- config_optim

  structure(list(
    backend       = backend   ,
    trace         = trace     ,
    covariance    = covariance,
    Omega         = Omega     ,
    penalty       = penalty   ,
    penalize_diagonal = penalize_diagonal,
    penalty_weights   = penalty_weights  ,
    config_post   = config_pst,
    config_optim  = config_opt,
    inception     = inception), class = "PLNmodels_param")

}
PLN-team/PLNmodels documentation built on April 15, 2024, 9:01 a.m.