R/DimPlot.R

Defines functions DimPlot2_SelColDisc DimPlot2_SelColCont DimPlot2_LabelClusters DimPlot2_PlotVars DimPlot2_GetData is_continuous DimPlot2

Documented in DimPlot2

#' @title Create an Enhanced Dimensional Reduction Plot
#' @description This function creates a dimension reduction plot that can handle both discrete and continuous variables seamlessly. It incorporates additional customization options for visual representation and automatically recognizes input variable types to optimize visualization.
#' @param seu Seurat object containing single-cell data for visualization.
#' @param features Variables to be visualized in the plot, accepting both discrete and continuous variables. Default: NULL.
#' @param group.by Alias for `features`. Default: NULL.
#' @param split.by A metadata column name by which to split the plot, creating separate plots for each unique value. This can be useful for visualizing differences across conditions or experiments.
#'   Default: NULL.
#' @param cells A vector specifying a subset of cells to include in the plot.
#'   Default: all cells are included.
#' @param slot Which data slot to use for pulling expression data. Accepts 'data', 'scale.data', or 'counts'. Default: 'data'.
#' @param assay Specify the assay from which to retrieve data.
#'   Default: NULL, which will use the default assay.
#' @param dims A two-length numeric vector specifying which dimensions to use for the x and y axes, typically from a PCA, tSNE, or UMAP reduction.
#'   Default: c(1, 2).
#' @param reduction Which dimensionality reduction to use. If not specified, will search in order of 'umap', 'tsne', then 'pca'.
#'   Default: NULL.
#' @param priority Specifies which to prioritize when metadata column names conflict with gene names: 'expr' for expression, 'none' for metadata.
#'   Default: c("expr", "none").
#' @param ncol Number of columns to display when combining multiple plots into a single patchworked ggplot object.
#'   Default: NULL.
#' @param nrow Number of rows to display when combining multiple plots into a single patchworked ggplot object.
#'   Default: NULL.
#' @param nrow.each Specifies the number of rows each split plot should have when using the split.by parameter.
#'   Default: NULL.
#' @param ncol.legend Integer specifying the number of columns in the plot legend of discrete variables. Default: NULL.
#' @param cols Flexible color settings for the plot, accepting a variety of inputs:
#'
#'   - A vector specifying a global color setting similar to Seurat's `DimPlot`/`FeaturePlot`.
#'
#'   - A list specifying colors for each variable type (discrete/continuous) or for each individual variable. For example, `list(discrete = "auto", continuous = "A")` applies automatic styling from `color_pro()` for discrete variables and `viridis` "A" style for continuous variables. More detailed setups can include `list("cluster" = "pro_blue", "CD14" = c("#EEEEEE", "black"))`.
#'
#'   For continuous variables:
#'
#'     - Predefined color schemes from the `viridis` package ("A", "B", "C", "D", "E").
#'
#'     - Named vector with keys "low", "mid", and "high" for three-point gradients. Example: `c(low = "blue", mid = "white", high = "red")`.
#'
#'     - Two-point gradient with keys "low" and "high". Example: `c(low = "blue", high = "red")`.
#'
#'     - Custom color gradient using a vector of colors.
#'
#'   For discrete variables:
#'
#'     - Seven color_pro styles: "default", "light", "pro_red", "pro_yellow", "pro_green", "pro_blue", "pro_purple".
#'
#'     - Five color_iwh styles: "iwh_default", "iwh_intense", "iwh_pastel", "iwh_all", "iwh_all_hard".
#'
#'     - Brewer color scales as specified by `brewer.pal.info`.
#'
#'     - Any manually specified colors.
#'
#'   Default: list(discrete = "auto", continuous = "A").
#' @param load.cols When TRUE, automatically loads pre-stored color information for variables from `seu@misc[["var_colors"]]`.
#'   Default: TRUE.#' @param pt.size Point size for plotting, adjusts the size of each cell in the plot.
#'   Default: NULL.
#' @param shape.by Metadata column or expression data used to specify different shapes for cells in the plot, allowing for additional visual distinctions.
#'   Default: NULL.
#' @param alpha.by Transparency of points in the plot, which can be helpful in densely plotted areas.
#'   Default: NULL.
#' @param order Boolean determining whether to plot cells in order of expression. Can be useful if points representing higher expression levels are being buried.
#'   Default: c(discrete = FALSE, continuous = TRUE).
#' @param shuffle Randomly shuffles the order of points to prevent burial under other points.
#'   Default: c(discrete = TRUE, continuous = FALSE).
#' @param label Whether to label clusters or other features in the plot.
#'   Default: FALSE.
#' @param label.color Customize the color of labels; defaults to the same as cluster colors unless specified, such as "black".
#'   Default: NULL.
#' @param box Whether to draw a box around labels to enhance visibility.
#'   Default: FALSE.
#' @param index.title Specify a prefix for cluster indices when labels are replaced by numerical indices to simplify the plot.
#'   Default: NULL.
#' @param repel Whether to use a repelling algorithm to avoid overlapping text labels.
#'   Default: FALSE.
#' @param label.size Size of the text labels used for clusters or features.
#'   Default: 4.
#' @param theme Allows customization of ggplot themes, for example, to remove axes or adjust text.
#'   Default: NULL.
#' @param cells.highlight A vector of cell names to highlight; simpler input than Seurat's approach, focusing on ease of use.
#'   Default: NULL.
#' @param cols.highlight A color or vector of colors to use for highlighting specified cells; will repeat to match the number of groups in cells.highlight.
#'   Default: '#DE2D26'.
#' @param sizes.highlight Size of highlighted points, providing emphasis where needed.
#'   Default: 1.
#' @param na.value Color to use for NA points when using a custom color scale.
#'   Default: 'grey80'.
#' @param raster Whether to convert the plot points to a raster format, which can help with performance on large datasets.
#'   Default: NULL.
#' @param raster.dpi The resolution for rasterized plots, useful for maintaining detail in dense plots.
#'   Default: NULL.
#' @param combine Whether to combine multiple plots into a single ggplot object using patchwork.
#'   Default: TRUE.
#' @param align Specifies how plots should be aligned if combined, accepting 'h' for horizontal, 'v' for vertical, or 'hv' for both.
#'   Default: 'hv'.
#' @return A ggplot object if `combine` is TRUE; otherwise, a list of ggplot objects, allowing for flexible plot arrangements or combined visualizations.
#' @details `DimPlot2` extends the functionality of Seurat's visualization tools by combining the features of `DimPlot` and `FeaturePlot` into a single, more versatile function. It automatically recognizes whether the input features are discrete or continuous, adjusting the visualization accordingly. This makes `DimPlot2` ideal for exploring complex scRNA-seq data without the need to switch between different plotting functions based on variable types. The function also offers advanced customization options for colors, themes, and labeling, making it highly adaptable to various data visualization needs.
#' @examples
#' library(Seurat)
#' library(SeuratExtend)
#'
#' # Create a basic dimensional reduction plot with default settings
#' DimPlot2(pbmc)
#'
#' # Visualize different variables, including both discrete and continuous types
#' DimPlot2(pbmc, features = c("cluster", "orig.ident", "CD14", "CD3D"))
#'
#' # Split the visualization by a specific variable for comparative analysis
#' DimPlot2(pbmc, features = c("cluster", "CD14"), split.by = "orig.ident", ncol = 1)
#'
#' # Highlight specific cells, such as a particular cluster
#' b_cells <- colnames(pbmc)[pbmc$cluster == "B cell"]
#' DimPlot2(pbmc, cells.highlight = b_cells)
#'
#' # Apply advanced customization for colors and themes
#' DimPlot2(
#'   pbmc,
#'   features = c("cluster", "orig.ident", "CD14", "CD3D"),
#'   cols = list(
#'     "cluster" = "pro_blue",
#'     "CD14" = "D",
#'     "CD3D" = c("#EEEEEE", "black")
#'   ),
#'   theme = NoAxes())
#'
#' # Enhance the plot with labels and bounding boxes
#' DimPlot2(pbmc, label = TRUE, box = TRUE, label.color = "black", repel = TRUE, theme = NoLegend())
#'
#' # Use indices instead of long cluster names to simplify labels in the plot
#' DimPlot2(pbmc, index.title = "C", box = TRUE, label.color = "black")
#' @rdname DimPlot2
#' @export

