R/OutcomeModels.R

Defines functions getOutcomeModel print.OutcomeModel confint.OutcomeModel coef.OutcomeModel getTimeAtRisk createSubgroupCounts getOutcomeCounts filterAndTidyCovariates modelTypeToCyclopsModelType getInformativePopulation fitOutcomeModel

Documented in fitOutcomeModel getOutcomeModel

# Copyright 2024 Observational Health Data Sciences and Informatics
#
# This file is part of CohortMethod
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#' Create an outcome model, and compute the relative risk
#'
#' @details
#' For likelihood profiling, either specify the `profileGrid` for a completely user- defined grid, or
#' `profileBounds` for an adaptive grid. Both should be defined on the log effect size scale. When both
#' `profileGrid` and `profileGrid` are `NULL` likelihood profiling is disabled.
#'
#' @description
#' Create an outcome model, and computes the relative risk
#'
#' @param population            A population object generated by [createStudyPopulation()],
#'                              potentially filtered by other functions.
#'
#' @param cohortMethodData      An object of type [CohortMethodData] as generated using
#'                              [getDbCohortMethodData()]. Can be omitted if not using covariates and
#'                              not using interaction terms.
#' @param modelType             The type of outcome model that will be used. Possible values are
#'                              "logistic", "poisson", or "cox".
#' @param stratified            Should the regression be conditioned on the strata defined in the
#'                              population object (e.g. by matching or stratifying on propensity
#'                              scores)?
#' @param useCovariates         Whether to use the covariates in the `cohortMethodData`
#'                              object in the outcome model.
#' @param inversePtWeighting    Use inverse probability of treatment weighting (IPTW)
#' @param interactionCovariateIds  An optional vector of covariate IDs to use to estimate interactions
#'                                 with the main treatment effect.
#' @param excludeCovariateIds   Exclude these covariates from the outcome model.
#' @param includeCovariateIds   Include only these covariates in the outcome model.
#' @param profileGrid           A one-dimensional grid of points on the log(relative risk) scale where
#'                              the likelihood for coefficient of variables is sampled. See details.
#' @param profileBounds         The bounds (on the log relative risk scale) for the adaptive sampling
#'                              of the likelihood function. See details.
#' @param prior                 The prior used to fit the model. See [Cyclops::createPrior()]
#'                              for details.
#' @param control               The control object used to control the cross-validation used to
#'                              determine the hyperparameters of the prior (if applicable). See
#'                              [Cyclops::createControl()] for details.
#'
#' @return
#' An object of class `OutcomeModel`. Generic function `print`, `coef`, and
#' `confint` are available.
#'
#' @export
fitOutcomeModel <- function(population,
                            cohortMethodData = NULL,
                            modelType = "logistic",
                            stratified = FALSE,
                            useCovariates = FALSE,
                            inversePtWeighting = FALSE,
                            interactionCovariateIds = c(),
                            excludeCovariateIds = c(),
                            includeCovariateIds = c(),
                            profileGrid = NULL,
                            profileBounds = c(log(0.1), log(10)),
                            prior = createPrior("laplace", useCrossValidation = TRUE),
                            control = createControl(
                              cvType = "auto",
                              seed = 1,
                              resetCoefficients = TRUE,
                              startingVariance = 0.01,
                              tolerance = 2e-07,
                              cvRepetitions = 10,
                              noiseLevel = "quiet"
                            )) {
  errorMessages <- checkmate::makeAssertCollection()
  checkmate::assertDataFrame(population, null.ok = TRUE, add = errorMessages)
  checkmate::assertNames(names(population), must.include = c("rowId", "outcomeCount", "treatment", "timeAtRisk", "personSeqId"), add = errorMessages)
  if (!is.null(cohortMethodData)) {
    checkmate::assertNames(names(cohortMethodData), must.include = c("analysisRef", "cohorts", "covariateRef", "covariates", "outcomes"), add = errorMessages)
  }
  checkmate::assertClass(cohortMethodData, "CohortMethodData", null.ok = TRUE, add = errorMessages)
  checkmate::assertChoice(modelType, c("logistic", "poisson", "cox"), add = errorMessages)
  checkmate::assertLogical(stratified, len = 1, add = errorMessages)
  checkmate::assertLogical(useCovariates, len = 1, add = errorMessages)
  checkmate::assertLogical(inversePtWeighting, len = 1, add = errorMessages)
  .assertCovariateId(interactionCovariateIds, null.ok = TRUE, add = errorMessages)
  .assertCovariateId(excludeCovariateIds, null.ok = TRUE, add = errorMessages)
  .assertCovariateId(includeCovariateIds, null.ok = TRUE, add = errorMessages)
  checkmate::assertNumeric(profileGrid, null.ok = TRUE, add = errorMessages)
  checkmate::assertNumeric(profileBounds, null.ok = TRUE, len = 2, add = errorMessages)
  checkmate::assertClass(prior, "cyclopsPrior", add = errorMessages)
  checkmate::assertClass(control, "cyclopsControl", add = errorMessages)
  checkmate::reportAssertions(collection = errorMessages)
  if (stratified && nrow(population) > 0 && is.null(population$stratumId)) {
    stop("Requested stratified analysis, but no stratumId column found in population. Please use matchOnPs or stratifyByPs to create strata.")
  }
  if (is.null(cohortMethodData) && useCovariates) {
    stop("Requested all covariates for model, but no cohortMethodData object specified")
  }
  if (is.null(cohortMethodData) && length(interactionCovariateIds) != 0) {
    stop("Requesting interaction terms in model, but no cohortMethodData object specified")
  }
  if (any(excludeCovariateIds %in% interactionCovariateIds)) {
    stop("Can't exclude covariates that are to be used for interaction terms")
  }
  if (any(includeCovariateIds %in% excludeCovariateIds)) {
    stop("Can't exclude covariates that are to be included")
  }
  if (inversePtWeighting && is.null(population$iptw)) {
    stop("Requested inverse probability weighting, but no IPTW are provided. Use createPs to generate them")
  }
  if (!is.null(profileGrid) && !is.null(profileBounds)) {
    stop("Can't specify both a grid and bounds for likelihood profiling")
  }

  start <- Sys.time()
  treatmentEstimate <- NULL
  interactionEstimates <- NULL
  mainEffectEstimates <- NULL
  mainEffectTerms <- NULL
  coefficients <- NULL
  fit <- NULL
  priorVariance <- NULL
  logLikelihood <- NA
  treatmentVarId <- NA
  subgroupCounts <- NULL
  logLikelihoodProfile <- NULL
  status <- "NO MODEL FITTED"
  metaData <- attr(population, "metaData")

  if (nrow(population) == 0) {
    status <- "NO SUBJECTS IN POPULATION, CANNOT FIT"
  } else if (sum(population$outcomeCount) == 0) {
    status <- "NO OUTCOMES FOUND FOR POPULATION, CANNOT FIT"
  } else {
    # Informative population ---------------------------------------------------------
    informativePopulation <- getInformativePopulation(
      population = population,
      stratified = stratified,
      inversePtWeighting = inversePtWeighting,
      modelType = modelType
    )

    if (sum(informativePopulation$treatment == 1) == 0 || sum(informativePopulation$treatment == 0) == 0) {
      status <- "NO STRATA WITH BOTH TARGET, COMPARATOR, AS WELL AS THE OUTCOME. CANNOT FIT"
    } else {
      if (useCovariates) {
        # Add covariates ---------------------------------------------------------------------------------
        covariateData <- filterAndTidyCovariates(
          cohortMethodData = cohortMethodData,
          includeRowIds = informativePopulation$rowId,
          includeCovariateIds = includeCovariateIds,
          excludeCovariateIds = excludeCovariateIds
        )
        on.exit(close(covariateData))
        metaData$deletedRedundantCovariateIdsForOutcomeModel <- attr(covariateData, "metaData")$deletedRedundantCovariateIds
        metaData$deletedInfrequentCovariateIdsForOutcomeModel <- attr(covariateData, "metaData")$deletedInfrequentCovariateIds

        mainEffectTerms <- covariateData$covariates %>%
          distinct(.data$covariateId) %>%
          inner_join(covariateData$covariateRef, by = "covariateId") %>%
          select(id = "covariateId", name = "covariateName") %>%
          collect()

        treatmentVarId <- cohortMethodData$covariates %>%
          summarise(value = max(.data$covariateId, na.rm = TRUE)) %>%
          pull() + 1

        treatmentCovariate <- informativePopulation %>%
          select("rowId", covariateValue = "treatment") %>%
          mutate(covariateId = treatmentVarId)

        appendToTable(covariateData$covariates, treatmentCovariate)

        if (stratified || modelType == "cox") {
          prior$exclude <- treatmentVarId # Exclude treatment variable from regularization
        } else {
          prior$exclude <- c(0, treatmentVarId) # Exclude treatment variable and intercept from regularization
        }
      } else {
        # Don't add covariates, only use treatment as covariate ----------------------------------------------
        treatmentVarId <- 1

        treatmentCovariate <- informativePopulation %>%
          select("rowId", covariateValue = "treatment") %>%
          mutate(covariateId = treatmentVarId)

        covariateData <- Andromeda::andromeda(
          covariates = treatmentCovariate,
          covariateRef = cohortMethodData$covariateRef
        )
        on.exit(close(covariateData))
        prior <- createPrior("none")
      }

      # Interaction terms -----------------------------------------------------------------------------------
      interactionTerms <- NULL
      if (length(interactionCovariateIds) != 0) {
        covariateData$covariatesSubset <- cohortMethodData$covariates %>%
          filter(.data$covariateId %in% interactionCovariateIds) %>%
          filter(.data$rowId %in% local(informativePopulation$rowId))

        # Cannot have interaction terms without main effects, so add covariates if they haven't been included already:
        if (!useCovariates) {
          appendToTable(covariateData$covariates, covariateData$covariatesSubset)
          mainEffectTerms <- covariateData$covariatesSubset %>%
            distinct(.data$covariateId) %>%
            inner_join(covariateData$covariateRef, by = "covariateId") %>%
            select(id = "covariateId", name = "covariateName") %>%
            collect()
        } else {
          # TODO: check if main effect covariate exists in data
          mainEffectTermsCheck <- !is.null(covariateData$covariates %>%
            distinct(.data$covariateId) %>%
            inner_join(covariateData$covariateRef, by = "covariateId") %>%
            select(id = "covariateId", name = "covariateName") %>%
            collect())

          if (!mainEffectTermsCheck) {
            stop("No main effects exist.")
          }
        }

        # Create interaction terms
        interactionTerms <- covariateData$covariateRef %>%
          filter(.data$covariateId %in% interactionCovariateIds) %>%
          select("covariateId", "covariateName") %>%
          collect()

        interactionTerms$interactionId <- treatmentVarId + seq_len(nrow(interactionTerms))
        interactionTerms$interactionName <- paste("treatment", interactionTerms$covariateName, sep = " * ")
        interactionTerms$covariateName <- NULL
        covariateData$interactionTerms <- interactionTerms

        targetRowIds <- informativePopulation$rowId[informativePopulation$treatment == 1]
        interactionCovariates <- covariateData$covariatesSubset %>%
          filter(.data$rowId %in% targetRowIds) %>%
          inner_join(covariateData$interactionTerms, by = "covariateId") %>%
          select("rowId", "interactionId", "covariateValue") %>%
          rename(covariateId = .data$interactionId)

        appendToTable(covariateData$covariates, interactionCovariates)

        uniqueCovariateIds <- covariateData$covariates %>%
          distinct(.data$covariateId) %>%
          pull()
        interactionTerms <- interactionTerms %>%
          filter(.data$interactionId %in% uniqueCovariateIds, )

        if (nrow(interactionTerms) == 0) {
          interactionTerms <- NULL
        } else {
          if (useCovariates) {
            prior$exclude <- unique(c(prior$exclude, interactionTerms$covariateId, interactionTerms$interactionId))
          }
          subgroupCounts <- createSubgroupCounts(interactionTerms$covariateId, covariateData$covariatesSubset, population, modelType)
        }
      }

      # Fit model -------------------------------------------------------------------------------------------
      covariateData$outcomes <- informativePopulation
      outcomes <- covariateData$outcomes
      if (stratified) {
        covariates <- covariateData$covariates %>%
          inner_join(select(covariateData$outcomes, "rowId", "stratumId"), by = "rowId")
      } else {
        covariates <- covariateData$covariates
      }
      cyclopsData <- Cyclops::convertToCyclopsData(
        outcomes = outcomes,
        covariates = covariates,
        addIntercept = (!stratified && !modelType == "cox"),
        modelType = modelTypeToCyclopsModelType(
          modelType,
          stratified
        ),
        checkRowIds = FALSE,
        normalize = NULL,
        quiet = TRUE
      )

      if (!is.null(interactionTerms)) {
        # Check separability:
        separability <- Cyclops::getUnivariableSeparability(cyclopsData)
        separability[as.character(treatmentVarId)] <- FALSE
        if (any(separability)) {
          removeCovariateIds <- as.numeric(names(separability)[separability])
          # Add main effects of separable interaction effects, and the other way around:
          if (!useCovariates) {
            removeCovariateIds <- unique(c(
              removeCovariateIds,
              interactionTerms$covariateId[interactionTerms$interactionId %in% removeCovariateIds]
            ))
          }
          removeCovariateIds <- unique(c(
            removeCovariateIds,
            interactionTerms$interactionId[interactionTerms$covariateId %in% removeCovariateIds]
          ))
          covariates <- covariates %>%
            filter(!.data$covariateId %in% removeCovariateIds)

          cyclopsData <- Cyclops::convertToCyclopsData(
            outcomes = outcomes,
            covariates = covariates,
            addIntercept = (!stratified && !modelType == "cox"),
            modelType = modelTypeToCyclopsModelType(
              modelType,
              stratified
            ),
            checkSorting = TRUE,
            checkRowIds = FALSE,
            normalize = NULL,
            quiet = TRUE
          )
          warning("Separable interaction terms found and removed")
          ref <- interactionTerms[interactionTerms$interactionId %in% removeCovariateIds, ]
          message("Separable interactions:")
          for (i in seq_len(nrow(ref))) {
            message(paste(ref[i, ], collapse = "\t"))
          }
          interactionTerms <- interactionTerms[!(interactionTerms$interactionId %in% removeCovariateIds), ]
          if (nrow(interactionTerms) == 0) {
            interactionTerms <- NULL
          }
        }
      }

      if (prior$priorType != "none" && prior$useCrossValidation && control$selectorType == "byPid" &&
          length(unique(informativePopulation$stratumId)) < control$fold) {
        fit <- "NUMBER OF INFORMATIVE STRATA IS SMALLER THAN THE NUMBER OF CV FOLDS, CANNOT FIT"
      } else {
        fit <- tryCatch(
          {
            Cyclops::fitCyclopsModel(cyclopsData, prior = prior, control = control)
          },
          error = function(e) {
            e$message
          }
        )
      }
      if (is.character(fit)) {
        status <- fit
      } else {
        # Retrieve likelihood profile
        if (!is.null(profileGrid) || !is.null(profileBounds)) {
          logLikelihoodProfile <- Cyclops::getCyclopsProfileLogLikelihood(
            object = fit,
            parm = treatmentVarId,
            x = profileGrid,
            bounds = profileBounds,
            tolerance = 0.1,
            includePenalty = TRUE
          )
        }
        if (fit$return_flag != "SUCCESS") {
          status <- fit$return_flag
        } else {
          status <- "OK"
          coefficients <- coef(fit)
          logRr <- coef(fit)[names(coef(fit)) == as.character(treatmentVarId)]
          ci <- tryCatch(
            {
              confint(fit, parm = treatmentVarId, includePenalty = TRUE)
            },
            error = function(e) {
              missing(e) # suppresses R CMD check note
              c(0, -Inf, Inf)
            }
          )
          if (identical(ci, c(0, -Inf, Inf))) {
            status <- "ERROR COMPUTING CI"
          }
          seLogRr <- (ci[3] - ci[2]) / (2 * qnorm(0.975))
          llNull <- Cyclops::getCyclopsProfileLogLikelihood(
            object = fit,
            parm = treatmentVarId,
            x = 0,
            includePenalty = FALSE
          )$value
          llr <- fit$log_likelihood - llNull
          treatmentEstimate <- tibble(
            logRr = logRr,
            logLb95 = ci[2],
            logUb95 = ci[3],
            seLogRr = seLogRr,
            llr = llr
          )
          priorVariance <- fit$variance[1]
          logLikelihood <- fit$log_likelihood
          if (!is.null(mainEffectTerms)) {
            logRr <- coef(fit)[match(as.character(mainEffectTerms$id), names(coef(fit)))]
            if (prior$priorType == "none") {
              ci <- tryCatch(
                {
                  confint(fit,
                          parm = mainEffectTerms$id, includePenalty = TRUE,
                          overrideNoRegularization = TRUE
                  )
                },
                error = function(e) {
                  missing(e) # suppresses R CMD check note
                  t(array(c(0, -Inf, Inf), dim = c(3, nrow(mainEffectTerms))))
                }
              )
            } else {
              ci <- t(array(c(0, -Inf, Inf), dim = c(3, nrow(mainEffectTerms))))
            }
            seLogRr <- (ci[, 3] - ci[, 2]) / (2 * qnorm(0.975))
            mainEffectEstimates <- tibble(
              covariateId = mainEffectTerms$id,
              coariateName = mainEffectTerms$name,
              logRr = logRr,
              logLb95 = ci[, 2],
              logUb95 = ci[, 3],
              seLogRr = seLogRr
            )
          }

          if (!is.null(interactionTerms)) {
            logRr <- coef(fit)[match(as.character(interactionTerms$interactionId), names(coef(fit)))]
            ci <- tryCatch(
              {
                confint(fit, parm = interactionTerms$interactionId, includePenalty = TRUE)
              },
              error = function(e) {
                missing(e) # suppresses R CMD check note
                t(array(c(0, -Inf, Inf), dim = c(3, nrow(interactionTerms))))
              }
            )
            seLogRr <- (ci[, 3] - ci[, 2]) / (2 * qnorm(0.975))
            interactionEstimates <- data.frame(
              covariateId = interactionTerms$covariateId,
              interactionName = interactionTerms$interactionName,
              logRr = logRr,
              logLb95 = ci[, 2],
              logUb95 = ci[, 3],
              seLogRr = seLogRr
            )
          }
        }
      }
    }
  }
  outcomeModel <- metaData
  outcomeModel$outcomeModelTreatmentVarId <- treatmentVarId
  outcomeModel$outcomeModelCoefficients <- coefficients
  outcomeModel$logLikelihoodProfile <- logLikelihoodProfile
  outcomeModel$outcomeModelPriorVariance <- priorVariance
  outcomeModel$outcomeModelLogLikelihood <- logLikelihood
  outcomeModel$outcomeModelType <- modelType
  outcomeModel$outcomeModelStratified <- stratified
  outcomeModel$outcomeModelUseCovariates <- useCovariates
  outcomeModel$inversePtWeighting <- inversePtWeighting
  if (inversePtWeighting) {
    outcomeModel$targetEstimator <- outcomeModel$iptwEstimator
  }
  outcomeModel$iptwEstimator <- NULL
  outcomeModel$outcomeModelTreatmentEstimate <- treatmentEstimate
  outcomeModel$outcomeModelmainEffectEstimates <- mainEffectEstimates
  if (length(interactionCovariateIds) != 0) {
    outcomeModel$outcomeModelInteractionEstimates <- interactionEstimates
  }
  outcomeModel$outcomeModelStatus <- status
  outcomeModel$populationCounts <- getCounts(population, "Population count")
  outcomeModel$outcomeCounts <- getOutcomeCounts(population, modelType)
  outcomeModel$timeAtRisk <- getTimeAtRisk(population, modelType)
  if (!is.null(subgroupCounts)) {
    outcomeModel$subgroupCounts <- subgroupCounts
  }
  class(outcomeModel) <- "OutcomeModel"
  delta <- Sys.time() - start
  message(paste("Fitting outcome model took", signif(delta, 3), attr(delta, "units")))
  ParallelLogger::logDebug("Outcome model fitting status is: ", status)
  return(outcomeModel)
}

