R/StudyPopulation.R

Defines functions plotTimeToEvent getCounts getAttritionTable createStudyPopulation fastDuplicated

Documented in createStudyPopulation getAttritionTable plotTimeToEvent

# Copyright 2026 Observational Health Data Sciences and Informatics
#
# This file is part of CohortMethod
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

fastDuplicated <- function(data, columns) {
  if (nrow(data) == 0) {
    return(vector())
  } else if (nrow(data) == 1) {
    return(c(FALSE))
  } else {
    results <- lapply(columns, function(column, data) data[2:nrow(data), column] == data[1:(nrow(data) - 1), column], data = data)
    result <- results[[1]]
    if (length(columns) > 1) {
      for (i in 2:length(columns)) {
        result <- result & results[[i]]
      }
    }
    return(c(FALSE, result))
  }
}

#' Create a study population
#'
#' @details
#' Create a study population by enforcing certain inclusion and exclusion criteria, defining a risk
#' window, and determining which outcomes fall inside the risk window.
#'
#' @template CohortMethodData
#'
#' @param population                  If specified, this population will be used as the starting
#'                                    point instead of the cohorts in the `cohortMethodData` object.
#' @param outcomeId                   The ID of the outcome. If NULL, no outcome-specific
#'                                    transformations will be performed.
#' @param createStudyPopulationArgs   An object of type `CreateStudyPopulationArgs` as created by
#'                                    the [createCreateStudyPopulationArgs()] function.
#' @return
#' A `tibble` specifying the study population. This `tibble` will have the following columns:
#'
#' - `rowId`: A unique identifier for an exposure.
#' - `personSeqId`: The person sequence ID of the subject.
#' - `cohortStartdate`: The index date.
#' - `outcomeCount` The number of outcomes observed during the risk window.
#' - `timeAtRisk`: The number of days in the risk window.
#' - `survivalTime`: The number of days until either the outcome or the end of the risk window.
#'
#' @export
createStudyPopulation <- function(cohortMethodData,
                                  population = NULL,
                                  outcomeId = NULL,
                                  createStudyPopulationArgs = createCreateStudyPopulationArgs()) {
  errorMessages <- checkmate::makeAssertCollection()
  checkmate::assertClass(cohortMethodData, "CohortMethodData", add = errorMessages)
  checkmate::assertDataFrame(population, null.ok = TRUE, add = errorMessages)
  checkmate::assertNumeric(outcomeId, null.ok = TRUE, add = errorMessages)
  if (!is.null(outcomeId)) checkmate::assertTRUE(all(outcomeId %% 1 == 0), add = errorMessages)
  checkmate::assertR6(createStudyPopulationArgs, "CreateStudyPopulationArgs", add = errorMessages)
  checkmate::reportAssertions(collection = errorMessages)

  isEnd <- function(anchor) {
    return(grepl("end$", anchor, ignore.case = TRUE))
  }

  if (is.null(population)) {
    metaData <- attr(cohortMethodData, "metaData")
    population <- cohortMethodData$cohorts |>
      collect() |>
      select(-"personId")
  } else {
    metaData <- attr(population, "metaData")
  }
  metaData$targetEstimator <- "ate"

  if (createStudyPopulationArgs$removeSubjectsWithPriorOutcome) {
    if (is.null(outcomeId)) {
      message("No outcome specified so skipping removing people with prior outcomes")
    } else {
      message("Removing subjects with prior outcomes (if any)")
      outcomes <- cohortMethodData$outcomes |>
        filter(.data$outcomeId == !!outcomeId) |>
        collect()
      if (isEnd(createStudyPopulationArgs$startAnchor)) {
        outcomes <- merge(outcomes, population[, c("rowId", "daysToCohortEnd")])
        priorOutcomeRowIds <- outcomes |>
          filter(
            .data$daysToEvent > -createStudyPopulationArgs$priorOutcomeLookback &
              outcomes$daysToEvent < outcomes$daysToCohortEnd + createStudyPopulationArgs$riskWindowStart
          ) |>
          pull("rowId")
      } else {
        priorOutcomeRowIds <- outcomes |>
          filter(
            .data$daysToEvent > -createStudyPopulationArgs$priorOutcomeLookback &
              .data$daysToEvent < createStudyPopulationArgs$riskWindowStart
          ) |>
          pull("rowId")
      }
      population <- population |>
        filter(!(.data$rowId %in% priorOutcomeRowIds))
      metaData$attrition <- rbind(
        metaData$attrition,
        getCounts(population, paste("No prior outcome"))
      )
    }
  }
  # Create risk windows:
  population$riskStart <- rep(createStudyPopulationArgs$riskWindowStart, nrow(population))
  if (isEnd(createStudyPopulationArgs$startAnchor)) {
    population$riskStart <- population$riskStart + population$daysToCohortEnd
  }
  population$riskEnd <- rep(createStudyPopulationArgs$riskWindowEnd, nrow(population))
  if (isEnd(createStudyPopulationArgs$endAnchor)) {
    population$riskEnd <- population$riskEnd + population$daysToCohortEnd
  }
  idx <- population$riskEnd > population$daysToObsEnd
  population$riskEnd[idx] <- population$daysToObsEnd[idx]

  if (!is.null(createStudyPopulationArgs$maxDaysAtRisk)) {
    idx <- population$riskEnd > population$riskStart + createStudyPopulationArgs$maxDaysAtRisk
    if (any(idx)) {
      population$riskEnd[idx] <- population$riskStart[idx] + createStudyPopulationArgs$maxDaysAtRisk
    }
  }
  if (createStudyPopulationArgs$censorAtNewRiskWindow) {
    message("Censoring time at risk of recurrent subjects at start of new time at risk")
    if (nrow(population) > 1) {
      population$startDate <- population$cohortStartDate + population$riskStart
      population$endDate <- population$cohortStartDate + population$riskEnd
      population <- population |>
        arrange(.data$personSeqId, .data$riskStart)
      idx <- seq_len(nrow(population) - 1)
      idx <- which(population$endDate[idx] >= population$startDate[idx + 1] &
                     population$personSeqId[idx] == population$personSeqId[idx + 1])
      if (length(idx) > 0) {
        population$endDate[idx] <- population$startDate[idx + 1] - 1
        population$riskEnd[idx] <- population$endDate[idx] - population$cohortStartDate[idx]
        idx <- population$riskEnd < population$riskStart
        if (any(idx)) {
          population <- population[!idx, ]
        }
      }
      population$startDate <- NULL
      population$endDate <- NULL
      metaData$attrition <- rbind(
        metaData$attrition,
        getCounts(population, paste("Censoring at start of new time-at-risk"))
      )
    }
  }
  if (createStudyPopulationArgs$minDaysAtRisk != 0) {
    message(paste("Removing subjects with less than", createStudyPopulationArgs$minDaysAtRisk, "day(s) at risk (if any)"))
    population <- population |>
      filter(1 + .data$riskEnd - .data$riskStart >= createStudyPopulationArgs$minDaysAtRisk)
    metaData$attrition <- rbind(metaData$attrition, getCounts(population, paste(
      "Have at least",
      createStudyPopulationArgs$minDaysAtRisk,
      "days at risk"
    )))
  }
  if (is.null(outcomeId)) {
    message("No outcome specified so not creating outcome and time variables")
  } else {
    # Select outcomes during time at risk
    outcomes <- cohortMethodData$outcomes |>
      filter(.data$outcomeId == !!outcomeId) |>
      collect()
    outcomes <- merge(outcomes, population[, c("rowId", "riskStart", "riskEnd")])
    outcomes <- outcomes |>
      filter(
        .data$daysToEvent >= .data$riskStart &
          .data$daysToEvent <= .data$riskEnd
      )

    # Create outcome count column
    if (nrow(outcomes) == 0) {
      population$outcomeCount <- rep(0, nrow(population))
    } else {
      outcomeCount <- outcomes |>
        group_by(.data$rowId) |>
        summarise(outcomeCount = length(.data$outcomeId))
      population$outcomeCount <- 0
      population$outcomeCount[match(outcomeCount$rowId, population$rowId)] <- outcomeCount$outcomeCount
    }

    # Create time at risk column
    population$timeAtRisk <- population$riskEnd - population$riskStart + 1

    # Create survival time column
    firstOutcomes <- outcomes |>
      arrange(.data$rowId, .data$daysToEvent) |>
      filter(!duplicated(.data$rowId))
    population$daysToEvent <- rep(NA, nrow(population))
    population$daysToEvent[match(firstOutcomes$rowId, population$rowId)] <- firstOutcomes$daysToEvent
    population$survivalTime <- population$timeAtRisk
    population$survivalTime[population$outcomeCount != 0] <- population$daysToEvent[population$outcomeCount !=
                                                                                      0] - population$riskStart[population$outcomeCount != 0] + 1
  }
  attr(population, "metaData") <- metaData
  ParallelLogger::logDebug("Study population has ", nrow(population), " rows")
  return(population)
}

