R/AttritionDiagram.R

Defines functions drawAttritionDiagram

Documented in drawAttritionDiagram

# @file AttritionDiagram.R
#
# Copyright 2024 Observational Health Data Sciences and Informatics
#
# This file is part of CohortMethod
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#' Draw the attrition diagram
#'
#' @description
#' \code{drawAttritionDiagram} draws the attrition diagram, showing how many people were excluded from
#' the study population, and for what reasons.
#'
#' @param object            Either an object of type \code{cohortMethodData}, a population object
#'                          generated by functions like \code{createStudyPopulation}, or an object of
#'                          type \code{outcomeModel}.
#' @param targetLabel       A label to us for the target cohort.
#' @param comparatorLabel   A label to us for the comparator cohort.
#' @param fileName          Name of the file where the plot should be saved, for example 'plot.png'.
#'                          See the function \code{ggsave} in the ggplot2 package for supported file
#'                          formats.
#'
#'
#' @return
#' A ggplot object. Use the \code{\link[ggplot2]{ggsave}} function to save to file in a different
#' format.
#'
#' @export
drawAttritionDiagram <- function(object,
                                 targetLabel = "Target",
                                 comparatorLabel = "Comparator",
                                 fileName = NULL) {
  errorMessages <- checkmate::makeAssertCollection()
  checkmate::assertCharacter(targetLabel, len = 1, add = errorMessages)
  checkmate::assertCharacter(fileName, len = 1, null.ok = TRUE, add = errorMessages)
  checkmate::reportAssertions(collection = errorMessages)

  attrition <- getAttritionTable(object)

  addStep <- function(data, attrition, row) {
    label <- paste(strwrap(as.character(attrition$description[row]), width = 30), collapse = "\n")
    data$leftBoxText[length(data$leftBoxText) + 1] <- label
    data$rightBoxText[length(data$rightBoxText) + 1] <- paste(targetLabel,
      ": n = ",
      data$currentTarget - attrition$targetPersons[row],
      "\n",
      comparatorLabel,
      ": n = ",
      data$currentComparator - attrition$comparatorPersons[row],
      sep = ""
    )
    data$currentTarget <- attrition$targetPersons[row]
    data$currentComparator <- attrition$comparatorPersons[row]
    return(data)
  }
  data <- list(leftBoxText = c(paste("Original cohorts:\n",
    targetLabel,
    ": n = ",
    attrition$targetPersons[1],
    "\n",
    comparatorLabel,
    ": n = ",
    attrition$comparatorPersons[1],
    sep = ""
  )), rightBoxText = c(""), currentTarget = attrition$targetPersons[1], currentComparator = attrition$comparatorPersons[1])
  for (i in 2:nrow(attrition)) {
    data <- addStep(data, attrition, i)
  }


  data$leftBoxText[length(data$leftBoxText) + 1] <- paste("Study population:\n",
    targetLabel,
    ": n = ",
    data$currentTarget,
    "\n",
    comparatorLabel,
    ": n = ",
    data$currentComparator,
    sep = ""
  )
  leftBoxText <- data$leftBoxText
  rightBoxText <- data$rightBoxText
  nSteps <- length(leftBoxText)

  boxHeight <- (1 / nSteps) - 0.03
  boxWidth <- 0.45
  shadowOffset <- 0.01
  arrowLength <- 0.01
  x <- function(x) {
    return(0.25 + ((x - 1) / 2))
  }
  y <- function(y) {
    return(1 - (y - 0.5) * (1 / nSteps))
  }

  downArrow <- function(p, x1, y1, x2, y2) {
    p <- p + ggplot2::geom_segment(ggplot2::aes(x = !!x1, y = !!y1, xend = !!x2, yend = !!y2))
    p <- p + ggplot2::geom_segment(ggplot2::aes(
      x = !!x2,
      y = !!y2,
      xend = !!(x2 + arrowLength),
      yend = !!(y2 + arrowLength)
    ))
    p <- p + ggplot2::geom_segment(ggplot2::aes(
      x = !!x2,
      y = !!y2,
      xend = !!(x2 - arrowLength),
      yend = !!(y2 + arrowLength)
    ))
    return(p)
  }
  rightArrow <- function(p, x1, y1, x2, y2) {
    p <- p + ggplot2::geom_segment(ggplot2::aes(x = !!x1, y = !!y1, xend = !!x2, yend = !!y2))
    p <- p + ggplot2::geom_segment(ggplot2::aes(
      x = !!x2,
      y = !!y2,
      xend = !!(x2 - arrowLength),
      yend = !!(y2 + arrowLength)
    ))
    p <- p + ggplot2::geom_segment(ggplot2::aes(
      x = !!x2,
      y = !!y2,
      xend = !!(x2 - arrowLength),
      yend = !!(y2 - arrowLength)
    ))
    return(p)
  }
  box <- function(p, x, y) {
    p <- p + ggplot2::geom_rect(ggplot2::aes(
      xmin = !!(x - (boxWidth / 2) + shadowOffset),
      ymin = !!(y - (boxHeight / 2) - shadowOffset),
      xmax = !!(x + (boxWidth / 2) + shadowOffset),
      ymax = !!(y + (boxHeight / 2) - shadowOffset)
    ), fill = rgb(0,
      0,
      0,
      alpha = 0.2
    ))
    p <- p + ggplot2::geom_rect(ggplot2::aes(
      xmin = !!(x - (boxWidth / 2)),
      ymin = !!(y - (boxHeight / 2)),
      xmax = !!(x + (boxWidth / 2)),
      ymax = !!(y + (boxHeight / 2))
    ), fill = rgb(
      0.94,
      0.94,
      0.94
    ), color = "black")
    return(p)
  }
  label <- function(p, x, y, text, hjust = 0) {
    p <- p + ggplot2::geom_text(ggplot2::aes(x = !!x, y = !!y, label = !!paste("\"", text, "\"",
      sep = ""
    )),
    hjust = hjust,
    size = 3.7
    )
    return(p)
  }

  p <- ggplot2::ggplot()
  for (i in 2:nSteps - 1) {
    p <- downArrow(p, x(1), y(i) - (boxHeight / 2), x(1), y(i + 1) + (boxHeight / 2))
    p <- label(p, x(1) + 0.02, y(i + 0.5), "Y")
  }
  for (i in 2:(nSteps - 1)) {
    p <- rightArrow(p, x(1) + boxWidth / 2, y(i), x(2) - boxWidth / 2, y(i))
    p <- label(p, x(1.5), y(i) - 0.02, "N", 0.5)
  }
  for (i in 1:nSteps) {
    p <- box(p, x(1), y(i))
  }
  for (i in 2:(nSteps - 1)) {
    p <- box(p, x(2), y(i))
  }
  for (i in 1:nSteps) {
    p <- label(p, x(1) - boxWidth / 2 + 0.02, y(i), text = leftBoxText[i])
  }
  for (i in 2:(nSteps - 1)) {
    p <- label(p, x(2) - boxWidth / 2 + 0.02, y(i), text = rightBoxText[i])
  }
  p <- p + ggplot2::theme(
    legend.position = "none",
    plot.background = ggplot2::element_blank(),
    panel.grid.major = ggplot2::element_blank(),
    panel.grid.minor = ggplot2::element_blank(),
    panel.border = ggplot2::element_blank(),
    panel.background = ggplot2::element_blank(),
    axis.text = ggplot2::element_blank(),
    axis.title = ggplot2::element_blank(),
    axis.ticks = ggplot2::element_blank()
  )

  if (!is.null(fileName)) {
    ggplot2::ggsave(p, filename = fileName, width = 6, height = 7, dpi = 400)
  }
  return(p)
}
OHDSI/CohortMethod documentation built on Oct. 9, 2024, 12:50 p.m.