getInformativePopulation <- function(population, stratified, inversePtWeighting, modelType) {
  population <- rename(population, y = "outcomeCount")
  if (!stratified) {
    population$stratumId <- NULL
  }
  population$time <- population$timeAtRisk
  if (modelType == "cox") {
    population$y[population$y != 0] <- 1
    population$time <- population$survivalTime
  } else if (modelType == "logistic") {
    population$y[population$y != 0] <- 1
  }
  if (stratified) {
    informativePopulation <- population %>%
      filter(.data$y != 0) %>%
      distinct(.data$stratumId) %>%
      inner_join(population, by = "stratumId")
  } else {
    informativePopulation <- population
  }
  if (inversePtWeighting) {
    informativePopulation$weights <- informativePopulation$iptw
  } else {
    informativePopulation$weights <- NULL
  }
  columns <- c("rowId", "y", "treatment")
  if (stratified) {
    columns <- c(columns, "stratumId")
  }
  if (modelType == "poisson" || modelType == "cox") {
    columns <- c(columns, "time")
  }
  if (inversePtWeighting) {
    columns <- c(columns, "weights")
  }
  informativePopulation <- informativePopulation[, columns]
  return(informativePopulation)
}

modelTypeToCyclopsModelType <- function(modelType, stratified) {
  if (modelType == "logistic") {
    if (stratified) {
      return("clr")
    } else {
      return("lr")
    }
  } else if (modelType == "poisson") {
    if (stratified) {
      return("cpr")
    } else {
      return("pr")
    }
  } else if (modelType == "cox") {
    return("cox")
  } else {
    stop(paste("Unknown model type:", modelType))
  }
}

