R/OutcomeModels.R

Defines functions getOutcomeModel print.OutcomeModel confint.OutcomeModel coef.OutcomeModel getTimeAtRisk createSubgroupCounts getOutcomeCounts filterAndTidyCovariates modelTypeToCyclopsModelType getInformativePopulation fitOutcomeModel

Documented in fitOutcomeModel getOutcomeModel

# Copyright 2026 Observational Health Data Sciences and Informatics
#
# This file is part of CohortMethod
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#' Create an outcome model, and compute the relative risk
#'
#' @details
#' For likelihood profiling, either specify the `profileGrid` for a completely user- defined grid, or
#' `profileBounds` for an adaptive grid. Both should be defined on the log effect size scale. When both
#' `profileGrid` and `profileGrid` are `NULL` likelihood profiling is disabled.
#'
#' @description
#' Create an outcome model, and computes the relative risk
#'
#' @param population            A population object generated by [createStudyPopulation()],
#'                              potentially filtered by other functions.
#'
#' @param cohortMethodData      An object of type [CohortMethodData] as generated using
#'                              [getDbCohortMethodData()]. Can be omitted if not using covariates and
#'                              not using interaction terms.
#' @param fitOutcomeModelArgs   An object of type `FitOutcomeModelArgs` as generated using the
#'                              [createFitOutcomeModelArgs()] function.
#'
#' @return
#' An object of class `OutcomeModel`. Generic function `print`, `coef`, and
#' `confint` are available.
#'
#' @export
fitOutcomeModel <- function(population,
                            cohortMethodData = NULL,
                            fitOutcomeModelArgs = createFitOutcomeModelArgs()) {
  errorMessages <- checkmate::makeAssertCollection()
  checkmate::assertDataFrame(population, null.ok = TRUE, add = errorMessages)
  checkmate::assertNames(names(population), must.include = c("rowId", "outcomeCount", "treatment", "timeAtRisk", "survivalTime", "personSeqId"), add = errorMessages)
  if (!is.null(cohortMethodData)) {
    checkmate::assertNames(names(cohortMethodData), must.include = c("analysisRef", "cohorts", "covariateRef", "covariates", "outcomes"), add = errorMessages)
  }
  checkmate::assertClass(cohortMethodData, "CohortMethodData", null.ok = TRUE, add = errorMessages)
  checkmate::assertR6(fitOutcomeModelArgs, "FitOutcomeModelArgs", add = errorMessages)
  checkmate::reportAssertions(collection = errorMessages)
  if (fitOutcomeModelArgs$stratified && nrow(population) > 0 && is.null(population$stratumId)) {
    stop("Requested stratified analysis, but no stratumId column found in population. Please use matchOnPs or stratifyByPs to create strata.")
  }
  if (is.null(cohortMethodData) && fitOutcomeModelArgs$useCovariates) {
    stop("Requested all covariates for model, but no cohortMethodData object specified")
  }
  if (is.null(cohortMethodData) && length(fitOutcomeModelArgs$interactionCovariateIds) != 0) {
    stop("Requesting interaction terms in model, but no cohortMethodData object specified")
  }
  if (fitOutcomeModelArgs$inversePtWeighting && is.null(population$iptw)) {
    stop("Requested inverse probability weighting, but no IPTW are provided. Use createPs to generate them")
  }

  start <- Sys.time()
  treatmentEstimate <- NULL
  interactionEstimates <- NULL
  mainEffectEstimates <- NULL
  mainEffectTerms <- NULL
  coefficients <- NULL
  fit <- NULL
  priorVariance <- NULL
  logLikelihood <- NA
  treatmentVarId <- NA
  subgroupCounts <- NULL
  logLikelihoodProfile <- NULL
  prior <- fitOutcomeModelArgs$prior
  status <- "NO MODEL FITTED"
  outcomeModel <- attr(population, "metaData")
  outcomeModel$outcomeModelType <- fitOutcomeModelArgs$modelType
  outcomeModel$outcomeModelStratified <- fitOutcomeModelArgs$stratified
  outcomeModel$outcomeModelUseCovariates <- fitOutcomeModelArgs$useCovariates
  outcomeModel$inversePtWeighting <- fitOutcomeModelArgs$inversePtWeighting
  if (fitOutcomeModelArgs$inversePtWeighting) {
    outcomeModel$targetEstimator <- outcomeModel$iptwEstimator
  }
  outcomeModel$iptwEstimator <- NULL
  outcomeModel$populationCounts <- getCounts(population, "Population count")
  outcomeModel$outcomeCounts <- getOutcomeCounts(population, fitOutcomeModelArgs$modelType)
  outcomeModel$timeAtRisk <- getTimeAtRisk(population, fitOutcomeModelArgs$modelType)

  if (nrow(population) == 0) {
    status <- "NO SUBJECTS IN POPULATION, CANNOT FIT"
  } else if (sum(population$outcomeCount) == 0) {
    status <- "NO OUTCOMES FOUND FOR POPULATION, CANNOT FIT"
  } else {
    # Informative population ---------------------------------------------------------
    informativePopulation <- getInformativePopulation(
      population = population,
      stratified = fitOutcomeModelArgs$stratified,
      inversePtWeighting = fitOutcomeModelArgs$inversePtWeighting,
      modelType = fitOutcomeModelArgs$modelType
    )

    if (sum(informativePopulation$treatment == 1) == 0 || sum(informativePopulation$treatment == 0) == 0) {
      status <- "NO STRATA WITH BOTH TARGET, COMPARATOR, AS WELL AS THE OUTCOME. CANNOT FIT"
    } else {
      if (fitOutcomeModelArgs$useCovariates) {
        # Add covariates ---------------------------------------------------------------------------------
        covariateData <- filterAndTidyCovariates(
          cohortMethodData = cohortMethodData,
          includeRowIds = informativePopulation$rowId,
          includeCovariateIds = fitOutcomeModelArgs$includeCovariateIds,
          excludeCovariateIds = fitOutcomeModelArgs$excludeCovariateIds
        )
        on.exit(close(covariateData))
        outcomeModel$deletedRedundantCovariateIdsForOutcomeModel <- attr(covariateData, "metaData")$deletedRedundantCovariateIds
        outcomeModel$deletedInfrequentCovariateIdsForOutcomeModel <- attr(covariateData, "metaData")$deletedInfrequentCovariateIds

        mainEffectTerms <- covariateData$covariates |>
          distinct(.data$covariateId) |>
          inner_join(covariateData$covariateRef, by = "covariateId") |>
          select(id = "covariateId", name = "covariateName") |>
          collect()

        treatmentVarId <- cohortMethodData$covariates |>
          summarise(value = max(.data$covariateId, na.rm = TRUE)) |>
          pull() + 1

        treatmentCovariate <- informativePopulation |>
          select("rowId", covariateValue = "treatment") |>
          mutate(covariateId = treatmentVarId)

        appendToTable(covariateData$covariates, treatmentCovariate)

        if (fitOutcomeModelArgs$stratified || fitOutcomeModelArgs$modelType == "cox") {
          prior$exclude <- treatmentVarId # Exclude treatment variable from regularization
        } else {
          prior$exclude <- c(0, treatmentVarId) # Exclude treatment variable and intercept from regularization
        }
      } else {
        # Don't add covariates, only use treatment as covariate ----------------------------------------------
        treatmentVarId <- 1

        treatmentCovariate <- informativePopulation |>
          select("rowId", covariateValue = "treatment") |>
          mutate(covariateId = treatmentVarId)

        covariateData <- Andromeda::andromeda(
          covariates = treatmentCovariate,
          covariateRef = cohortMethodData$covariateRef
        )
        on.exit(close(covariateData))
        prior <- createPrior("none")
      }

      # Interaction terms -----------------------------------------------------------------------------------
      interactionTerms <- NULL
      if (length(fitOutcomeModelArgs$interactionCovariateIds) != 0) {
        covariateData$covariatesSubset <- cohortMethodData$covariates |>
          filter(.data$covariateId %in% fitOutcomeModelArgs$interactionCovariateIds) |>
          filter(.data$rowId %in% local(informativePopulation$rowId))

        # Cannot have interaction terms without main effects, so add covariates if they haven't been included already:
        if (!fitOutcomeModelArgs$useCovariates) {
          appendToTable(covariateData$covariates, covariateData$covariatesSubset)
          mainEffectTerms <- covariateData$covariatesSubset |>
            distinct(.data$covariateId) |>
            inner_join(covariateData$covariateRef, by = "covariateId") |>
            select(id = "covariateId", name = "covariateName") |>
            collect()
        } else {
          # TODO: check if main effect covariate exists in data
          mainEffectTermsCheck <- !is.null(covariateData$covariates |>
                                             distinct(.data$covariateId) |>
                                             inner_join(covariateData$covariateRef, by = "covariateId") |>
                                             select(id = "covariateId", name = "covariateName") |>
                                             collect())

          if (!mainEffectTermsCheck) {
            stop("No main effects exist.")
          }
        }

        # Create interaction terms
        interactionTerms <- covariateData$covariateRef |>
          filter(.data$covariateId %in% fitOutcomeModelArgs$interactionCovariateIds) |>
          select("covariateId", "covariateName") |>
          collect()

        interactionTerms$interactionId <- treatmentVarId + seq_len(nrow(interactionTerms))
        interactionTerms$interactionName <- paste("treatment", interactionTerms$covariateName, sep = " * ")
        interactionTerms$covariateName <- NULL
        covariateData$interactionTerms <- interactionTerms

        targetRowIds <- informativePopulation$rowId[informativePopulation$treatment == 1]
        # Call to compute() is required or else DuckDB complains about rowId being ambigous (even
        # though it does not exist in covariateData$interactionTerms)
        interactionCovariates <- covariateData$covariatesSubset |>
          filter(.data$rowId %in% targetRowIds) |>
          inner_join(covariateData$interactionTerms, by = "covariateId") |>
          compute() |>
          select("rowId", "interactionId", "covariateValue") |>
          rename(covariateId = .data$interactionId)

        appendToTable(covariateData$covariates, interactionCovariates)

        uniqueCovariateIds <- covariateData$covariates |>
          distinct(.data$covariateId) |>
          pull()
        interactionTerms <- interactionTerms |>
          filter(.data$interactionId %in% uniqueCovariateIds, )

        if (nrow(interactionTerms) == 0) {
          interactionTerms <- NULL
        } else {
          if (fitOutcomeModelArgs$useCovariates) {
            prior$exclude <- unique(c(prior$exclude,
                                      interactionTerms$covariateId,
                                      interactionTerms$interactionId))
          }
          subgroupCounts <- createSubgroupCounts(
            interactionCovariateIds = interactionTerms$covariateId,
            covariatesSubset = covariateData$covariatesSubset,
            population = population,
            modelType = fitOutcomeModelArgs$modelType
          )
        }
      }

      # Fit model -------------------------------------------------------------------------------------------
      if (prior$priorType != "none" &&
          isTRUE(prior$useCrossValidation) &&
          fitOutcomeModelArgs$control$selectorType == "byPid" &&
          length(unique(informativePopulation$stratumId)) < fitOutcomeModelArgs$control$fold) {
        fit <- "NUMBER OF INFORMATIVE STRATA IS SMALLER THAN THE NUMBER OF CV FOLDS, CANNOT FIT"
      } else {
        covariateData$outcomes <- informativePopulation
        outcomes <- covariateData$outcomes
        if (fitOutcomeModelArgs$stratified) {
          covariates <- covariateData$covariates %>%
            inner_join(select(covariateData$outcomes, "rowId", "stratumId"), by = "rowId")
        } else {
          covariates <- covariateData$covariates
        }
        # Free as much memory as possible before we load data into Andromeda:
        rm(population)
        rm(informativePopulation)
        Andromeda::flushAndromeda(covariateData)

        cyclopsData <- Cyclops::convertToCyclopsData(
          outcomes = outcomes,
          covariates = covariates,
          addIntercept = (!fitOutcomeModelArgs$stratified && !fitOutcomeModelArgs$modelType == "cox"),
          modelType = modelTypeToCyclopsModelType(
            fitOutcomeModelArgs$modelType,
            fitOutcomeModelArgs$stratified
          ),
          checkRowIds = FALSE,
          normalize = NULL,
          quiet = TRUE
        )

        if (!is.null(interactionTerms)) {
          # Check separability:
          separability <- Cyclops::getUnivariableSeparability(cyclopsData)
          separability[as.character(treatmentVarId)] <- FALSE
          if (any(separability)) {
            removeCovariateIds <- as.numeric(names(separability)[separability])
            # Add main effects of separable interaction effects, and the other way around:
            if (!fitOutcomeModelArgs$useCovariates) {
              removeCovariateIds <- unique(c(
                removeCovariateIds,
                interactionTerms$covariateId[interactionTerms$interactionId %in% removeCovariateIds]
              ))
            }
            removeCovariateIds <- unique(c(
              removeCovariateIds,
              interactionTerms$interactionId[interactionTerms$covariateId %in% removeCovariateIds]
            ))
            covariates <- covariates %>%
              filter(!.data$covariateId %in% removeCovariateIds)

            cyclopsData <- Cyclops::convertToCyclopsData(
              outcomes = outcomes,
              covariates = covariates,
              addIntercept = (!fitOutcomeModelArgs$stratified && !fitOutcomeModelArgs$modelType == "cox"),
              modelType = modelTypeToCyclopsModelType(
                fitOutcomeModelArgs$modelType,
                fitOutcomeModelArgs$stratified
              ),
              checkSorting = TRUE,
              normalize = NULL,
              quiet = TRUE
            )
            warning("Separable interaction terms found and removed")
            ref <- interactionTerms[interactionTerms$interactionId %in% removeCovariateIds, ]
            message("Separable interactions:")
            for (i in seq_len(nrow(ref))) {
              message(paste(ref[i, ], collapse = "\t"))
            }
            interactionTerms <- interactionTerms[!(interactionTerms$interactionId %in% removeCovariateIds), ]
            if (nrow(interactionTerms) == 0) {
              interactionTerms <- NULL
            }
          }
        }
        fit <- tryCatch(
          {
            Cyclops::fitCyclopsModel(cyclopsData, prior = prior, control = fitOutcomeModelArgs$control)
          },
          error = function(e) {
            e$message
          }
        )
      }
      if (is.character(fit)) {
        status <- fit
      } else {
        # Retrieve likelihood profile
        if (!is.null(fitOutcomeModelArgs$profileGrid) ||
            !is.null(fitOutcomeModelArgs$profileBounds)) {
          logLikelihoodProfile <- Cyclops::getCyclopsProfileLogLikelihood(
            object = fit,
            parm = treatmentVarId,
            x = fitOutcomeModelArgs$profileGrid,
            bounds = fitOutcomeModelArgs$profileBounds,
            tolerance = 0.1,
            includePenalty = TRUE,
            returnDerivatives = TRUE
          )
        }
        if (fit$return_flag != "SUCCESS") {
          status <- fit$return_flag
        } else {
          status <- "OK"
          coefficients <- coef(fit)
          logRr <- coef(fit)[names(coef(fit)) == as.character(treatmentVarId)]
          if (fitOutcomeModelArgs$bootstrapCi) {
            bootstrapSummary <- tryCatch(
              {
                bootstrap <- Cyclops::runBootstrap(fit, fitOutcomeModelArgs$bootstrapReplicates)
                bootstrap$summary[as.character(treatmentVarId),]
              },
              error = function(e) {
                missing(e) # suppresses R CMD check note
                list(
                  bpi_lower = -Inf,
                  bpi_upper = Inf,
                  std_err = NA
                )
              }
            )
            ci <- c(0, bootstrapSummary$bpi_lower, bootstrapSummary$bpi_upper)
            seLogRr <- bootstrapSummary$std_err
          } else {
            ci <- tryCatch(
              {
                confint(fit, parm = treatmentVarId, includePenalty = TRUE)
              },
              error = function(e) {
                missing(e) # suppresses R CMD check note
                c(0, -Inf, Inf)
              }
            )
            if (identical(ci, c(0, -Inf, Inf))) {
              status <- "ERROR COMPUTING CI"
            }
            seLogRr <- (ci[3] - ci[2]) / (2 * qnorm(0.975))
          }
          llNull <- Cyclops::getCyclopsProfileLogLikelihood(
            object = fit,
            parm = treatmentVarId,
            x = 0,
            includePenalty = FALSE
          )$value
          llr <- fit$log_likelihood - llNull
          treatmentEstimate <- tibble(
            logRr = logRr,
            logLb95 = ci[2],
            logUb95 = ci[3],
            seLogRr = seLogRr,
            llr = llr
          )
          priorVariance <- fit$variance[1]
          logLikelihood <- fit$log_likelihood
          if (!is.null(mainEffectTerms)) {
            logRr <- coef(fit)[match(as.character(mainEffectTerms$id), names(coef(fit)))]
            if (prior$priorType == "none") {
              ci <- tryCatch(
                {
                  confint(fit,
                          parm = mainEffectTerms$id, includePenalty = TRUE,
                          overrideNoRegularization = TRUE
                  )
                },
                error = function(e) {
                  missing(e) # suppresses R CMD check note
                  t(array(c(0, -Inf, Inf), dim = c(3, nrow(mainEffectTerms))))
                }
              )
            } else {
              ci <- t(array(c(0, -Inf, Inf), dim = c(3, nrow(mainEffectTerms))))
            }
            seLogRr <- (ci[, 3] - ci[, 2]) / (2 * qnorm(0.975))
            mainEffectEstimates <- tibble(
              covariateId = mainEffectTerms$id,
              coariateName = mainEffectTerms$name,
              logRr = logRr,
              logLb95 = ci[, 2],
              logUb95 = ci[, 3],
              seLogRr = seLogRr
            )
          }

          if (!is.null(interactionTerms)) {
            logRr <- coef(fit)[match(as.character(interactionTerms$interactionId), names(coef(fit)))]
            ci <- tryCatch(
              {
                confint(fit, parm = interactionTerms$interactionId, includePenalty = TRUE)
              },
              error = function(e) {
                missing(e) # suppresses R CMD check note
                t(array(c(0, -Inf, Inf), dim = c(3, nrow(interactionTerms))))
              }
            )
            seLogRr <- (ci[, 3] - ci[, 2]) / (2 * qnorm(0.975))
            interactionEstimates <- data.frame(
              covariateId = interactionTerms$covariateId,
              interactionName = interactionTerms$interactionName,
              logRr = logRr,
              logLb95 = ci[, 2],
              logUb95 = ci[, 3],
              seLogRr = seLogRr
            )
          }
        }
      }
    }
  }
  outcomeModel$outcomeModelTreatmentVarId <- treatmentVarId
  outcomeModel$outcomeModelCoefficients <- coefficients
  outcomeModel$logLikelihoodProfile <- logLikelihoodProfile
  outcomeModel$outcomeModelPriorVariance <- priorVariance
  outcomeModel$outcomeModelLogLikelihood <- logLikelihood
  outcomeModel$outcomeModelTreatmentEstimate <- treatmentEstimate
  outcomeModel$outcomeModelmainEffectEstimates <- mainEffectEstimates
  if (length(fitOutcomeModelArgs$interactionCovariateIds) != 0) {
    outcomeModel$outcomeModelInteractionEstimates <- interactionEstimates
  }
  outcomeModel$outcomeModelStatus <- status
  if (!is.null(subgroupCounts)) {
    outcomeModel$subgroupCounts <- subgroupCounts
  }
  class(outcomeModel) <- "OutcomeModel"
  delta <- Sys.time() - start
  message(paste("Fitting outcome model took", signif(delta, 3), attr(delta, "units")))
  ParallelLogger::logDebug("Outcome model fitting status is: ", status)
  return(outcomeModel)
}

