R/addCohortSurvival.R

Defines functions addCohortSurvival .dropTempTables .executeSurvivalSql .getSurvivalSqlTemplate .buildCensorExpression .generateTempNames .validateSurvivalInputs .getDbms

Documented in addCohortSurvival .buildCensorExpression .dropTempTables .executeSurvivalSql .generateTempNames .getDbms .getSurvivalSqlTemplate .validateSurvivalInputs

# ===========================================================================
# Internal helpers
# ===========================================================================

#' Resolve DBMS from a DatabaseConnector connection
#' @keywords internal
.getDbms <- function(connection) {
  dbms <- tryCatch(
    DatabaseConnector::dbms(connection),
    error = function(e) attr(connection, "dbms")
  )
  if (is.null(dbms)) stop("Could not determine DBMS from `connection`.")
  dbms
}

#' Validate survival-specific inputs
#' @keywords internal
.validateSurvivalInputs <- function(outcomeDateVariable, chunkSize,
                                    outcomeWashout, minDaysToEvent,
                                    followUpDays) {
  allowed_cols <- c("cohort_start_date", "cohort_end_date")
  if (!is.character(outcomeDateVariable) || length(outcomeDateVariable) != 1 ||
    !(outcomeDateVariable %in% allowed_cols)) {
    stop("`outcomeDateVariable` must be 'cohort_start_date' or 'cohort_end_date'.")
  }
  if (!is.null(chunkSize) &&
    (!is.numeric(chunkSize) || chunkSize <= 0 || chunkSize != as.integer(chunkSize))) {
    stop("`chunkSize` must be a positive integer when provided.")
  }
  if (!is.numeric(outcomeWashout) || outcomeWashout < 0) {
    stop("`outcomeWashout` must be >= 0 (or Inf).")
  }
  if (!is.numeric(minDaysToEvent) || minDaysToEvent < 0) {
    stop("`minDaysToEvent` must be >= 0.")
  }
  if (!is.numeric(followUpDays) || followUpDays < 0) {
    stop("`followUpDays` must be >= 0 (or Inf).")
  }
  invisible(NULL)
}

#' Generate unique temp-table names for a single run
#' @keywords internal
.generateTempNames <- function() {
  id <- paste0(sample(letters, 8, replace = TRUE), collapse = "")
  list(
    target  = paste0("#surv_tgt_", id),
    obs     = paste0("#surv_obs_", id),
    coh_obs = paste0("#surv_co_", id),
    outcome = paste0("#surv_out_", id),
    washout = paste0("#surv_wo_", id),
    events  = paste0("#surv_ev_", id),
    result  = paste0("#surv_res_", id),
    id      = id
  )
}

#' Build a nested CASE-WHEN censor expression
#' @keywords internal
.buildCensorExpression <- function(censorOnCohortExit, hasCensorDate,
                                   hasFollowUpLimit, followUpDays) {
  parts <- "days_to_exit"
  if (censorOnCohortExit) {
    parts <- c(parts, "DATEDIFF(day, cohort_start_date, cohort_end_date)")
  }
  if (hasCensorDate) {
    parts <- c(parts, "DATEDIFF(day, cohort_start_date, '@censor_date')")
  }
  if (hasFollowUpLimit) {
    parts <- c(parts, as.character(followUpDays))
  }

  expr <- parts[1]
  if (length(parts) > 1) {
    for (i in 2:length(parts)) {
      expr <- sprintf(
        "CASE WHEN %s < %s THEN %s ELSE %s END",
        parts[i], expr, parts[i], expr
      )
    }
  }
  expr
}