filterAndTidyCovariates <- function(cohortMethodData,
                                    includeRowIds,
                                    includeCovariateIds,
                                    excludeCovariateIds) {
  covariates <- cohortMethodData$covariates %>%
    filter(.data$rowId %in% includeRowIds)

  if (length(includeCovariateIds) != 0) {
    covariates <- covariates %>%
      filter(.data$covariateId %in% includeCovariateIds)
  }
  if (length(excludeCovariateIds) != 0) {
    covariates <- covariates %>%
      filter(!.data$covariateId %in% includeCovariateIds)
  }
  filteredCovariateData <- Andromeda::andromeda(
    covariates = covariates,
    covariateRef = cohortMethodData$covariateRef,
    analysisRef = cohortMethodData$analysisRef
  )
  metaData <- attr(cohortMethodData, "metaData")
  metaData$populationSize <- length(includeRowIds)
  attr(filteredCovariateData, "metaData") <- metaData
  class(filteredCovariateData) <- "CovariateData"

  covariateData <- FeatureExtraction::tidyCovariateData(filteredCovariateData)
  close(filteredCovariateData)
  return(covariateData)
}

getOutcomeCounts <- function(population, modelType) {
  population <- rename(population, y = "outcomeCount")
  if (modelType == "cox") {
    population$y[population$y != 0] <- 1
  } else if (modelType == "logistic") {
    population$y[population$y != 0] <- 1
  }
  return(tibble(
    targetPersons = length(unique(population$personSeqId[population$treatment == 1 & population$y != 0])),
    comparatorPersons = length(unique(population$personSeqId[population$treatment == 0 & population$y != 0])),
    targetExposures = length(population$personSeqId[population$treatment == 1 & population$y != 0]),
    comparatorExposures = length(population$personSeqId[population$treatment == 0 & population$y != 0]),
    targetOutcomes = sum(population$y[population$treatment == 1]),
    comparatorOutcomes = sum(population$y[population$treatment == 0])
  ))
}

