R/PLNmixture.R

Defines functions PLNmixture_param PLNmixture

Documented in PLNmixture PLNmixture_param

#' Poisson lognormal mixture model
#'
#' Fit the mixture variants of the Poisson lognormal with a variational algorithm. Use the (g)lm syntax for model specification (covariates, offsets).
#'
#' @param formula an object of class "formula": a symbolic description of the model to be fitted.
#' @param data an optional data frame, list or environment (or object coercible by as.data.frame to a data frame) containing the variables in the model. If not found in data, the variables are taken from environment(formula), typically the environment from which lm is called.
#' @param subset an optional vector specifying a subset of observations to be used in the fitting process.
#' @param clusters a vector of integer containing the successive number of clusters (or components) to be considered
#' @param control a list-like structure for controlling the optimization, with default generated by [PLNmixture_param()]. See the associated documentation
#' for details.
#'
#' @return an R6 object with class [`PLNmixturefamily`], which contains
#' a collection of models with class [`PLNmixturefit`]
#'
#' @rdname PLNmixture
#' @examples
#' ## Use future to dispatch the computations on 2 workers
#' \dontrun{
#' future::plan("multisession", workers = 2)
#' }
#'
#' data(trichoptera)
#' trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)
#' myMixtures <- PLNmixture(Abundance ~ 1 + offset(log(Offset)), clusters = 1:4, data = trichoptera,
#'                          control = PLNmixture_param(smoothing = 'none'))
#'
#' # Shut down parallel workers
#' \dontrun{
#' future::plan("sequential")
#' }
#' @seealso The classes [`PLNmixturefamily`], [`PLNmixturefit`] and [PLNmixture_param()]
#' @importFrom stats model.frame model.matrix model.response model.offset update.formula
#' @export
PLNmixture <- function(formula, data, subset, clusters = 1:5,  control = PLNmixture_param()) {

  ## Temporary test for deprecated use of list()
  if (!inherits(control, "PLNmodels_param"))
    stop("We now use the function PLNmixture_param() to generate the list of parameters that controls the fit:
    replace 'list(my_arg = xx)' by PLN_param(my_arg = xx) and see the documentation of PLNmixture_param().")

  # remove the intercept term if any (will be used to deal with group means)
  the_call <- match.call(expand.dots = FALSE)
  the_call$formula <- update.formula(formula(the_call), ~ . -1)

  ## extract the data matrices and weights
  args <- extract_model(the_call, parent.frame())

  ## Instantiate the collection of PLN models
  if (control$trace > 0) cat("\n Initialization...")
  if (control$trace > 0) cat("\n\n Adjusting", length(clusters), "PLN mixture models.\n")
  myPLN <- PLNmixturefamily$new(clusters, args$Y, args$X, args$O, args$formula, control)

  ## Now adjust the PLN models
  myPLN$optimize(control$config_optim)

  ## Smoothing to avoid local minima
  if (control$smoothing != "none" & control$trace > 0) cat("\n\n Smoothing PLN mixture models.\n")
    myPLN$smooth(control)

  ## Post-treatments: Compute pseudo-R2, rearrange criteria and the visualization for PCA
  if (control$trace > 0) cat("\n Post-treatments")
  myPLN$postTreatment(control$config_post, control$config_optim)

  if (control$trace > 0) cat("\n DONE!\n")
  myPLN
}

#' Control of a PLNmixture fit
#'
#' Helper to define list of parameters to control the PLNmixture fit. All arguments have defaults.
#'
#' @param backend optimization back used, either "nlopt" or "torch". Default is "nlopt"
#' @param covariance character setting the model for the covariance matrices of the mixture components. Either "full", "diagonal" or "spherical". Default is "spherical".
#' @param smoothing The smoothing to apply. Either, 'none', forward', 'backward' or 'both'. Default is 'both'.
#' @param init_cl The initial clustering to apply. Either, 'kmeans', CAH' or a user defined clustering given as a list of  clusterings, the size of which is equal to the number of clusters considered. Default is 'kmeans'.
#' @param config_optim a list for controlling the optimizer (either "nlopt" or "torch" backend). See details
#' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.).
#' @param trace a integer for verbosity.
#' @param inception Set up the parameters initialization: by default, the model is initialized with a multivariate linear model applied on
#'    log-transformed data, and with the same formula as the one provided by the user. However, the user can provide a PLNfit (typically obtained from a previous fit),
#'    which sometimes speeds up the inference.
#'
#' @return list of parameters configuring the fit.
#' @details See [PLN_param()] for a full description of the generic optimization parameters. PLNmixture_param() also has additional parameters controlling the optimization due the inner-outer loop structure of the optimizer:
#' * "ftol_out" outer solver stops when an optimization step changes the objective function by less than xtol multiplied by the absolute value of the parameter. Default is 1e-6
#' * "maxit_out" outer solver stops when the number of iteration exceeds maxit_out. Default is 50
#' * "it_smoothing" number of the iterations of the smoothing procedure. Default is 1.
#'
#' @seealso [PLN_param()]
#' @export
PLNmixture_param <- function(
    backend       = "nlopt"    ,
    trace         = 1          ,
    covariance    = "spherical",
    init_cl       = "kmeans"   ,
    smoothing     = "both"     ,
    config_optim  = list()     ,
    config_post   = list()     ,
    inception     = NULL # pretrained PLNfit used as initialization
) {
  if (!is.null(inception)) stopifnot(isPLNfit(inception))

  ## post-treatment config
  config_pst <- config_post_default_PLNmixture
  config_pst[names(config_post)] <- config_post
  config_pst$trace <- trace

  ## optimization config
  backend <- match.arg(backend)
  stopifnot(backend %in% c("nlopt", "torch"))
  if (backend == "nlopt") {
    stopifnot(config_optim$algorithm %in% available_algorithms_nlopt)
    config_opt <- config_default_nlopt
  }
  if (backend == "torch") {
    stopifnot(config_optim$algorithm %in% available_algorithms_torch)
    config_opt <- config_default_torch
  }
  config_opt$ftol_out  <- 1e-3
  config_opt$maxit_out <- 50
  config_opt$it_smooth <- 1
  config_opt[names(config_optim)] <- config_optim
  config_opt$trace <- trace

  structure(list(
    backend       = backend    ,
    trace         = trace      ,
    covariance    = covariance ,
    init_cl       = init_cl    ,
    smoothing     = smoothing  ,
    config_optim  = config_opt ,
    config_post   = config_pst ,
    inception     = inception   ), class = "PLNmodels_param")
}
PLN-team/PLNmodels documentation built on Oct. 13, 2024, 4:01 a.m.