R/fibre-fit.R

Defines functions fibre_impl fibre_bridge fibre.recipe fibre.formula fibre.matrix fibre.data.frame fibre.default fibre

Documented in fibre fibre fibre.data.frame fibre.default fibre.formula fibre.matrix fibre.recipe

#' Fit a `fibre`
#'
#' `fibre()` fits a model.
#'
#' @param x Depending on the context:
#'
#'   * A __data frame__ of predictors.
#'   * A __matrix__ of predictors.
#'   * A __recipe__ specifying a set of preprocessing steps
#'     created from [recipes::recipe()].
#'
#' @param y When `x` is a __data frame__ or __matrix__, `y` is the outcome
#' specified as:
#'
#'   * A __data frame__ with 1 numeric column.
#'   * A __matrix__ with 1 numeric column.
#'   * A numeric __vector__.
#'
#' @param data When a __recipe__ or __formula__ is used, `data` is specified as:
#'
#'   * A __data frame__ containing both the predictors and the outcome.
#'
#' @param formula A formula specifying the outcome terms on the left-hand side,
#' and the predictor terms on the right-hand side.
#'
#' @param intercept A logical. Should an intercept be included in the model?
#' @param engine A single character. The engine to use for fitting the model.
#' @param engine_options A list of options to pass to the engine.
#' @param ncores An integer. The number of cores to use for parallel processing.
#' @param verbose An integer. The level of verbosity.
#' @param fit A logical. Should the model be fit? If `FALSE`, the model is not fit
#' and instead a list of data is returned that can be used to fit the model later.
#' Useful for debugging or for fitting the model using a custom model design or
#' currently unsupported engine.
#'
#' @param ... Not currently used, but required for extensibility.
#'
#' @return
#'
#' A `fibre` object.
#'
#' @examples
#' predictors <- mtcars[, -1]
#' outcome <- mtcars[, 1]
#'
#' # XY interface
#' #mod <- fibre(predictors, outcome)
#'
#' # Formula interface
#' #mod2 <- fibre(mpg ~ ., mtcars)
#'
#' # Recipes interface
#' #library(recipes)
#' #rec <- recipe(mpg ~ ., mtcars)
#' #rec <- step_log(rec, disp)
#' #mod3 <- fibre(rec, mtcars)
#'
#' @export
fibre <- function(x, ...) {
  UseMethod("fibre")
}

#' @export
#' @rdname fibre
fibre.default <- function(x, ...) {
  stop("`fibre()` is not defined for a '", class(x)[1], "'.", call. = FALSE)
}

# XY method - data frame

#' @export
#' @rdname fibre
fibre.data.frame <- function(x, y,
                             intercept = TRUE,
                             engine = c("inla", "glmnet", "torch"),
                             engine_options = list(),
                             ncores = NULL,
                             verbose = 0,
                             fit = TRUE,
                             ...) {
  engine <- match.arg(engine)
  processed <- hardhat::mold(x, y)
  fibre_bridge(processed, engine, engine_options, ncores = ncores, verbose = verbose, fit = fit, ...)
}

# XY method - matrix

#' @export
#' @rdname fibre
fibre.matrix <- function(x, y,
                         intercept = TRUE,
                         engine = c("inla", "glmnet", "torch"),
                         engine_options = list(),
                         ncores = NULL,
                         verbose = 0,
                         fit = TRUE,
                         ...) {
  engine <- match.arg(engine)
  processed <- hardhat::mold(x, y)
  fibre_bridge(processed, engine, engine_options, ncores = ncores, verbose = verbose, fit = fit, ...)
}

# Formula method