getInformativePopulation <- function(population, stratified, inversePtWeighting, modelType) {
  population <- rename(population, y = "outcomeCount")
  if (!stratified) {
    population$stratumId <- NULL
  }
  population$time <- population$timeAtRisk
  if (modelType == "cox") {
    population$y[population$y != 0] <- 1
    population$time <- population$survivalTime
  } else if (modelType == "logistic") {
    population$y[population$y != 0] <- 1
  }
  if (stratified) {
    informativePopulation <- population |>
      filter(.data$y != 0) |>
      distinct(.data$stratumId) |>
      inner_join(population, by = "stratumId")
  } else {
    informativePopulation <- population
  }
  if (inversePtWeighting) {
    informativePopulation$weights <- informativePopulation$iptw
  } else {
    informativePopulation$weights <- NULL
  }
  columns <- c("rowId", "y", "treatment")
  if (stratified) {
    columns <- c(columns, "stratumId")
  }
  if (modelType == "poisson" || modelType == "cox") {
    columns <- c(columns, "time")
  }
  if (inversePtWeighting) {
    columns <- c(columns, "weights")
  }
  informativePopulation <- informativePopulation[, columns]
  return(informativePopulation)
}

modelTypeToCyclopsModelType <- function(modelType, stratified) {
  if (modelType == "logistic") {
    if (stratified) {
      return("clr")
    } else {
      return("lr")
    }
  } else if (modelType == "poisson") {
    if (stratified) {
      return("cpr")
    } else {
      return("pr")
    }
  } else if (modelType == "cox") {
    return("cox")
  } else {
    stop(paste("Unknown model type:", modelType))
  }
}

