R/simulate_counts.R

Defines functions .generate_qc_summary .perturb_cells .create_damage_label .assign_damage_levels .select_damaged_cells .get_beta_shape_parameters .get_steepness_value .check_simulate_inputs simulate_counts

Documented in simulate_counts

#' simulate_counts
#'
#' Function to simulate damaged cells by perturbing the gene expression of
#' existing cells.
#'
#' 'DamageDetective' models damage in single-cell RNA sequencing data as the
#' loss of cytoplasmic RNA, where cells experiencing greater RNA loss are
#' assumed to be more extensively damaged, while those with minimal loss are
#' considered largely intact. The perturbation process introduces RNA loss into
#' existing cells and is controlled by three key parameters: the **target
#' proportion of damage**,  which specifies the fraction of cells to be
#' perturbed; the **target level of damage**, which defines the extent of RNA
#' loss across cells; and the **target distribution of damage**, which
#' determines how the different levels of RNA loss are distributed across
#' cells.
#'
#' Based on these parameters, cells are randomly selected and assigned a target
#' proportion of RNA loss. The total number of transcripts to be removed is
#' determined, and perturbation is applied through weighted sampling without
#' replacement from cytoplasmic gene counts. Here, the probability of
#' transcript loss is determined by gene abundance, with highly expressed genes
#' more likely to lose transcripts. Once the target RNA loss is reached, the
#' cell's expression profile is updated, and the process repeats for all
#' selected cells.
#'
#' @param count_matrix Matrix or dgCMatrix containing the counts from
#'  single cell RNA sequencing data.
#' @param damage_proportion Numeric describing what proportion
#'  of the input data should be altered to resemble damaged data.
#'
#'  * Must range between 0 and 1.
#' @param annotated_celltypes Boolean specifying whether input matrix has
#'  cell type information stored.
#'
#'  * Default is FALSE
#' @param target_damage Numeric vector specifying the upper and lower range of
#'  the level of damage that will be introduced.
#'
#'  Here, damage refers to the amount of cytoplasmic RNA lost by a cell where
#'  values closer to 1 indicate more loss and therefore more heavily damaged
#'  cells.
#'
#'  * Default is c(0.1, 0.8)
#' @param damage_distribution String specifying whether the distribution of
#'  damage levels among the damaged cells should be shifted towards the
#'  upper or lower range of damage specified in 'target_damage' or follow
#'  a symmetric distribution between them. There are three valid options:
#'
#'  * "right_skewed"
#'  * "left_skewed"
#'  * "symmetric"
#'
#'  * Default is "right_skewed"
#' @param distribution_steepness String specifying how concentrated the spread
#'  of damaged cells are about the mean of the target distribution specified in
#'  'target_damage'. Here, an increase in steepness manifests in a more
#'  apparent skewness.There are three valid options:
#'
#'  * "shallow"
#'  * "moderate"
#'  * "steep"
#'
#'  * Default is "moderate"
#' @param beta_shape_parameters Numeric vector that allows for the shape
#'   parameters of the beta distribution to defined explicitly. This offers
#'   greater flexibility than allowed by the 'damage_distribution' and
#'   'distribution_steepness' parameters and will override the defaults they
#'   offer.
#'
#'   * Default is 'NULL'
#' @param ribosome_penalty Numeric specifying the factor by which the
#'  probability of loosing a transcript from a ribosomal gene is multiplied by.
#'  Here, values closer to 0 represent a greater penalty.
#'
#'  * Default is 0.01.
#' @param generate_plot Boolean specifying whether the QC plot should
#'  be outputted. QC plots will be generated by default as we recommend
#'  verifying the perturbed data retains characteristics of true
#'  single cell data.
#'
#'  * Default is TRUE.
#' @param plot_ribosomal_penalty Boolean specifying whether the output QC plot
#'  should focus on only the ribosomal proportion or contain additional QC
#'  information. If TRUE, this can be useful for visualising the impact of
#'  the ribosomal penalty parameter.
#'
#'  * Default is FALSE.
#' @param display_plot Boolean specifying whether the output QC plot should
#'   be displayed in the global environment. Naturally, this is only relevant
#'   when generate_plot is TRUE.
#'
#'   * Default is TRUE.
#' @param palette Character vector containing three colours to create the
#'    continuous palette for damaged cells.
#'
#'  * Default is c("grey", "#7023FD", "#E60006").
#' @param organism String specifying the organism of origin of the input
#'  data where there are two standard options,
#'
#'  * "Hsap"
#'  * "Mmus"
#'
#'  If a user wishes to use a non-standard organism they must input a list
#'  containing strings for the patterns to match mitochondrial and ribosomal
#'  genes of the organism. If available, nuclear-encoded genes that are likely
#'  retained in the nucleus, such as in nuclear speckles, must also
#'  be specified. An example for humans is below,
#'
#'  * organism = c(mito_pattern = "^MT-",
#'                 ribo_pattern = "^(RPS|RPL)",
#'                 nuclear <- c("NEAT1","XIST", "MALAT1")
#'
#' * Default is "Hsap"
#' @param seed Numeric specifying the random seed to ensure reproducibility of
#'  the function's output. Setting a seed ensures that the random sampling
#'  and perturbation processes produce the same results when the function
#'  is run multiple times with the same input data and parameters.
#'
#'  * Default is 7.
#' @return A list containing the altered count matrix, a data frame with summary
#'  statistics, and, if specified, a 'ggplot2' object of the quality control
#'  metrics of the alteration.
#' @import ggplot2
#' @importFrom stats rbeta
#' @importFrom withr with_seed
#' @import patchwork
#' @importFrom ggpubr get_legend
#' @importFrom cowplot ggdraw draw_label plot_grid
#' @export
#' @examples
#' data("test_counts", package = "DamageDetective")
#'
#' simulated_damage <- simulate_counts(
#'   count_matrix = test_counts,
#'   damage_proportion = 0.1,
#'   ribosome_penalty = 0.01,
#'   target_damage = c(0.5, 0.9),
#'   generate_plot = FALSE,
#'   seed = 7
#' )
simulate_counts <- function(
    count_matrix,
    damage_proportion,
    annotated_celltypes = FALSE,
    target_damage = c(0.1, 0.8),
    damage_distribution = "right_skewed",
    distribution_steepness = "moderate",
    beta_shape_parameters = NULL,
    ribosome_penalty = 0.001,
    generate_plot = TRUE,
    palette = c("grey", "#7023FD", "#E60006"),
    plot_ribosomal_penalty = FALSE,
    display_plot = TRUE,
    seed = NULL,
    organism = "Hsap"
) {
  # Data preparations ----
  .check_simulate_inputs(
    count_matrix,
    damage_proportion,
    beta_shape_parameters,
    target_damage,
    distribution_steepness,
    damage_distribution,
    ribosome_penalty,
    organism
  )

  # Calculate the number of damaged cells to simulate
  total_cells <- ncol(count_matrix)
  damaged_cell_number <- round(total_cells * damage_proportion)

  # Assign damage levels to the selected cells
  steepness_value <- .get_steepness_value(distribution_steepness)
  shape_params <- .get_beta_shape_parameters(
    damage_distribution, steepness_value, beta_shape_parameters
  )

  # Retrieve genes corresponding to the organism of interest
  gene_idx <- get_organism_indices(count_matrix, organism)

  # Select target cells for perturbation ----
  damaged_cell_selections <- .select_damaged_cells(
    count_matrix, annotated_celltypes, damage_proportion,
    damaged_cell_number, seed
  )

  # Consolidate into target beta distribution
  damage_levels <- .assign_damage_levels(
    damaged_cell_selections, shape_params, target_damage, seed
  )

  # Store the assigned damage levels
  damage_label <- .create_damage_label(
    count_matrix, damaged_cell_selections, damage_levels
  )

  # Perturb selected cells ----
  count_matrix <- .perturb_cells(
    count_matrix, damaged_cell_selections, damage_label,
    gene_idx, ribosome_penalty, seed
  )

  # Consolidate & plot simulation output ----
  qc_summary <- .generate_qc_summary(count_matrix, damage_label, gene_idx)

  if (generate_plot) {
    final_plot <- if (plot_ribosomal_penalty) {
      plot_ribosomal_penalty(qc_summary, palette)
    } else {
      plot_simulation_outcome(qc_summary, palette)
    }

    if (display_plot) {
      print(final_plot)
    }

    return(list(
      matrix = count_matrix,
      qc_summary = qc_summary,
      plot = final_plot
    ))
  }

  return(list(
    matrix = count_matrix,
    qc_summary = qc_summary
  ))
}

