R/plot_correlation.R

#' plot_correlation
#'
#' plot_correlation can plot correlation using a correlation table.
#'
#' @param cor_tab A data frame that contains at least two columns to calculate correlation. Columns
#' are different samples. Doesn't require rownames.
#'
#' @param x Colname of cor_tab. If x's length is more than 1, plot_correlation will calculate
#' correlation respectively and facet the plot by x.
#'
#' @param y Colname of cor_tab.
#'
#' @param method A character string indicating which correlation coefficient is to be used. One of
#' c("pearson", "kendall", or "spearman"), default is "spearman".
#'
#' @param histogram Add histogram to the plot. Default is FALSE.
#'
#' @param heatmap Dafault FASLE. If TRUE, will generate a correlation heatmap with input table.
#' Doesn't require x or y.
#'
#' @param range A vector indicating numeric breaks, which will be pass to circlize::colorRamp2(
#' breaks), default is c(-1, 0, 1).
#'
#' @param range_color A vector of colors which correspond to values in breaks, which will be pass to
#' circlize::colorRamp2(colors), default is c("blue", "white", "red").
#'
#' @param row_order Default NULL. Input a vector of ordered rownames if need to adjust row order for
#' heatmap. Manually setting row order will turn off clustering.
#'
#' @param col_order Default NULL. Input a vector of ordered colnames if need to adjust column order
#' for heatmap. Manually setting column order will turn off clustering.
#'
#' @param row_title Same as ComplexHeatmap::Heatmap(row_title).
#'
#' @param col_title Same as ComplexHeatmap::Heatmap(column_title).
#'
#' @param font_size Default 12. Font size of the number in heatmap.
#'
#' @export