DimPlot2 <- function(
    seu,
    features = NULL,
    group.by = NULL,
    split.by = NULL,
    cells = NULL,
    slot = "data",
    assay = NULL,
    dims = c(1, 2),
    reduction = NULL,
    priority = c("expr", "none"),
    ncol = NULL,
    nrow = NULL,
    nrow.each = NULL,
    ncol.legend = NULL,
    cols = list(discrete = "auto", continuous = "A"),
    load.cols = TRUE,
    pt.size = NULL,
    shape.by = NULL,
    alpha.by = NULL,
    order = c(discrete = FALSE, continuous = TRUE),
    shuffle = c(discrete = TRUE, continuous = FALSE),
    label = FALSE,
    label.color = NULL,
    box = FALSE,
    index.title = NULL,
    repel = FALSE,
    label.size = 4,
    theme = NULL,
    cells.highlight = NULL,
    cols.highlight = "#DE2D26",
    sizes.highlight = 1,
    na.value = "grey80",
    raster = NULL,
    raster.dpi = NULL,
    combine = TRUE,
    align = "hv"
) {
  plot_data <- DimPlot2_GetData(
    seu = seu,
    features = features,
    group.by = group.by,
    split.by = split.by,
    cells = cells,
    slot = slot,
    assay = assay,
    dims = dims,
    reduction = reduction,
    shape.by = shape.by,
    alpha.by = alpha.by,
    cells.highlight = cells.highlight
  )

  data.dim <- plot_data$data.dim
  data.var <- plot_data$data.var
  vars <- colnames(data.var)

  get_disc_cont_par <- function(par, type, default) {
    if (is.vector(par) && is.logical(par) && type %in% names(par)) {
      par.d <- par[type]
    } else if (is.logical(par) && length(par) == 1) {
      par.d <- par
    } else {
      par.d <- default
    }
    return(par.d)
  }

  library(ggplot2)
  p <- list()
  for (i in vars) {
    data.single <- data.dim
    data.single$var <- data.var[[i]]
    if(is_continuous(data.single$var)) {
      scale_color <- DimPlot2_SelColCont(
        seu = seu,
        var = i,
        cols = cols,
        load.cols = load.cols
      )
      order.c <- get_disc_cont_par(order, "continuous", TRUE)
      shuffle.c <- get_disc_cont_par(shuffle, "continuous", FALSE)

      p[[i]] <- DimPlot2_PlotSingle(
        data.plot = data.single,
        title = i,
        cols = scale_color,
        pt.size = pt.size,
        order = order.c,
        shuffle = shuffle.c,
        label = FALSE,
        label.color = NULL,
        repel = FALSE,
        box = FALSE,
        label.size = 4,
        cols.highlight = "#DE2D26",
        sizes.highlight = 1,
        na.value = na.value,
        nrow.each = nrow.each,
        theme = theme,
        raster = raster,
        raster.dpi = raster.dpi)
    } else {
      if(i != "Selected_cells") data.single$var <- factor(data.single$var)
      n <- nlevels(data.single$var)
      if(!is.null(index.title)) {
        data.single$var_orig <- factor(data.single$var)
        index <- paste0(index.title, seq(nlevels(data.single$var_orig)))
        index_cluster <- paste(index, levels(data.single$var_orig))
        data.single$var <- factor(index[data.single$var_orig], levels = index)
        data.single$index_cluster <- factor(index_cluster[data.single$var_orig], levels = index_cluster)
        label <- TRUE
        labels <- index_cluster
      } else {labels <- waiver()}
      if(n > 100) stop("Discrete variable '",i,"' has > 100 values. Unable to plot.")
      if(i != "Selected_cells") {
        scale_color <- DimPlot2_SelColDisc(
          seu = seu,
          n = n,
          var = i,
          cols = cols,
          load.cols = load.cols,
          label = label,
          labels = labels,
          box = box
        )
      } else scale_color <- NULL

      order.d <- get_disc_cont_par(order, "discrete", FALSE)
      shuffle.d <- get_disc_cont_par(shuffle, "discrete", TRUE)

      p[[i]] <- DimPlot2_PlotSingle(
        data.plot = data.single,
        title = i,
        cols = scale_color,
        pt.size = pt.size,
        order = order.d,
        shuffle = shuffle.d,
        label = label,
        label.color = label.color,
        repel = repel,
        box = box,
        label.size = label.size,
        cols.highlight = cols.highlight,
        sizes.highlight = sizes.highlight,
        na.value = na.value,
        nrow.each = nrow.each,
        theme = theme,
        raster = raster,
        raster.dpi = raster.dpi,
        ncol.legend = ncol.legend)
    }
  }
  if(!combine) {
    return(p)
  } else {
    import("cowplot")
    p <- plot_grid(plotlist = p, ncol = ncol, nrow = nrow, align = align)
    return(p)
  }
}