#' Return the parameterised SQL template
#'
#' Age / gender columns are included only when the corresponding
#' \code{@@include_age} / \code{@@include_gender} flags are TRUE.
#' The person-table JOIN is included when \code{@@include_demographics} is TRUE.
#' @keywords internal
.getSurvivalSqlTemplate <- function() {
  "
  -- Step 1: Extract target cohort (with optional demographics)
  DROP TABLE IF EXISTS @temp_target;

  SELECT
    subject_id,
    cohort_start_date,
    cohort_end_date,
    cohort_definition_id
    {@include_age}?{, YEAR(cohort_start_date) - p.year_of_birth AS age_years}
    {@include_gender}?{,
      CASE
        WHEN p.gender_concept_id = 8507 THEN 'Male'
        WHEN p.gender_concept_id = 8532 THEN 'Female'
        ELSE 'Other'
      END AS gender}
  INTO @temp_target
  FROM @cohort_database_schema.@target_cohort_table
  {@include_demographics}?{
  JOIN @cdm_database_schema.person p
    ON p.person_id = subject_id
  }
  WHERE cohort_definition_id = @target_cohort_id;

  CREATE INDEX idx_@temp_id_target ON @temp_target (subject_id, cohort_start_date);

  -- Step 2: Overlapping observation periods
  DROP TABLE IF EXISTS @temp_obs;

  SELECT
    person_id,
    observation_period_start_date,
    observation_period_end_date
  INTO @temp_obs
  FROM @cdm_database_schema.@observation_period_table oop
  INNER JOIN @temp_target t
    ON oop.person_id = t.subject_id
   AND t.cohort_start_date >= oop.observation_period_start_date
   AND t.cohort_start_date <= oop.observation_period_end_date;

  CREATE INDEX idx_@temp_id_obs ON @temp_obs (person_id);

  -- Step 3: Compute days to observation exit
  DROP TABLE IF EXISTS @temp_coh_obs;

  SELECT
    t.subject_id,
    t.cohort_start_date,
    t.cohort_end_date,
    t.cohort_definition_id,
    {@include_age}?{t.age_years,}
    {@include_gender}?{t.gender,}
    DATEDIFF(day, t.cohort_start_date, o.observation_period_end_date) AS days_to_exit
  INTO @temp_coh_obs
  FROM @temp_target t
  INNER JOIN @temp_obs o
    ON t.subject_id = o.person_id
   AND t.cohort_start_date >= o.observation_period_start_date
   AND t.cohort_start_date <= o.observation_period_end_date;

  CREATE INDEX idx_@temp_id_coh_obs ON @temp_coh_obs (subject_id, cohort_start_date);

  -- Step 4: Extract outcome events
  DROP TABLE IF EXISTS @temp_outcome;

  SELECT
    subject_id,
    @outcome_date_variable AS outcome_date,
    cohort_definition_id
  INTO @temp_outcome
  FROM @outcome_database_schema.@outcome_cohort_table
  WHERE cohort_definition_id = @outcome_cohort_id
    AND subject_id IN (SELECT subject_id FROM @temp_target);

  CREATE INDEX idx_@temp_id_outcome ON @temp_outcome (subject_id, outcome_date);

  -- Step 5: Washout filter
  DROP TABLE IF EXISTS @temp_washout;

  {@has_washout}?{
  SELECT
    c.subject_id,
    c.cohort_start_date,
    CASE WHEN COUNT(o.subject_id) > 0 THEN 1 ELSE 0 END AS event_in_washout
  INTO @temp_washout
  FROM @temp_coh_obs c
  LEFT JOIN @temp_outcome o
    ON c.subject_id = o.subject_id
   AND o.outcome_date >= DATEADD(day, -@washout, c.cohort_start_date)
   AND o.outcome_date <  c.cohort_start_date
  GROUP BY c.subject_id, c.cohort_start_date;
  }:{
  SELECT
    subject_id,
    cohort_start_date,
    0 AS event_in_washout
  INTO @temp_washout
  FROM @temp_coh_obs;
  }

  CREATE INDEX idx_@temp_id_washout ON @temp_washout (subject_id, cohort_start_date);

  -- Step 6: Join events, compute days-to-event
  DROP TABLE IF EXISTS @temp_events;

  SELECT
    c.subject_id,
    c.cohort_start_date,
    c.cohort_end_date,
    c.cohort_definition_id,
    {@include_age}?{c.age_years,}
    {@include_gender}?{c.gender,}
    c.days_to_exit,
    w.event_in_washout,
    MIN(DATEDIFF(day, c.cohort_start_date, o.outcome_date)) {@add_day}?{+ 1} AS days_to_event
  INTO @temp_events
  FROM @temp_coh_obs c
  LEFT JOIN @temp_washout w
    ON c.subject_id = w.subject_id
   AND c.cohort_start_date = w.cohort_start_date
  LEFT JOIN @temp_outcome o
    ON c.subject_id = o.subject_id
   AND o.outcome_date >= c.cohort_start_date
   {@has_min_days}?{AND DATEDIFF(day, c.cohort_start_date, o.outcome_date) >= @min_days_to_event}
  GROUP BY
    c.subject_id, c.cohort_start_date, c.cohort_end_date,
    c.cohort_definition_id, {@include_age}?{c.age_years,} {@include_gender}?{c.gender,}
    c.days_to_exit, w.event_in_washout;

  -- Step 7: Compute time and status
  DROP TABLE IF EXISTS @temp_result;

  SELECT
    ROW_NUMBER() OVER (ORDER BY subject_id, cohort_start_date) AS subject_id,
    CASE
      WHEN event_in_washout = 1 THEN NULL
      {@has_censor_date}?{
      WHEN cohort_start_date > CAST('@censor_date' AS DATE) THEN NULL
      }
      WHEN days_to_event IS NOT NULL AND days_to_event <= @censor_expression THEN days_to_event
      ELSE @censor_expression
    END AS time,
    CASE
      WHEN event_in_washout = 1 THEN NULL
      {@has_censor_date}?{
      WHEN cohort_start_date > CAST('@censor_date' AS DATE) THEN NULL
      }
      WHEN days_to_event IS NOT NULL AND days_to_event <= @censor_expression THEN 1
      ELSE 0
    END AS status
    {@include_age}?{, age_years}
    {@include_gender}?{, gender}
  INTO @temp_result
  FROM @temp_events;

  -- Step 8: Return results
  SELECT * FROM @temp_result
  {@has_chunking}?{
    ORDER BY subject_id
    OFFSET @offset ROWS FETCH NEXT @chunk_size ROWS ONLY
  };
  "
}