filterAndTidyCovariates <- function(cohortMethodData,
                                    includeRowIds,
                                    includeCovariateIds,
                                    excludeCovariateIds) {
  covariates <- cohortMethodData$covariates |>
    filter(.data$rowId %in% includeRowIds)

  if (length(includeCovariateIds) != 0) {
    covariates <- covariates |>
      filter(.data$covariateId %in% includeCovariateIds)
  }
  if (length(excludeCovariateIds) != 0) {
    covariates <- covariates |>
      filter(!.data$covariateId %in% includeCovariateIds)
  }
  filteredCovariateData <- Andromeda::andromeda(
    covariates = covariates,
    covariateRef = cohortMethodData$covariateRef,
    analysisRef = cohortMethodData$analysisRef
  )
  metaData <- attr(cohortMethodData, "metaData")
  metaData$populationSize <- length(includeRowIds)
  attr(filteredCovariateData, "metaData") <- metaData
  class(filteredCovariateData) <- "CovariateData"

  covariateData <- FeatureExtraction::tidyCovariateData(filteredCovariateData)
  close(filteredCovariateData)
  return(covariateData)
}

getOutcomeCounts <- function(population, modelType) {
  population <- rename(population, y = "outcomeCount")
  if (modelType == "cox") {
    population$y[population$y != 0] <- 1
  } else if (modelType == "logistic") {
    population$y[population$y != 0] <- 1
  }
  return(tibble(
    targetPersons = length(unique(population$personSeqId[population$treatment == 1 & population$y != 0])),
    comparatorPersons = length(unique(population$personSeqId[population$treatment == 0 & population$y != 0])),
    targetExposures = length(population$personSeqId[population$treatment == 1 & population$y != 0]),
    comparatorExposures = length(population$personSeqId[population$treatment == 0 & population$y != 0]),
    targetOutcomes = sum(population$y[population$treatment == 1]),
    comparatorOutcomes = sum(population$y[population$treatment == 0])
  ))
}