is_continuous <- function(vec) {
  return(is.numeric(vec) & !is.factor(vec))
}

DimPlot2_GetData <- function(
    seu,
    features = NULL,
    group.by = NULL,
    split.by = NULL,
    cells = NULL,
    slot = "data",
    assay = NULL,
    dims = c(1, 2),
    reduction = NULL,
    shape.by = NULL,
    alpha.by = NULL,
    cells.highlight = NULL
) {
  if(!require(SeuratObject)) library(Seurat)

  if(!is.null(assay)) DefaultAssay(seu) <- assay

  # cells
  cells <- cells %||% colnames(seu)
  if(is.logical(cells)) {
    if(length(cells) != ncol(seu)) {
      stop("Logical value of 'cells' should be the same length as cells in ",
           "Seurat object")
    } else {
      cells.l <- cells
      cells <- colnames(seu)[cells]
    }
  } else {
    if(all(!cells %in% colnames(seu))) {
      stop("'cells' not found in Seurat object")
    }else if(any(!cells %in% colnames(seu))) {
      cells.out <- setdiff(cells, colnames(seu))
      stop(length(cells.out), " cell(s) not found in Seurat object: '",
           cells.out[1], "'...")
    }
    cells.l <- colnames(seu) %in% cells
  }

  # reduction
  if (length(x = dims) != 2 || !is.numeric(x = dims)) {
    stop("'dims' must be a two-length integer vector")
  }
  reduction <- reduction %||% DefaultDimReduc(object = seu)
  data.dim <- Embeddings(object = seu[[reduction]])[cells, dims]
  data.dim <- as.data.frame(x = data.dim)
  dims <- colnames(data.dim)

  # vars
  if(!is.null(cells.highlight)) {
    seu@meta.data[["Selected_cells"]] <- factor(
      ifelse(
        colnames(seu) %in% cells.highlight,
        "Selected", "Unselected"),
      levels = c("Unselected","Selected")
    )
  }
  if(is.null(features) & is.null(group.by) & is.null(cells.highlight)) {
    data.var <- data.frame(Idents = factor(Idents(seu)[cells]))
  } else {
    vars <- c(features, group.by)
    if(!is.null(cells.highlight)) vars = c(vars, "Selected_cells")
    if (utils::packageVersion("SeuratObject") >= "5.0.0") {
      data.var <- FetchData(object = seu, vars = vars, cells = cells, layer = slot, clean = "none")
    } else {
      data.var <- FetchData(object = seu, vars = vars, cells = cells, slot = slot)
    }
  }

  # split.by, shape.by, alpha.by
  plot_vars <- list(
    split.by = split.by,
    shape.by = shape.by,
    alpha.by = alpha.by)
  for (i in names(plot_vars)) {
    if(!is.null(plot_vars[[i]])) {
      data.dim[[i]] <- DimPlot2_PlotVars(
        seu = seu,
        var = plot_vars[[i]],
        var_name = i,
        cells = cells,
        cells.l = cells.l)
    }
  }

  return(list(
    data.dim = data.dim,
    data.var = data.var
  ))
}