#' Execute pre-rendered SQL statements and return the final result set
#' @keywords internal
.executeSurvivalSql <- function(connection, sql, dbms, tempEmulationSchema,
                                hasChunking, chunkSize, tempResult) {
  statements <- SqlRender::splitSql(sql)

  # Run all setup statements (everything except the trailing SELECT)
  if (length(statements) > 1) {
    for (i in seq_len(length(statements) - 1)) {
      DatabaseConnector::executeSql(
        connection, paste0(statements[i], ";"),
        progressBar = FALSE, reportOverallTime = FALSE
      )
    }
  }

  # Fetch results — chunked or single
  if (hasChunking) {
    chunks <- list()
    offset <- 0L
    repeat {
      chunkSql <- SqlRender::render(
        "SELECT * FROM @t ORDER BY subject_id OFFSET @offset ROWS FETCH NEXT @n ROWS ONLY",
        t = tempResult, offset = offset, n = chunkSize
      )
      chunkSql <- SqlRender::translate(
        chunkSql,
        targetDialect = dbms, tempEmulationSchema = tempEmulationSchema
      )
      chunk <- DatabaseConnector::querySql(connection, chunkSql)
      if (nrow(chunk) == 0) break
      chunks[[length(chunks) + 1L]] <- chunk
      offset <- offset + chunkSize
      if (nrow(chunk) < chunkSize) break
    }
    if (length(chunks)) do.call(rbind, chunks) else data.frame()
  } else {
    DatabaseConnector::querySql(
      connection, paste0(statements[length(statements)], ";")
    )
  }
}

#' Drop a vector of temp tables (best-effort, errors silenced)
#' @keywords internal
.dropTempTables <- function(connection, dbms, tempEmulationSchema, tables) {
  for (tbl in tables) {
    sql <- SqlRender::translate(
      SqlRender::render("DROP TABLE IF EXISTS @t;", t = tbl),
      targetDialect = dbms, tempEmulationSchema = tempEmulationSchema
    )
    try(
      DatabaseConnector::executeSql(
        connection, sql, progressBar = FALSE, reportOverallTime = FALSE
      ),
      silent = TRUE
    )
  }
}


# ===========================================================================
# Main entry point
# ===========================================================================