createSubgroupCounts <- function(interactionCovariateIds, covariatesSubset, population, modelType) {
  createSubgroupCounts <- function(subgroupCovariateId) {
    subgroupRowIds <- covariatesSubset %>%
      filter(.data$covariateId %in% subgroupCovariateId) %>%
      distinct(.data$rowId) %>%
      pull()

    subgroup <- population %>%
      filter(.data$rowId %in% subgroupRowIds)

    counts <- bind_cols(
      getCounts(subgroup, "Population count"),
      rename(getOutcomeCounts(subgroup, modelType),
             targetOutcomePersons = .data$targetPersons,
             comparatorOutcomePersons = .data$comparatorPersons,
             targetOutcomeExposures = .data$targetExposures,
             comparatorOutcomeExposures = .data$comparatorExposures
      ),
      getTimeAtRisk(subgroup, modelType)
    )
    counts$description <- NULL
    counts$subgroupCovariateId <- subgroupCovariateId
    return(counts)
  }
  subgroupCounts <- lapply(interactionCovariateIds, createSubgroupCounts)
  subgroupCounts <- bind_rows(subgroupCounts)
  return(subgroupCounts)
}

getTimeAtRisk <- function(population, modelType) {
  if (modelType == "cox") {
    population$time <- population$survivalTime
  } else {
    population$time <- population$timeAtRisk
  }
  return(tibble(
    targetDays = sum(population$time[population$treatment == 1]),
    comparatorDays = sum(population$time[population$treatment == 0])
  ))
}