#' @export
#' @rdname fibre
fibre.formula <- function(formula, data,
                          intercept = TRUE,
                          family = "gaussian",
                          engine = c("inla", "glmnet", "torch"),
                          engine_options = list(),
                          ncores = NULL,
                          verbose = 0,
                          fit = TRUE,
                          ...) {
  engine <- match.arg(engine)
  processed <- hardhat::mold(formula, data,
                             blueprint = fibre_formula_blueprint(intercept = intercept))

  if(engine == "glmnet" && length(processed$extras$model_info) > 1) {
    rlang::warn('engine = "glmnet" currently only concatenates multiple bre() calls in a model (so only one parameter is fit on all rates across all bre calls).')
  }
  # if(engine == "glmnet" && ncol(processed$outcomes) > 1 && !family %in% c("multinomial", "mgaussian")) {
  #   rlang::abort('engine = "glmnet" currently only supports "mgaussian" and "multinomial" for multiple outcome models.')
  # }
  fibre_bridge(processed, family, engine, engine_options, ncores = ncores, verbose = verbose, fit = fit, ...)
}

# Recipe method

#' @export
#' @rdname fibre
fibre.recipe <- function(x, data,
                         intercept = TRUE,
                         engine = c("inla", "glmnet", "torch"),
                         engine_options = list(),
                         ncores = NULL,
                         verbose = 0,
                         fit = TRUE,
                         ...) {
  engine <- match.arg(engine)
  processed <- hardhat::mold(x, data)
  fibre_bridge(processed, engine, engine_options, ncores = ncores, verbose = verbose, fit = fit, ...)
}

# ------------------------------------------------------------------------------
# Bridge

fibre_bridge <- function(processed, family, engine, engine_options, ncores, verbose, fit, ...) {

  predictors <- processed$predictors
  outcomes <- processed$outcomes
  offset <- processed$extras$offset
  pfcs <- purrr::map(processed$extras$model_info,
                     "phyf")
  rate_dists <- purrr::map(processed$extras$model_info,
                           "rate_dist")
  hypers <- purrr::map(processed$extras$model_info,
                       "hyper")
  latents <- purrr::map(processed$extras$model_info,
                        "latent")
  mixture_ofs <- purrr::map(processed$extras$model_info,
                           "mixture_of")
  labels <- purrr::map(processed$extras$model_info,
                           "label")


  fit <- fibre_impl(predictors, outcomes,
                    offset, pfcs,
                    rate_dists, hypers,
                    latents, labels,
                    family,
                    engine,
                    engine_options,
                    processed$blueprint,
                    ncores = ncores,
                    verbose = verbose,
                    fit = fit)

  return(fit)

}


# ------------------------------------------------------------------------------
# Implementation