#' Calculate survival data for a cohort
#'
#' @param connection DatabaseConnector connection object
#' @param cdmDatabaseSchema Schema containing CDM tables
#' @param cohortDatabaseSchema Schema containing cohort tables
#' @param targetCohortTable Name of the target cohort table
#' @param targetCohortId ID of the target cohort
#' @param outcomeCohortTable Name of the outcome cohort table
#' @param outcomeCohortId ID of the outcome cohort (default: 1)
#' @param outcomeDateVariable Date variable for outcome ("cohort_start_date" or "cohort_end_date")
#' @param outcomeWashout Washout period in days (default: Inf for no washout)
#' @param minDaysToEvent Minimum days between cohort entry and outcome event (default: 0)
#' @param censorOnCohortExit Whether to censor at cohort exit (default: FALSE)
#' @param censorOnDate Specific date to censor at (default: NULL)
#' @param followUpDays Maximum follow-up days (default: Inf)
#' @param includeAge Whether to include age_years in the output (default: FALSE)
#' @param includeGender Whether to include gender in the output (default: FALSE)
#' @param observationPeriodTable Name of observation period table (default: "observation_period")
#' @param tempEmulationSchema Schema for temp table emulation (default: NULL)
#' @param chunkSize Optional chunk size for large cohorts (default: NULL)
#' @param addDay Logical; add one day to the outcome date? (default: FALSE)
#'
#' @return Data frame with columns \code{subject_id}, \code{time}, \code{status}
#'   and optionally \code{age_years} and/or \code{gender}.
#' @keywords internal
#'
addCohortSurvival <- function(
    connection,
    cdmDatabaseSchema,
    cohortDatabaseSchema,
    targetCohortTable,
    targetCohortId,
    outcomeCohortTable,
    outcomeCohortId = 1,
    outcomeDatabaseSchema = cohortDatabaseSchema,
    outcomeDateVariable = "cohort_start_date",
    outcomeWashout = Inf,
    minDaysToEvent = 0,
    censorOnCohortExit = FALSE,
    censorOnDate = NULL,
    followUpDays = Inf,
    includeAge = FALSE,
    includeGender = FALSE,
    observationPeriodTable = "observation_period",
    tempEmulationSchema = NULL,
    chunkSize = NULL,
    addDay = FALSE) {
  # ---- setup -------------------------------------------------------------
  dbms <- .getDbms(connection)
  .validateSurvivalInputs(
    outcomeDateVariable, chunkSize, outcomeWashout, minDaysToEvent, followUpDays
  )

  # Compute flags
  hasWashout       <- is.finite(outcomeWashout) && outcomeWashout > 0
  hasCensorDate    <- !is.null(censorOnDate)
  hasFollowUpLimit <- is.finite(followUpDays)
  hasMinDays       <- minDaysToEvent > 0
  hasChunking      <- !is.null(chunkSize)

  # Temp table names + cleanup on exit

  tmp <- .generateTempNames()
  on.exit(
    .dropTempTables(connection, dbms, tempEmulationSchema, unlist(tmp[1:7])),
    add = TRUE
  )

  # ---- build SQL ---------------------------------------------------------
  censorExpr <- .buildCensorExpression(
    censorOnCohortExit, hasCensorDate, hasFollowUpLimit, followUpDays
  )

  sql <- SqlRender::render(
    .getSurvivalSqlTemplate(),
    cdm_database_schema      = cdmDatabaseSchema,
    cohort_database_schema   = cohortDatabaseSchema,
    target_cohort_table      = targetCohortTable,
    target_cohort_id         = targetCohortId,
    outcome_database_schema = outcomeDatabaseSchema,
    outcome_cohort_table     = outcomeCohortTable,
    outcome_cohort_id        = outcomeCohortId,
    outcome_date_variable    = outcomeDateVariable,
    observation_period_table = observationPeriodTable,
    temp_target              = tmp$target,
    temp_obs                 = tmp$obs,
    temp_coh_obs             = tmp$coh_obs,
    temp_outcome             = tmp$outcome,
    temp_washout             = tmp$washout,
    temp_events              = tmp$events,
    temp_result              = tmp$result,
    temp_id                  = tmp$id,
    include_age              = includeAge,
    include_gender           = includeGender,
    include_demographics     = includeAge || includeGender,
    has_washout              = hasWashout,
    washout                  = if (hasWashout) outcomeWashout else 0,
    has_min_days             = hasMinDays,
    min_days_to_event        = minDaysToEvent,
    has_censor_date          = hasCensorDate,
    censor_date              = if (hasCensorDate) format(as.Date(censorOnDate), "%Y-%m-%d") else "",
    censor_expression        = censorExpr,
    has_chunking             = hasChunking,
    chunk_size               = if (hasChunking) chunkSize else 0,
    offset                   = 0,
    add_day                  = addDay
  )
  sql <- SqlRender::translate(
    sql, targetDialect = dbms, tempEmulationSchema = tempEmulationSchema
  )

  # ---- execute & return --------------------------------------------------
  result <- tryCatch(
    .executeSurvivalSql(
      connection, sql, dbms, tempEmulationSchema, hasChunking, chunkSize, tmp$result
    ),
    error = function(e) stop("Failed to execute survival query: ", e$message)
  )

  names(result) <- tolower(names(result))
  outputCols <- c("subject_id", "time", "status")
  if (includeAge)    outputCols <- c(outputCols, "age_years")
  if (includeGender) outputCols <- c(outputCols, "gender")
  result[, outputCols, drop = FALSE]
}

Try the OdysseusSurvivalModule package in your browser

Any scripts or data that you put into this service are public.

OdysseusSurvivalModule documentation built on April 3, 2026, 5:06 p.m.