plot_correlation <- function (
  cor_tab,
  x = NA,
  y = NA,
  method = "spearman",
  histogram = FALSE,
  heatmap = FALSE,
  range = c(-1, 0, 1),
  range_color = c("blue", "white", "red"),
  row_order = NULL,
  col_order = NULL,
  row_title = NA,
  col_title = NA,
  font_size = 12
) {
  ## Notice: Colnames of the input table should only be letters or numbers,
  ## or it can't be recognized when plotting.
  #if (any(str_detect(c(x, y), '\\W'))) {
  #  stop(paste0("Colnames of the input columns can only contain letters or numbers, or it can't be",
  #              " recognized when plotting."))
  #}
  if (!heatmap) {
    if (is.na(x) | is.na(y)) {
      stop(paste0("Please input both 'x' and 'y' for plotting correlation."))
    }
  }
  ## `range` argument只能有3个值
  #if (length(range) != 3) {
  #  stop(paste0("The range argument should be 3 values."))
  #}
  # Plot correlation
  if (heatmap) {
    # Calculate correlation
    cormat <- cor(cor_tab, method = method)
    # Setting range for heatmap color
    col_fun = circlize::colorRamp2(range, range_color)
    # Set row order parameter for ComplexHeatmap::Heatmap
    if(!is.null(row_order)) {
      param_cluster_rows <- FALSE
    } else {
      param_cluster_rows <- TRUE
    }
    # Set column order parameter for ComplexHeatmap::Heatmap
    if(!is.null(col_order)) {
      param_cluster_columns <- FALSE
    } else {
      param_cluster_columns <- TRUE
    }
    # Set row title parameter for ComplexHeatmap::Heatmap
    if(is.na(row_title)) {
      param_row_title = character(0)
    } else {
      param_row_title = row_title
    }
    # Set column title parameter for ComplexHeatmap::Heatmap
    if(is.na(col_title)) {
      param_column_title = character(0)
    } else {
      param_column_title = col_title
    }
    ComplexHeatmap::Heatmap(
      cormat,
      name = "correlation",
      # Set Color Range
      col = col_fun,
      # Rows Cluster / Order / Title
      cluster_rows = param_cluster_rows,
      row_order = row_order,
      row_title = param_row_title,
      # Columns Cluster / Order / Title
      cluster_columns = param_cluster_columns,
      column_order = col_order,
      column_title = param_column_title,
      # Add value to each cell
      cell_fun = function(j, i, x, y, width, height, fill) {
        #"%.1f": 保留一位小数
        grid.text(sprintf("%.1f", cormat[i, j]), x, y, gp = gpar(fontsize = font_size))
      }
    )
  } else {
    # Setting unit to show in plot
    if (method == "pearson") {
      unit <- "Pearson's r"
    } else if (method == "spearman") {
      unit <- "Spearman's Rho"
    } else if (method == "kendall") {
      unit <- "Kendall's Tau"
    } else {
      stop("Input method is not supported.")
    }
    # Define plot one correlation or multiple correlation
    if (length(x) == 1) {
      cor_res <- cor.test(cor_tab[[x]], cor_tab[[y]], method = method)
      if (cor_res$p.value < 2.2e-16) {
        p_value <- "p-value < 2.2e-16"
      } else {
        if (cor_res$p.value < 0.001) {
          p_value <- paste0("p-value = ", formatC(cor_res$p.value, format = "e", digits = 1))
        } else {
          p_value <- paste0('p-value = ', round(cor_res$p.value, 3))
        }
      }
      g <- ggplot(data = cor_tab, aes_string(x = x, y = y)) +
        geom_point() +
        geom_smooth(method = lm) +
        annotate(geom = 'text',
                 x = max(cor_tab[[x]]) / 2,
                 y = max(cor_tab[[y]]) * 1.1,
                 label = paste0(unit, " = ", round(cor_res$estimate, 2))) +
        annotate(geom = 'text',
                 x = max(cor_tab[[x]]) / 2,
                 y = max(cor_tab[[y]]) * 1.05,
                 label = p_value) +
        labs(x = as.character(x), y = as.character(y)) +
        theme_classic() +
        theme(panel.grid = element_blank(),
              axis.text.y = element_text(size = 14),
              axis.text.x = element_text(size = 14),
              axis.title = element_text(size = 16),
              legend.text = element_text(size = 12))
      if(histogram) {
        g1 <- ggExtra::ggMarginal(g, type = "histogram", fill = "white")
      } else {
        g1 <- g
      }
      g1
    } else {
      cor_res <- data.frame(row.names = c("facet", "correlation", "pvalue", "x")) %>%
        t() %>% as.data.frame()
      for (i in 1:length(x)) {
        cor_test <- cor.test(cor_tab[[x[i]]], cor_tab[[y]], method = method)
        cor_res[i,1] <- x[i]
        cor_res[i,2] <- paste0(unit, " = ", round(as.numeric(cor_test$estimate), 2))
        if (cor_test$p.value < 2.2e-16) {
          cor_res[i,3] <- "p-value < 2.2e-16"
        } else if (cor_test$p.value < 0.001) {
          cor_res[i,3] <- paste0("p-value = ", formatC(cor_test$p.value, format = "e", digits = 1))
        } else {
          cor_res[i,3] <- paste0('p-value = ', round(cor_test$p.value, 3))
        }
        cor_res[i,4] <- max(cor_tab[[x[i]]]) / 2
      }
      cor_tab <- gather(cor_tab, x, key = "facet", value = "values")
      values <- "values"
      g2 <- ggplot(cor_tab, aes_string(x = values, y = y)) +
        facet_wrap(vars(facet), scales = "free") +
        geom_point() +
        geom_smooth(method = lm) +
        xlab("") +
        theme_classic() +
        theme(panel.grid = element_blank(),
              axis.text.y = element_text(size = 12),
              axis.text.x = element_text(size = 12),
              axis.title = element_text(size = 16),
              legend.text = element_text(size = 12)) +
        geom_text(
          data = cor_res,
          mapping = aes(x = min(cor_tab$values), y = max(cor_tab$values), label = correlation),
          vjust = 2, hjust = median(cor_tab$values)
        ) +
        geom_text(
          data = cor_res,
          mapping = aes(x = min(cor_tab$values), y = max(cor_tab$values), label = pvalue),
          vjust = 4, hjust = median(cor_tab$values) * 2
        )
      g2
    }
  }
}
yeguanhuav/visual16S documentation built on Feb. 19, 2022, 10:32 a.m.