fibre_impl <- function(predictors, outcomes,
                       offset, pfcs,
                       rate_dists, hypers,
                       latents, labels,
                       family,
                       engine,
                       engine_options,
                       blueprint,
                       ncores,
                       verbose,
                       fit) {

  if(verbose > 0) {
    cli::cli_progress_message("Setting up data for model...")
  }

  if(engine == "glmnet" && sum(unlist(latents)) > 0) {
    rlang::abort('engine = "glmnet" does not support argument latent > 0, try engine = "torch"')
  }

  if(engine == "inla" && sum(unlist(latents)) > 0) {
    rlang::abort('engine = "inla" does not support argument latent > 0, try engine = "torch"')
  }

  if(engine == "glmnet") {
    complete <- complete.cases(outcomes)
    to_predict <- list(predictors = predictors,
                       pfcs = pfcs,
                       outcomes = outcomes)
    outcomes <- outcomes[complete, ]
    predictors <- predictors[complete, ]
    pfcs <- purrr::map(pfcs,
                       ~ .x[complete])
  }

  dat_list <- switch(engine,
                     inla = shape_data_inla(pfcs,
                                            predictors,
                                            !family %in% c("surv"),
                                            outcomes,
                                            latents),
                     glmnet = shape_data_glmnet(pfcs,
                                                predictors,
                                                !family %in% c("mgaussian", "multinomial"),
                                                outcomes,
                                                rate_dists))

  #return(dat_list)

  form <- switch(engine,
                 inla = make_inla_formula(dat_list$dat, dat_list$y),
                 glmnet = NULL)


  if(engine == "inla") {

    family <- get_families(family, colnames(dat_list$y))
    family_hyper <- family$hyper
    family <- family$family

    if(!is.null(engine_options$control.family$hyper)) {
      family_hyper <- engine_options$control.family$hyper
      engine_options$control.family$hyper <- NULL
    }

    #print(hypers)
    hyper_names <- form$hypers
    form <- form$form
    inla_dat <- INLA::inla.stack(data = list(y = dat_list$y),
                                 A = list(dat_list$A),
                                 effects = list(dat_list$dat),
                                 compress = FALSE,
                                 remove.unused = FALSE)
    if(!is.null(ncores)) {
      num.threads <- ncores
    } else {
      num.threads <- INLA::inla.getOption("num.threads")
    }
    inla_options <- list(control.predictor = list(A = INLA::inla.stack.A(inla_dat),
                                                  link = 1,
                                                  compute = TRUE),
                         control.family = family_hyper,
                         control.compute = list(config = TRUE),
                         inla.mode = "experimental",
                         verbose = verbose > 1,
                         num.threads = num.threads)
    inla_options <- utils::modifyList(inla_options, engine_options, keep.null = TRUE)

    n_re <- length(hyper_names$re)
    if(n_re > 0) {
      hypers_re <- hypers[seq_along(hyper_names$re)]
      names(hypers_re) <- hyper_names$re
      rlang::env_bind(rlang::f_env(form), !!!hypers_re)
    }

    if(length(hypers$latent) > 0) {
      hypers_latent <- hypers[seq_along(hypers$latent) + n_re]
      names(hypers_latent) <- hypers$latent
      hypers_copy <- rep(list(list(beta = list(fixed = FALSE))), length(hypers$copy))
      names(hypers_copy) <- hypers$copy
      rlang::env_bind(rlang::f_env(form), !!!hypers_latent)
      rlang::env_bind(rlang::f_env(form), !!!hypers_copy)
    }

  }

  if(engine == "glmnet") {
    glmnet_options <- list(standardize = FALSE,
                           nlambda = 200,
                           lambda.min.ratio = 0.001 / nrow(dat_list$dat),
                           trace.it = as.numeric(verbose > 1))
    if(!is.null(engine_options$alpha)) {
      alpha <- engine_options$alpha
      engine_options$alpha <- NULL
    } else {
      alpha <- 1
    }
    glmnet_options <- utils::modifyList(glmnet_options, engine_options, keep.null = TRUE)
  }

  #return(list(inla_dat, form, family, family_hyper))

  if(verbose > 0) {
    cli::cli_progress_message("Fitting model with {engine}...")
  }

  if(fit) {
    fit <- switch(engine,
                  inla = rlang::exec(INLA::inla, formula = form,
                                     data = INLA::inla.stack.data(inla_dat),
                                     family = family,
                                     !!!inla_options),
                  glmnet = rlang::exec(glmnet::glmnet,
                                       x = dat_list$dat,
                                       dat_list$y,
                                       family = family,
                                       alpha = alpha,
                                       penalty.factor = dat_list$penalty_factor,
                                       intercept = FALSE,
                                       !!!glmnet_options),
                  rlang::abort("Invalid engine argument")
    )

    if(verbose > 0) {
      cli::cli_progress_message("Post processing model outputs...")
    }

    fit <- list(fit = fit, renamer = dat_list$renamer)
    #print(dat_list$dat_uncompressed)
    if(engine == "glmnet" && !is.null(engine_options$what)) {
      what <- engine_options$what
    } else {
      what <- c("phylosig")
    }
    return(switch(engine,
           inla = fibre_process_fit_inla(fit, blueprint,
                                         dat_list$dat_uncompressed,
                                         pfcs,
                                         rate_dists,
                                         labels,
                                         engine,
                                         dat_list,
                                         verbose = verbose),
           glmnet = fibre_process_fit_glmnet(fit, blueprint,
                                             dat_list,
                                             labels,
                                             alpha,
                                             to_predict,
                                             engine,
                                             rate_dists,
                                             pfcs,
                                             verbose = verbose,
                                             ncores = ncores,
                                             what = what,
                                             n_y = ncol(outcomes),
                                             multi = !family %in% c("mgaussian", "multinomial"))
           ))

  } else {
    attr(dat_list, "formula") <- form
    return(dat_list)
  }
}
rdinnager/fibre documentation built on Dec. 14, 2024, 10:33 a.m.