createSubgroupCounts <- function(interactionCovariateIds, covariatesSubset, population, modelType) {
  createSubgroupCounts <- function(subgroupCovariateId) {
    subgroupRowIds <- covariatesSubset |>
      filter(.data$covariateId %in% subgroupCovariateId) |>
      distinct(.data$rowId) |>
      pull()

    subgroup <- population |>
      filter(.data$rowId %in% subgroupRowIds)

    counts <- bind_cols(
      getCounts(subgroup, "Population count"),
      rename(getOutcomeCounts(subgroup, modelType),
             targetOutcomePersons = .data$targetPersons,
             comparatorOutcomePersons = .data$comparatorPersons,
             targetOutcomeExposures = .data$targetExposures,
             comparatorOutcomeExposures = .data$comparatorExposures
      ),
      getTimeAtRisk(subgroup, modelType)
    )
    counts$description <- NULL
    counts$subgroupCovariateId <- subgroupCovariateId
    return(counts)
  }
  subgroupCounts <- lapply(interactionCovariateIds, createSubgroupCounts)
  subgroupCounts <- bind_rows(subgroupCounts)
  return(subgroupCounts)
}

getTimeAtRisk <- function(population, modelType) {
  if (modelType == "cox") {
    population$time <- population$survivalTime
  } else {
    population$time <- population$timeAtRisk
  }
  return(tibble(
    targetDays = sum(population$time[population$treatment == 1]),
    comparatorDays = sum(population$time[population$treatment == 0])
  ))
}