#' @export
coef.OutcomeModel <- function(object, ...) {
  return(object$outcomeModelTreatmentEstimate$logRr)
}

#' @export
confint.OutcomeModel <- function(object, parm, level = 0.95, ...) {
  missing(parm) # suppresses R CMD check note
  if (level != 0.95) {
    stop("Only supporting 95% confidence interval")
  }
  return(c(
    object$outcomeModelTreatmentEstimate$logLb95,
    object$outcomeModelTreatmentEstimate$logUb95
  ))
}

#' @export
print.OutcomeModel <- function(x, ...) {
  writeLines(paste("Model type:", x$outcomeModelType))
  writeLines(paste("Stratified:", x$outcomeModelStratified))
  writeLines(paste("Use covariates:", x$outcomeModelUseCovariates))
  writeLines(paste("Use inverse probability of treatment weighting:", x$inversePtWeighting))
  writeLines(paste("Target estimand:", x$targetEstimator))
  writeLines(paste("Status:", x$outcomeModelStatus))
  if (!is.null(x$outcomeModelPriorVariance) && !is.na(x$outcomeModelPriorVariance)) {
    writeLines(paste("Prior variance:", x$outcomeModelPriorVariance))
  }
  writeLines("")
  d <- x$outcomeModelTreatmentEstimate
  if (!is.null(d)) {
    rns <- "treatment"
    i <- x$outcomeModelInteractionEstimates
    if (!is.null(i)) {
      d <- rbind(d[, 1:4], i[, 3:6])
      rns <- c(rns, as.character(i$interactionName))
    }
    output <- data.frame(exp(d$logRr), exp(d$logLb95), exp(d$logUb95), d$logRr, d$seLogRr)
    colnames(output) <- c("Estimate", "lower .95", "upper .95", "logRr", "seLogRr")
    rownames(output) <- rns
    printCoefmat(output)
  }
}