DimPlot2_PlotVars <- function(
    seu,
    var,
    var_name,
    cells,
    cells.l
) {
  if(is.null(var)) {
    f2 <- NULL
  }else if(length(var) == 1) {
    if(var %in% colnames(seu@meta.data)){
      f2 <- factor(seu[[var]][cells,])
      names(f2) <- cells
    }else{
      stop("Cannot find '", var, "' in meta.data")
    }
  }else if(length(var) == ncol(seu)) {
    f2 <- factor(var[cells.l])
    names(f2) <- cells
  }else if(length(var) == length(cells)) {
    f2 <- factor(var)
    names(f2) <- cells
  }else{
    stop("'",var_name,"' should be variable name in 'meta.data' or ",
         "string with the same length of cells")
  }
  return(f2)
}

DimPlot2_PlotSingle <- function (
    data.plot,
    title = NULL,
    cols = NULL,
    pt.size = NULL,
    order = NULL,
    shuffle = NULL,
    label = FALSE,
    label.color = NULL,
    repel = FALSE,
    box = FALSE,
    label.size = 4,
    cols.highlight = "#DE2D26",
    sizes.highlight = 1,
    na.value = "grey80",
    nrow.each = NULL,
    theme = NULL,
    raster = NULL,
    raster.dpi = NULL,
    ncol.legend = NULL)
{
  library(ggplot2)
  pt.size <- pt.size %||% DimPlot2_AutoPointSize(data = data.plot, raster = raster)
  dims <- colnames(data.plot)[1:2]
  if ((nrow(x = data.plot) > 1e+05) & !isFALSE(raster)) {
    message("Rasterizing points since number of points exceeds 100,000.",
            "\nTo disable this behavior set `raster=FALSE`")
  }
  raster <- raster %||% (nrow(x = data.plot) > 1e+05)
  if (!is.null(x = raster.dpi)) {
    if (!is.numeric(x = raster.dpi) || length(x = raster.dpi) != 2)
      stop("'raster.dpi' must be a two-length numeric vector")
  } else raster.dpi <- c(512, 512)

  if (isTRUE(x = shuffle)) {
    data.plot <- data.plot[sample(x = 1:nrow(x = data.plot)), ]
  }
  istrue.cell.highlight <- (title == "Selected_cells" & identical(levels(data.plot$var), c("Unselected","Selected")))
  if (isTRUE(x = order | istrue.cell.highlight)) {
    data.plot <- data.plot[order(data.plot$var), ]
  }

  if(!"shape.by" %in% colnames(data.plot)) shape.by <- NULL
  if(!"alpha.by" %in% colnames(data.plot)) alpha.by <- NULL

  plot <- ggplot(data = data.plot, mapping = aes(x = .data[[dims[1]]], y = .data[[dims[2]]], color = var, shape = shape.by, alpha = alpha.by))
  if(istrue.cell.highlight) pt.size <- sizes.highlight
  plot <-
    if (isTRUE(x = raster)) {
      import("scattermore")
      plot + geom_scattermore(pointsize = pt.size, pixels = raster.dpi)
    } else {
      plot + geom_point(size = pt.size)
    }
  if(!is_continuous(data.plot$var)) {
    plot <- plot +
      guides(color = guide_legend(override.aes = list(size = 3), ncol = ncol.legend))
  }
  plot <- plot + labs(color = NULL, title = title)
  if (label & title != "Selected_cells") {
    plot <- DimPlot2_LabelClusters(
      plot = plot,
      id = "var",
      repel = repel,
      label.color = label.color,
      box = box,
      size = label.size)
  }
  if("split.by" %in% colnames(data.plot)) {
    plot <- plot + facet_wrap(vars(split.by), nrow = nrow.each)
  }
  import("cowplot")
  plot <- plot + theme_cowplot() + CenterTitle()
  if(istrue.cell.highlight) {
    plot <- plot +
      scale_color_manual(values = c("#C3C3C3", cols.highlight)) +
      theme
  } else {
    plot <- plot + cols + theme
  }
  return(plot)
}