#' @export
coef.OutcomeModel <- function(object, ...) {
  return(object$outcomeModelTreatmentEstimate$logRr)
}

#' @export
confint.OutcomeModel <- function(object, parm, level = 0.95, ...) {
  missing(parm) # suppresses R CMD check note
  if (level != 0.95) {
    stop("Only supporting 95% confidence interval")
  }
  return(c(
    object$outcomeModelTreatmentEstimate$logLb95,
    object$outcomeModelTreatmentEstimate$logUb95
  ))
}

#' @export
print.OutcomeModel <- function(x, ...) {
  message(paste("Model type:", x$outcomeModelType))
  message(paste("Stratified:", x$outcomeModelStratified))
  message(paste("Use covariates:", x$outcomeModelUseCovariates))
  message(paste("Use inverse probability of treatment weighting:", x$inversePtWeighting))
  message(paste("Target estimand:", x$targetEstimator))
  message(paste("Status:", x$outcomeModelStatus))
  if (!is.null(x$outcomeModelPriorVariance) && !is.na(x$outcomeModelPriorVariance)) {
    message(paste("Prior variance:", x$outcomeModelPriorVariance))
  }
  message("")
  d <- x$outcomeModelTreatmentEstimate
  if (!is.null(d)) {
    rns <- "treatment"
    i <- x$outcomeModelInteractionEstimates
    if (!is.null(i)) {
      d <- rbind(d[, 1:4], i[, 3:6])
      rns <- c(rns, as.character(i$interactionName))
    }
    output <- data.frame(exp(d$logRr), exp(d$logLb95), exp(d$logUb95), d$logRr, d$seLogRr)
    colnames(output) <- c("Estimate", "lower .95", "upper .95", "logRr", "seLogRr")
    rownames(output) <- rns
    printCoefmat(output)
  }
}

