R/internal_glmHelpers.R

Defines functions clean_fastglm fast_model_matrix prepare.data_cached inline.pred predict_model fit_glm

Documented in clean_fastglm fit_glm inline.pred predict_model prepare.data_cached

#' Fit a GLM using the package specified in params@glm.package
#' @param X model matrix
#' @param y response vector
#' @param family family object (e.g. quasibinomial())
#' @param weights optional prior weights vector
#' @param params SEQparams object
#' @importFrom fastglm fastglm
#' @importFrom parglm parglm.fit parglm.control
#' @importFrom stats binomial
#' @keywords internal
fit_glm <- function(X, y, family, weights = NULL, params, start = NULL) {
  if (!is.null(start) && length(start) != ncol(X)) start <- NULL
  if (params@glm.package == "fastglm") {
    if (is.null(weights)) {
      fastglm(X, y, family = family, method = params@fastglm.method, start = start)
    } else {
      fastglm(X, y, family = family, weights = weights, method = params@fastglm.method, start = start)
    }
  } else {
    # parglm does not support quasi-likelihood families; substitute the
    # equivalent standard family (coefficients are identical, only dispersion differs)
    if (identical(family$family, "quasibinomial")) family <- binomial(link = family$link)
    ctrl <- if (is.null(params@parglm.control)) parglm.control(method = "FAST") else params@parglm.control
    ctrl$nthreads <- params@nthreads
    if (is.null(weights)) {
      parglm.fit(X, y, family = family, control = ctrl, start = start)
    } else {
      parglm.fit(X, y, family = family, weights = weights, control = ctrl, start = start)
    }
  }
}

#' Predict from a model fitted by fit_glm
#' @param model model object returned by fit_glm
#' @param X model matrix for prediction
#' @param type "response" or "link"
#' @keywords internal
predict_model <- function(model, X, type = "response") {
  if (inherits(model, "fastglm")) {
    predict(model, X, type)
  } else {
    eta <- drop(as.matrix(X) %*% coef(model))
    if (type == "response") model$family$linkinv(eta) else eta
  }
}

#' Helper Function to inline predict a fastglm object
#' @param model a fastglm object
#' @param newdata filler for a .SD from data.table
#' @param params parameter from SEQuential
#' @param type type of prediction
#' @param case case type: "default", "LTFU", "visit", "surv"
#' @param multi multinomial flag
#' @param target target level for multinomial
#' @param cache optional formula cache from init_formula_cache
#'
#' @keywords internal

inline.pred <- function(model, newdata, params, type = NULL, case = "default", multi = FALSE, target = NULL, cache = NULL) {
  is_outcome_pred <- case == "surv" || identical(type, "outcome")
  # Use cache if provided, otherwise fall back to parsing
  if (!is.null(cache)) {
    cached <- switch(
      case,
      "default" = switch(
        type,
        "numerator" = cache$numerator,
        "denominator" = cache$denominator,
        "outcome" = cache$covariates
      ),
      "LTFU" = switch(
        type,
        "numerator" = cache$cense_numerator,
        "denominator" = cache$cense_denominator
      ),
      "visit" = switch(
        type,
        "numerator" = cache$visit_numerator,
        "denominator" = cache$visit_denominator
      ),
      "surv" = cache$covariates
    )

    if (!is.null(cached)) {
      factor_cols <- if (params@followup.class && is_outcome_pred && "followup" %in% cached$cols)
        list(followup = 0L:max(params@DT$followup, na.rm = TRUE)) else NULL
      X <- fast_model_matrix(cached$formula, newdata, cached$cols, is_simple = cached$is_simple, factor_cols = factor_cols)
      pred <- if (!multi) predict_model(model, X, "response") else multinomial.predict(model, X, target)
      return(pred)
    }
  }

  # Fallback to original parsing (for backwards compatibility)
  covs <- switch(
    case,
    "default" = switch(
      type,
      "numerator" = params@numerator,
      "denominator" = params@denominator,
      "outcome" = params@covariates
    ),
    "LTFU" = switch(
      type,
      "numerator" = params@cense.numerator,
      "denominator" = params@cense.denominator
    ),
    "visit" = switch(
      type,
      "numerator" = params@visit.numerator,
      "denominator" = params@visit.denominator
    ),
    "surv" = params@covariates
  )
  cols <- formula_vars(covs)
  pred_data <- newdata[, cols, with = FALSE]
  if (params@followup.class && is_outcome_pred && "followup" %in% cols) {
    fup_levels <- 0L:max(params@DT$followup, na.rm = TRUE)
    pred_data[, followup := factor(followup, levels = fup_levels)]
  }
  X <- model.matrix(as.formula(paste0("~", covs)), data = pred_data)

  pred <- if (!multi) predict_model(model, X, "response") else multinomial.predict(model, X, target)
  return(pred)
}