DimPlot2_AutoPointSize <- function (data, raster = NULL) {
  return(ifelse(test = isTRUE(x = raster), yes = 1, no = min(1583/nrow(x = data), 1)))
}

DimPlot2_LabelClusters <- function(
    plot,
    id,
    clusters = NULL,
    labels = NULL,
    label.color = NULL,
    split.by = NULL,
    repel = TRUE,
    box = FALSE,
    geom = 'GeomPoint',
    position = "median",
    ...
) {
  if(repel) import("ggrepel")
  xynames <- unlist(x = GetXYAesthetics(plot = plot, geom = geom), use.names = TRUE)
  if (!id %in% colnames(x = plot$data)) {
    stop("Cannot find variable ", id, " in plotting data")
  }
  if (!is.null(x = split.by) && !split.by %in% colnames(x = plot$data)) {
    warning("Cannot find splitting variable ", id, " in plotting data")
    split.by <- NULL
  }
  data <- plot$data[, c(xynames, id, split.by)]
  possible.clusters <- as.character(x = na.omit(object = unique(x = data[, id])))
  groups <- clusters %||% as.character(x = na.omit(object = unique(x = data[, id])))
  if (any(!groups %in% possible.clusters)) {
    stop("The following clusters were not found: ", paste(groups[!groups %in% possible.clusters], collapse = ","))
  }
  pb <- ggplot_build(plot = plot)
  if (geom == 'GeomSpatial') {
    xrange.save <- layer_scales(plot = plot)$x$range$range
    yrange.save <- layer_scales(plot = plot)$y$range$range
    data[, xynames["y"]] = max(data[, xynames["y"]]) - data[, xynames["y"]] + min(data[, xynames["y"]])
    if (!pb$plot$plot_env$crop) {
      y.transform <- c(0, nrow(x = pb$plot$plot_env$image)) - pb$layout$panel_params[[1]]$y.range
      data[, xynames["y"]] <- data[, xynames["y"]] + sum(y.transform)
    }
  }
  data <- cbind(data, color = pb$data[[1]][[1]])
  labels.loc <- lapply(
    X = groups,
    FUN = function(group) {
      data.use <- data[data[, id] == group, , drop = FALSE]
      data.medians <- if (!is.null(x = split.by)) {
        do.call(
          what = 'rbind',
          args = lapply(
            X = unique(x = data.use[, split.by]),
            FUN = function(split) {
              medians <- apply(
                X = data.use[data.use[, split.by] == split, xynames, drop = FALSE],
                MARGIN = 2,
                FUN = median,
                na.rm = TRUE
              )
              medians <- as.data.frame(x = t(x = medians))
              medians[, split.by] <- split
              return(medians)
            }
          )
        )
      } else {
        as.data.frame(x = t(x = apply(
          X = data.use[, xynames, drop = FALSE],
          MARGIN = 2,
          FUN = median,
          na.rm = TRUE
        )))
      }
      data.medians[, id] <- group
      data.medians$color <- data.use$color[1]
      return(data.medians)
    }
  )
  if (position == "nearest") {
    labels.loc <- lapply(X = labels.loc, FUN = function(x) {
      group.data <- data[as.character(x = data[, id]) == as.character(x[3]), ]
      nearest.point <- nn2(data = group.data[, 1:2], query = as.matrix(x = x[c(1,2)]), k = 1)$nn.idx
      x[1:2] <- group.data[nearest.point, 1:2]
      return(x)
    })
  }
  labels.loc <- do.call(what = 'rbind', args = labels.loc)
  labels.loc[, id] <- factor(x = labels.loc[, id], levels = levels(data[, id]))
  labels <- labels %||% groups
  if (length(x = unique(x = labels.loc[, id])) != length(x = labels)) {
    stop("Length of labels (", length(x = labels),  ") must be equal to the number of clusters being labeled (", length(x = labels.loc), ").")
  }
  names(x = labels) <- groups
  for (group in groups) {
    labels.loc[labels.loc[, id] == group, id] <- labels[group]
  }
  if (box) {
    geom.use <- ifelse(test = repel, yes = geom_label_repel, no = geom_label)
    if(is.null(label.color)) {
      plot <- plot + geom.use(
        data = labels.loc,
        mapping = aes_string(x = xynames['x'], y = xynames['y'], label = id),
        show.legend = FALSE,
        ...
      )
    } else {
      plot <- plot + geom.use(
        data = labels.loc,
        mapping = aes_string(x = xynames['x'], y = xynames['y'], label = id),
        show.legend = FALSE,
        color = label.color,
        ...
      )
    }
  } else {
    geom.use <- ifelse(test = repel, yes = geom_text_repel, no = geom_text)
    plot <- plot + geom.use(
      data = labels.loc,
      mapping = aes_string(x = xynames['x'], y = xynames['y'], label = id),
      show.legend = FALSE,
      color = "black",
      ...
    )
  }
  # restore old axis ranges
  if (geom == 'GeomSpatial') {
    plot <- suppressMessages(expr = plot + coord_fixed(xlim = xrange.save, ylim = yrange.save))
  }
  return(plot)
}