.check_simulate_inputs <- function(
    count_matrix, damage_proportion, beta_shape_parameters,
    target_damage, distribution_steepness, damage_distribution,
    ribosome_penalty, organism
) {
  # Check that count matrix is given
  if (is.null(count_matrix)) stop("Please provide 'count_matrix' input.")
  if (!inherits(count_matrix, "matrix") &
      !inherits(count_matrix, "CsparseMatrix")) {
    stop("Please ensure 'count_matrix' is a sparse matrix (dgCMatrix).")
  }

  # Ensure user adjustments to default parameters are executable
  if (is.null(damage_proportion)) stop("Please provide 'damage_proportion'.")
  if (!is.numeric(damage_proportion) ||
      damage_proportion < 0 ||
      damage_proportion > 1
  ) {
    stop("Please ensure 'damage_proportion' is a numeric between 0 and 1.")
  }
  if (!is.null(beta_shape_parameters) &
      length(beta_shape_parameters) != 2) {
    stop("Please ensure 'beta_shape_parameters' is of length 2.")
  }
  if (!is.numeric(target_damage) || length(target_damage) != 2 ||
      target_damage[1] < 0 || target_damage[2] > 1 ||
      target_damage[1] >= target_damage[2]) {
    stop("Please ensure 'target_damage' is a numeric vector of length 2,
    with values between 0 and 1, and the first value is less than the second.")
  }
  if (!distribution_steepness %in% c("shallow", "moderate", "steep")) {
    stop("Please ensure 'distribution_steepness' is one
         of 'shallow', 'moderate', or 'steep'.")
  }
  if (!damage_distribution %in% c("right_skewed", "left_skewed", "symmetric")) {
    stop("Please ensure 'damage_distribution' is one of 'right_skewed',
         'left_skewed', or 'symmetric'.")
  }
  if (!is.numeric(ribosome_penalty) || ribosome_penalty < 0 ||
      ribosome_penalty > 1) {
    stop("Please ensure 'ribosome_penalty' is a numeric between 0 and 1.")
  }
  if (!organism %in% c("Hsap", "Mmus") & length(organism) != 3) {
    stop("Please ensure 'organism' is one of 'Hsap' or 'Mmus',
         see documentation for non-standard organisms.")
  }
}

.get_steepness_value <- function(distribution_steepness) {
  steepness_levels <- list(
    shallow = 4,
    moderate = 7,
    steep = 14
  )
  return(steepness_levels[[distribution_steepness]])
}

.get_beta_shape_parameters <- function(
    damage_distribution, steepness_value, beta_shape_parameters
) {
  if (!is.null(beta_shape_parameters)) {
    return(beta_shape_parameters)
  }

  if (damage_distribution == "right_skewed") {
    return(c(steepness_value * 0.3, steepness_value * 0.7))
  } else if (damage_distribution == "left_skewed") {
    return(c(steepness_value * 0.7, steepness_value * 0.3))
  } else if (damage_distribution == "symmetric") {
    return(c(steepness_value * 0.5, steepness_value * 0.5))
  }
}

.select_damaged_cells <- function(
    count_matrix, annotated_celltypes, damage_proportion,
    damaged_cell_number, seed
) {
  if (annotated_celltypes) {
    cell_types <- as.factor(sub("_.*", "", colnames(count_matrix)))
    cell_type_counts <- table(cell_types)

    if (length(cell_type_counts) <= 1 || length(cell_type_counts) >= 1000) {
      stop("Please ensure cell types are assigned correctly.")
    }

    damage_per_type <- round(cell_type_counts * damage_proportion)
    total_damaged <- sum(damage_per_type)

    if (total_damaged != damaged_cell_number) {
      diff <- damaged_cell_number - total_damaged
      damage_per_type <- damage_per_type + round(
        diff * damage_per_type / sum(damage_per_type)
      )
    }

    damaged_cell_selections <- unlist(lapply(names(damage_per_type),
                                             function(ct) {
      cells_of_type <- which(cell_types == ct)
      withr::with_seed(seed, {
        sample(cells_of_type, size = damage_per_type[ct], replace = FALSE)
      })
    }))
  } else {
    damaged_cell_selections <- withr::with_seed(seed, {
      sample(1:ncol(count_matrix), damaged_cell_number, replace = FALSE)
    })
  }

  return(damaged_cell_selections)
}

.assign_damage_levels <- function(
    damaged_cell_selections, shape_params, target_damage, seed
) {
  damage_levels <- withr::with_seed(seed, {
    stats::rbeta(length(damaged_cell_selections),
                 shape1 = shape_params[1], shape2 = shape_params[2])
  })

  damage_levels <- target_damage[1] +
    (target_damage[2] - target_damage[1]) * damage_levels

  return(damage_levels)
}

.create_damage_label <- function(
    count_matrix, damaged_cell_selections, damage_levels
) {
  total_cells <- ncol(count_matrix)
  damage_status <- rep("control", total_cells)
  damage_status[damaged_cell_selections] <- "damaged"

  damage_label <- data.frame(
    barcode = colnames(count_matrix),
    status = damage_status,
    damage_level = 0,
    stringsAsFactors = FALSE
  )

  damaged_cells <- match(
    colnames(count_matrix)[damaged_cell_selections], damage_label$barcode
  )
  damage_label$damage_level[damaged_cells] <- damage_levels

  return(damage_label)
}

.perturb_cells <- function(
    count_matrix, damaged_cell_selections, damage_label,
    gene_idx, ribosome_penalty, seed
) {
  barcode_map <- match(colnames(count_matrix), damage_label$barcode)

  for (i in seq_along(damaged_cell_selections)) {
    cell <- damaged_cell_selections[i]
    cell_damage_level <- damage_label$damage_level[barcode_map[cell]]
    total_count <- sum(count_matrix[gene_idx$non_mito_idx, cell])
    total_loss <- round(cell_damage_level * total_count)

    transcripts <- rep(
      gene_idx$non_mito_idx,
      times = count_matrix[gene_idx$non_mito_idx, cell]
    )

    gene_totals <- count_matrix[gene_idx$non_mito_idx, cell]
    probabilities <- gene_totals / total_count
    probabilities[gene_idx$ribo_idx] <- probabilities[gene_idx$ribo_idx] *
      ribosome_penalty
    probabilities <- probabilities / sum(probabilities)

    prob_repeated <- rep(probabilities, times = gene_totals)

    lost_transcripts <- withr::with_seed(seed, {
      sample(
        transcripts,
        size = total_loss,
        replace = FALSE,
        prob = prob_repeated
      )
    })

    remaining_counts <- table(
      factor(
        transcripts[!transcripts %in% lost_transcripts],
        levels = gene_idx$non_mito_idx
      )
    )

    count_matrix[gene_idx$non_mito_idx, cell] <- as.integer(remaining_counts)
  }

  return(count_matrix)
}

.generate_qc_summary <- function(count_matrix, damage_label, gene_idx) {
  matched_indices <- match(colnames(count_matrix), damage_label$barcode)
  total_counts <- colSums(count_matrix)

  qc_summary <- data.frame(
    Cell = colnames(count_matrix),
    Damaged_Level = as.numeric(damage_label$damage_level[matched_indices]),
    Original_Features = colSums(count_matrix != 0),
    New_Features = colSums(count_matrix != 0),
    Original_MitoProp = colSums(
      count_matrix[gene_idx$mito_idx, , drop = FALSE]
    ) / total_counts,
    New_MitoProp = colSums(
      count_matrix[gene_idx$mito_idx, , drop = FALSE]
    ) / total_counts,
    Original_RiboProp = colSums(
      count_matrix[gene_idx$ribo_idx, , drop = FALSE]
    ) / total_counts,
    New_RiboProp = colSums(
      count_matrix[gene_idx$ribo_idx, , drop = FALSE]
    ) / total_counts
  )

  return(qc_summary)
}

Try the DamageDetective package in your browser

Any scripts or data that you put into this service are public.

DamageDetective documentation built on April 4, 2025, 2:39 a.m.