#' Get the outcome model
#'
#' @description
#' Get the full outcome model, so showing the betas of all variables included
#' in the outcome model, not just the treatment variable.
#'
#' @param outcomeModel       An object of type `OutcomeModel` as generated using he
#'                           [fitOutcomeModel()] function.
#'
#' @template CohortMethodData
#'
#' @return
#' A tibble.
#'
#' @export
getOutcomeModel <- function(outcomeModel, cohortMethodData) {
  errorMessages <- checkmate::makeAssertCollection()
  checkmate::assertClass(outcomeModel, "OutcomeModel", add = errorMessages)
  checkmate::assertClass(cohortMethodData, "CohortMethodData", add = errorMessages)
  checkmate::reportAssertions(collection = errorMessages)

  cfs <- outcomeModel$outcomeModelCoefficients

  cfs <- cfs[cfs != 0]
  attr(cfs, "names")[attr(cfs, "names") == "(Intercept)"] <- 0
  cfs <- data.frame(coefficient = cfs, id = as.numeric(attr(cfs, "names")))

  ref <- cohortMethodData$covariateRef %>%
    filter(.data$covariateId %in% local(cfs$covariateId)) %>%
    select(id = "covariateId", name = "covariateName") %>%
    collect()

  ref <- bind_rows(
    ref,
    tibble(
      id = outcomeModel$outcomeModelTreatmentVarId,
      name = "Treatment"
    ),
    tibble(
      id = 0,
      name = "(Intercept)"
    )
  )

  cfs <- cfs %>%
    inner_join(ref, by = "id")
  return(cfs)
}
OHDSI/CohortMethod documentation built on Oct. 9, 2024, 12:50 p.m.