GetXYAesthetics <- function (plot, geom = "GeomPoint", plot.first = TRUE) {
  geoms <- sapply(X = plot$layers, FUN = function(layer) {
    return(class(x = layer$geom)[1])
  })
  if (geom == "GeomPoint" && "GeomScattermore" %in% geoms) {
    geom <- "GeomScattermore"
  }
  geoms <- which(x = geoms == geom)
  if (length(x = geoms) == 0) {
    stop("Cannot find a geom of class ", geom)
  }
  geoms <- min(geoms)
  if (plot.first) {
    x <- as_label(x = plot$mapping$x %||% plot$layers[[geoms]]$mapping$x)
    y <- as_label(x = plot$mapping$y %||% plot$layers[[geoms]]$mapping$y)
  } else {
    x <- as_label(x = plot$layers[[geoms]]$mapping$x %||%
                    plot$mapping$x)
    y <- as_label(x = plot$layers[[geoms]]$mapping$y %||%
                    plot$mapping$y)
  }
  return(list(x = x, y = y))
}

DimPlot2_SelColCont <- function(
  seu,
  var,
  cols,
  load.cols
) {
  list_l <- is.list(cols)
  cols_var_l <- var %in% names(cols)
  load_var <- seu@misc[["var_colors"]][[var]]
  cols_cont_l <- "continuous" %in% names(cols)
  load_cont <- seu@misc[["var_colors"]][["continuous"]]
  if(list_l & cols_var_l) {
    cols <- cols[[var]]
  } else if(load.cols & !is.null(load_var)) {
    cols <- load_var
  } else if(list_l & cols_cont_l) {
    cols <- cols[["continuous"]]
  } else if(load.cols & !is.null(load_cont)) {
    cols <- load_cont
  } else if (list_l) {
    cols <- "A"
  }
  scale_color <- scale_color_cont_auto(cols)
  return(scale_color)
}