#' Get the attrition table for a population
#'
#' @param object   Either an object of type [CohortMethodData], a population object generated by
#'                 functions like [createStudyPopulation()], or an object of type
#'                 `outcomeModel`.
#'
#' @return
#' A `tibble` specifying the number of people and exposures in the population after specific steps
#' of filtering.
#'
#'
#' @export
getAttritionTable <- function(object) {
  if (is(object, "OutcomeModel")) {
    return(object$attrition)
  } else {
    return(attr(object, "metaData")$attrition)
  }
}

getCounts <- function(population, description = "") {
  targetPersons <- length(unique(population$personSeqId[population$treatment == 1]))
  comparatorPersons <- length(unique(population$personSeqId[population$treatment == 0]))
  targetExposures <- length(population$personSeqId[population$treatment == 1])
  comparatorExposures <- length(population$personSeqId[population$treatment == 0])
  counts <- tibble(
    description = description,
    targetPersons = targetPersons,
    comparatorPersons = comparatorPersons,
    targetExposures = targetExposures,
    comparatorExposures = comparatorExposures
  )
  return(counts)
}


#' Plot time-to-event
#'
#' @details
#' Creates a plot showing the number of events over time in the target and comparator cohorts, both before and after
#' index date. The plot also distinguishes between events inside and outside the time-at-risk period. This requires
#' the user to (re)specify the time-at-risk using the same arguments as the [createStudyPopulation()] function.
#' Note that it is not possible to specify that people with the outcome prior should be removed, since the plot will
#' show these prior events.
#'
#' @template CohortMethodData
#'
#' @param population                       If specified, this population will be used as the starting
#'                                         point instead of the cohorts in the `cohortMethodData` object.
#' @param outcomeId                        The ID of the outcome. If NULL, no outcome-specific
#'                                         transformations will be performed.
#' @param minDaysAtRisk                    The minimum required number of days at risk.
#' @param riskWindowStart                  The start of the risk window (in days) relative to the `startAnchor`.
#' @param startAnchor                      The anchor point for the start of the risk window. Can be `"cohort start"`
#'                                         or `"cohort end"`.
#' @param riskWindowEnd                    The end of the risk window (in days) relative to the `endAnchor`.
#' @param endAnchor                        The anchor point for the end of the risk window. Can be `"cohort start"`
#'                                         or `"cohort end"`.
#' @param censorAtNewRiskWindow            If a subject is in multiple cohorts, should time-at-risk be censored
#'                                         when the new time-at-risk starts to prevent overlap?
#' @param periodLength                     The length in days of each period shown in the plot.
#' @param numberOfPeriods                  Number of periods to show in the plot. The periods are
#'                                         equally divided before and after the index date.
#' @param highlightExposedEvents           (logical) Highlight event counts during exposure in a different color?
#' @param includePostIndexTime             (logical) Show time after the index date?
#' @param showFittedLines                  (logical) Fit lines to the proportions and show them in the plot?
#' @param targetLabel                      A label to us for the target cohort.
#' @param comparatorLabel                  A label to us for the comparator cohort.
#' @param title                            Optional: the main title for the plot.
#' @param fileName                         Name of the file where the plot should be saved, for example
#'                                         'plot.png'. See [ggplot2::ggsave()] for supported file formats.
#'
#' @return
#' A ggplot object. Use the [ggplot2::ggsave()] function to save to file in a different
#' format.
#'
#' @export
plotTimeToEvent <- function(cohortMethodData,
                            population = NULL,
                            outcomeId = NULL,
                            minDaysAtRisk = 1,
                            riskWindowStart = 0,
                            startAnchor = "cohort start",
                            riskWindowEnd = 0,
                            endAnchor = "cohort end",
                            censorAtNewRiskWindow = FALSE,
                            periodLength = 7,
                            numberOfPeriods = 52,
                            highlightExposedEvents = TRUE,
                            includePostIndexTime = TRUE,
                            showFittedLines = TRUE,
                            targetLabel = "Target",
                            comparatorLabel = "Comparator",
                            title = NULL,
                            fileName = NULL) {
  errorMessages <- checkmate::makeAssertCollection()
  checkmate::assertInt(periodLength, lower = 0, add = errorMessages)
  checkmate::assertInt(numberOfPeriods, lower = 0, add = errorMessages)
  checkmate::assertLogical(highlightExposedEvents, len = 1, add = errorMessages)
  checkmate::assertLogical(includePostIndexTime, len = 1, add = errorMessages)
  checkmate::assertLogical(showFittedLines, len = 1, add = errorMessages)
  checkmate::assertCharacter(targetLabel, len = 1, add = errorMessages)
  checkmate::assertCharacter(comparatorLabel, len = 1, add = errorMessages)
  checkmate::assertCharacter(title, len = 1, null.ok = TRUE, add = errorMessages)
  checkmate::assertCharacter(fileName, len = 1, null.ok = TRUE, add = errorMessages)
  checkmate::reportAssertions(collection = errorMessages)
  if (is.null(population)) {
    population <- cohortMethodData$cohorts |>
      collect()
  }
  population <- createStudyPopulation(
    cohortMethodData = cohortMethodData,
    population = population,
    outcomeId = outcomeId,
    createStudyPopulationArgs = createCreateStudyPopulationArgs(
      removeSubjectsWithPriorOutcome = FALSE,
      minDaysAtRisk = minDaysAtRisk,
      riskWindowStart = riskWindowStart,
      startAnchor = startAnchor,
      riskWindowEnd = riskWindowEnd,
      endAnchor = endAnchor,
      censorAtNewRiskWindow = censorAtNewRiskWindow
    )
  )
  outcomes <- cohortMethodData$outcomes |>
    filter(.data$outcomeId == !!outcomeId) |>
    select("rowId", "daysToEvent") |>
    collect()

  outcomes <- outcomes |>
    inner_join(select(population, "rowId", "treatment", "daysFromObsStart", "daysToObsEnd", "riskStart", "riskEnd"),
               by = "rowId"
    ) |>
    filter(-.data$daysFromObsStart <= .data$daysToEvent & .data$daysToObsEnd >= .data$daysToEvent) |>
    mutate(exposed = .data$daysToEvent >= .data$riskStart & .data$daysToEvent <= .data$riskEnd)

  idxExposed <- outcomes$exposed == 1
  idxTarget <- outcomes$treatment == 1
  createPeriod <- function(number) {
    start <- number * periodLength
    end <- number * periodLength + periodLength
    idxInPeriod <- outcomes$daysToEvent >= start & outcomes$daysToEvent < end
    idxPopInPeriod <- -population$daysFromObsStart <= start & population$daysToObsEnd >= end
    tibble(
      number = number,
      start = start,
      end = end,
      eventsExposed = 0,
      eventsUnexposed = 0,
      observed = 0,
      eventsExposedTarget = sum(idxInPeriod & idxExposed & idxTarget),
      eventsExposedComparator = sum(idxInPeriod & idxExposed & !idxTarget),
      eventsUnexposedTarget = sum(idxInPeriod & !idxExposed & idxTarget),
      eventsUnexposedComparator = sum(idxInPeriod & !idxExposed & !idxTarget),
      observedTarget = sum(idxPopInPeriod & population$treatment),
      observedComparator = sum(idxPopInPeriod & !population$treatment)
    )
  }
  periods <- lapply(-floor(numberOfPeriods / 2):ceiling(numberOfPeriods / 2), createPeriod)
  periods <- do.call("rbind", periods)
  periods <- periods |>
    filter(.data$observedTarget > 0) |>
    mutate(
      rateExposedTarget = .data$eventsExposedTarget / .data$observedTarget,
      rateUnexposedTarget = .data$eventsUnexposedTarget / .data$observedTarget,
      rateExposedComparator = .data$eventsExposedComparator / .data$observedComparator,
      rateUnexposedComparator = .data$eventsUnexposedComparator / .data$observedComparator,
      rateTarget = (.data$eventsExposedTarget + .data$eventsUnexposedTarget) / .data$observedTarget,
      rateComparator = (.data$eventsExposedComparator + .data$eventsUnexposedComparator) / .data$observedComparator
    )
  if (!includePostIndexTime) {
    periods <- periods |>
      filter(.data$end <= 0)
  }
  vizData <- rbind(
    tibble(
      start = periods$start,
      end = periods$end,
      rate = periods$rateExposedTarget,
      status = "Exposed events",
      type = targetLabel
    ),
    tibble(
      start = periods$start,
      end = periods$end,
      rate = periods$rateUnexposedTarget,
      status = "Unexposed events",
      type = targetLabel
    ),
    tibble(
      start = periods$start,
      end = periods$end,
      rate = periods$rateExposedComparator,
      status = "Exposed events",
      type = comparatorLabel
    ),
    tibble(
      start = periods$start,
      end = periods$end,
      rate = periods$rateUnexposedComparator,
      status = "Unexposed events",
      type = comparatorLabel
    )
  )
  vizData$type <- factor(vizData$type, levels = c(targetLabel, comparatorLabel))

  if (highlightExposedEvents) {
    plot <- ggplot2::ggplot(vizData, ggplot2::aes(
      x = .data$start + periodLength / 2,
      y = .data$rate * 1000,
      fill = .data$status
    )) +
      ggplot2::geom_col(width = periodLength, alpha = 0.7)
  } else {
    plot <- ggplot2::ggplot(vizData, ggplot2::aes(
      x = .data$start + periodLength / 2,
      y = .data$rate * 1000
    )) +
      ggplot2::geom_col(width = periodLength, alpha = 0.7, fill = rgb(0, 0, 0.8))
  }
  plot <- plot +
    ggplot2::geom_vline(xintercept = 0, colour = "#000000", lty = 1, linewidth = 1) +
    ggplot2::scale_fill_manual(values = c(
      rgb(0.8, 0, 0),
      rgb(0, 0, 0.8)
    )) +
    ggplot2::scale_x_continuous("Days since exposure start") +
    ggplot2::scale_y_continuous("Proportion (per 1,000 persons)") +
    ggplot2::facet_grid(type ~ ., scales = "free_y") +
    ggplot2::theme(
      panel.grid.minor = ggplot2::element_blank(),
      panel.background = ggplot2::element_rect(fill = "#FAFAFA", colour = NA),
      panel.grid.major = ggplot2::element_line(colour = "#AAAAAA"),
      axis.ticks = ggplot2::element_blank(),
      strip.background = ggplot2::element_blank(),
      legend.title = ggplot2::element_blank(),
      legend.position = "top"
    )

  if (showFittedLines) {
    preTarget <- periods[periods$start < 0, ]
    preTarget <- cbind(preTarget, predict(lm(rateTarget ~ poly(number, 3), data = preTarget), interval = "confidence"))
    preTarget$type <- targetLabel
    preTarget$period <- "Pre"
    preComparator <- periods[periods$start < 0, ]
    preComparator <- cbind(preComparator, predict(lm(rateComparator ~ poly(number, 3), data = preComparator), interval = "confidence"))
    preComparator$type <- comparatorLabel
    preComparator$period <- "Pre"
    curve <- bind_rows(preTarget, preComparator)
    if (includePostIndexTime) {
      postTarget <- periods[periods$start >= 0, ]
      postTarget <- cbind(postTarget, predict(lm(rateTarget ~ poly(number, 3), data = postTarget), interval = "confidence"))
      postTarget$type <- targetLabel
      postTarget$period <- "Post"
      postComparator <- periods[periods$start >= 0, ]
      postComparator <- cbind(postComparator, predict(lm(rateComparator ~ poly(number, 3), data = postComparator), interval = "confidence"))
      postComparator$type <- comparatorLabel
      postComparator$period <- "Post"
      curve <- bind_rows(curve, postTarget, postComparator)
    }
    curve <- curve |>
      mutate(
        rate = 0,
        status = "Exposed events",
        type = factor(.data$type, levels = c(targetLabel, comparatorLabel)),
        lwr = if_else(.data$lwr < 0, 0, .data$lwr)
      )


    plot <- plot + ggplot2::geom_ribbon(
      ggplot2::aes(
        x = start + periodLength / 2,
        ymin = .data$lwr * 1000,
        ymax = .data$upr * 1000,
        group = .data$period
      ),
      fill = rgb(0, 0, 0),
      alpha = 0.3,
      data = curve
    ) +
      ggplot2::geom_line(
        ggplot2::aes(
          x = start + periodLength / 2,
          y = .data$fit * 1000,
          group = .data$period
        ),
        linewidth = 1.5,
        alpha = 0.8,
        data = curve
      )
  }

  if (!is.null(title)) {
    plot <- plot + ggplot2::ggtitle(title)
  }
  if (!is.null(fileName)) {
    ggplot2::ggsave(fileName, plot, width = 7, height = 5, dpi = 400)
  }
  return(plot)
}

Try the CohortMethod package in your browser

Any scripts or data that you put into this service are public.

CohortMethod documentation built on March 21, 2026, 5:06 p.m.