#' Get the outcome model
#'
#' @description
#' Get the full outcome model, so showing the betas of all variables included
#' in the outcome model, not just the treatment variable.
#'
#' @param outcomeModel       An object of type `OutcomeModel` as generated using he
#'                           [fitOutcomeModel()] function.
#'
#' @template CohortMethodData
#'
#' @return
#' A tibble.
#'
#' @export
getOutcomeModel <- function(outcomeModel, cohortMethodData) {
  errorMessages <- checkmate::makeAssertCollection()
  checkmate::assertClass(outcomeModel, "OutcomeModel", add = errorMessages)
  checkmate::assertClass(cohortMethodData, "CohortMethodData", add = errorMessages)
  checkmate::reportAssertions(collection = errorMessages)

  cfs <- outcomeModel$outcomeModelCoefficients

  cfs <- cfs[cfs != 0]
  attr(cfs, "names")[attr(cfs, "names") == "(Intercept)"] <- 0
  cfs <- data.frame(coefficient = cfs, id = as.numeric(attr(cfs, "names")))

  ref <- cohortMethodData$covariateRef |>
    filter(.data$covariateId %in% local(cfs$covariateId)) |>
    select(id = "covariateId", name = "covariateName") |>
    collect()

  ref <- bind_rows(
    ref,
    tibble(
      id = outcomeModel$outcomeModelTreatmentVarId,
      name = "Treatment"
    ),
    tibble(
      id = 0,
      name = "(Intercept)"
    )
  )

  cfs <- cfs |>
    inner_join(ref, by = "id")
  return(cfs)
}

Try the CohortMethod package in your browser

Any scripts or data that you put into this service are public.

CohortMethod documentation built on March 21, 2026, 5:06 p.m.