DimPlot2_SelColDisc <- function(
    seu,
    var,
    n,
    cols,
    load.cols,
    label,
    labels = waiver(),
    box
) {
  list_l <- is.list(cols)
  cols_var_l <- var %in% names(cols)
  load_var <- seu@misc[["var_colors"]][[var]]
  cols_disc_l <- "discrete" %in% names(cols)
  load_disc <- seu@misc[["var_colors"]][["discrete"]]
  label_l <- label & !box
  if(list_l & cols_var_l) {
    cols <- cols[[var]]
  } else if(load.cols & !is.null(load_var)) {
    cols <- load_var
  } else if(list_l & cols_disc_l) {
    cols <- cols[["discrete"]]
  } else if(load.cols & !is.null(load_disc)) {
    cols <- load_disc
  } else if(list_l) {
    cols <- "pro_default"
  }
  if(is.null(cols)) return(NULL)
  if(cols[1] == "auto") cols <- ifelse(label_l, "pro_light","pro_default")
  scale_color <- scale_color_disc_auto(cols, n = n, labels = labels)
  return(scale_color)
}

CenterTitle <- function () {
  return(theme(plot.title = element_text(hjust = 0.5), validate = TRUE))
}
huayc09/SeuratExtend documentation built on July 15, 2024, 6:22 p.m.