#' Helper function to prepare data for fastglm
#' @param weight data after undergoing preparation
#' @param params parameter from SEQuential
#' @param type type of model, e.g. d0 = "denominator"
#' @param model model number, e.g. d0 = "zero model"
#' @param case case
#' @param cache cache
#'
#' @keywords internal
# Refactored prepare.data - uses pre-computed cache
prepare.data_cached <- function(weight, params, type, model, case, cache) {
  
  # Get the right cached formula/cols based on case and type
  cached <- switch(
    case,
    "default" = if (type == "numerator") cache$numerator else cache$denominator,
    "LTFU" = if (type == "numerator") cache$cense_numerator else cache$cense_denominator,
    "visit" = if (type == "numerator") cache$visit_numerator else cache$visit_denominator,
    "surv" = cache$covariates
  )
  
  if (is.null(cached)) {
    stop("Missing formula cache for case=", case, ", type=", type)
  }
  
  formula <- cached$formula
  cols <- cached$cols
  
  # ----- Case: default -----
  if (case == "default") {
    if (params@weight.lag_condition) {
      weight <- if (type == "numerator" && params@excused) {
        weight[get(cache$tx_bas) == model]
      } else {
        weight[tx_lag == model]
      }
    }
    
    if (type == "denominator" && !params@weight.preexpansion) {
      weight <- weight[followup != 0L]
    }
    
    if (params@excused) {
      target <- match(model, unlist(params@treat.level))
      excused_col <- params@excused.cols[[target]]
      if (!is.na(excused_col)) {
        weight <- weight[get(excused_col) == 0L]
      }
    }
    
    y <- if (!params@weight.preexpansion && (params@excused || params@deviation.excused)) {
      weight[["censored"]]
    } else {
      weight[[params@treatment]]
    }
    
    # ----- Case: LTFU -----
  } else if (case == "LTFU") {
    weight <- weight[!is.na(get(params@cense))]
    # Only compute squared column if not already present
    
    sq_col <- cache$time_sq_col
    if (!sq_col %in% names(weight)) {
      weight[, (sq_col) := get(params@time)^2]
    }
    y <- abs(weight[[params@cense]] - 1L)
    
    # ----- Case: visit -----
  } else if (case == "visit") {
    sq_col <- cache$time_sq_col
    if (!sq_col %in% names(weight)) {
      weight[, (sq_col) := get(params@time)^2]
    }
    y <- weight[[params@visit]]
    
    # ----- Case: surv -----
  } else if (case == "surv") {
    weight <- weight[!is.na(get(params@outcome))]
    y <- if (type == "compevent") weight[[params@compevent]] else weight[[params@outcome]]
  }
  
  # ----- Build design matrix -----
  factor_cols <- if (params@followup.class && case == "surv" && "followup" %in% cols)
    list(followup = 0L:max(params@DT$followup, na.rm = TRUE)) else NULL
  X <- fast_model_matrix(formula, weight, cols, is_simple = cached$is_simple, factor_cols = factor_cols)
  return(list(y = y, X = X))
}

# Fast model matrix builder - avoids overhead for simple cases
fast_model_matrix <- function(formula, data, cols, is_simple = FALSE, factor_cols = NULL) {
  subset_data <- data[, ..cols]

  if (!is.null(factor_cols)) {
    for (nm in names(factor_cols)) {
      if (nm %in% names(subset_data) && !is.factor(subset_data[[nm]]))
        set(subset_data, j = nm, value = factor(subset_data[[nm]], levels = factor_cols[[nm]]))
    }
  }

  # Fast path for simple additive numeric-only models: no setDF needed
  if (isTRUE(is_simple) && all(vapply(subset_data, is.numeric, logical(1)))) {
    X <- as.matrix(subset_data)
    X <- cbind(Intercept = 1, X)
    return(X)
  }

  # Standard path: setDF in-place on the temp subset, then model.matrix
  X <- model.matrix(formula, data = setDF(subset_data), na.action = stats::na.pass)
  return(X)
}

#' Strip large components from a model object returned by fit_glm
#' @param model a model object (fastglm or parglm.fit)
#' @keywords internal
clean_fastglm <- function(model) {
  strip <- function(m) {
    m$x <- NULL
    m$y <- NULL
    m$model <- NULL
    m$fitted.values <- NULL
    m$residuals <- NULL
    m$linear.predictors <- NULL
    m$weights <- NULL
    m$prior.weights <- NULL
    m$qr <- NULL
    m
  }
  if (!is.null(model$models)) {
    model$models <- lapply(model$models, strip)
    return(model)
  }
  strip(model)
}

Try the SEQTaRget package in your browser

Any scripts or data that you put into this service are public.

SEQTaRget documentation built on May 21, 2026, 9:07 a.m.