R/plotScatter.R

Defines functions plotScatter

Documented in plotScatter

#' Scatter plot of canonical variates from SCCA using PMD
#'
#' This function generates the scatter plot of canonical variates to check if there is separation pattern between groups.
#'
#' @param X Data matrix, each row is one sample, each column is one feature.
#' @param Y Data matrix, each row is one sample, each column is one feature.
#' @param CCA_out List of SCCA results from PMA::CCA().
#' @param groups a vector indicating the group information of samples, will be colored differently in the scatter plot.
#' @param K Number of pairs of canonical variates to be plotted.
#' @import ggplot2
#' @import ggpubr
#' @importFrom stats cor
#' @export
#' @return A series of scatter plot with canonical variates as axes, samples from different groups are colored differently.
#' @examples
#' library(TestPMD)
#' data("covid")
#' out <- PMA::CCA(standsdmu(covid$metabolite), standsdmu(covid$protein),
#' typex = "standard", typez="standard",
#' penaltyx = 0.9, penaltyz = 0.9, K = 3, standardize = FALSE, trace = FALSE)
#' p_list <- plotScatter(X = covid$metabolite, Y = covid$protein,
#' CCA_out = out, groups = covid$meta$icu, K = 3)
#' ggpubr::ggarrange(plotlist = p_list, nrow = 1)

plotScatter <- function(X, Y, K, CCA_out, groups){
  if(!is.matrix(X) & !is.data.frame(X)){stop("X is not in form n*p matrix of data frame.")}
  if(!is.matrix(Y) & !is.data.frame(Y)){stop("Y is not in form n*q matrix of data frame.")}
  # if(!is.numeric(X) | !is.numeric(Y)){stop("Your data including non-numeric values.")}
  if(!is.numeric(K)){stop("K should be numeric")}
  if(is.null(CCA_out$u) | is.null(CCA_out$v)){stop("CCA_out does not have any u and v.")}
  if(ncol(CCA_out$u) <  K | ncol(CCA_out$v) <  K){stop("K is out of subscript bound of CCA_out")}
  if(length(groups) != nrow(X)){stop("Length of groups does not match samples in X")}
  if(nrow(Y) != nrow(X)){stop("Samples in Y do not match samples in X")}

  X <- standsdmu(X)
  Y <- standsdmu(Y)
  U <- as.matrix(CCA_out$u[, 1:K])
  V <- as.matrix(CCA_out$v[, 1:K])
  p_list <- list()
  groups <- factor(groups)
  for (comp in 1:K) {
    X_cv <- X%*%U[, comp]
    Y_cv <- Y%*%V[, comp]
    p_list <- append(p_list,
                     list(ggplot2::ggplot(data.frame(X_cv, Y_cv, groups))+
                            ggplot2::geom_point(ggplot2::aes(x = X_cv, y = Y_cv, colour = groups), size = 3)+
                            ggplot2::labs(x = "Canonical variate wrt X", y = "Canonical variate wrt Y", title = paste("Canonical pair", comp, sep = " "))+
                            ggplot2::theme_linedraw()+
                            ggplot2::theme(plot.title = element_text(hjust = 0.5))+
                            ggplot2::theme(legend.position="right")+
                            ggplot2::geom_text( x = min(X_cv) + 1.8, y = max(Y_cv) - 1, label = paste("corr=", round(cor(X_cv, Y_cv), 4), sep = ""), check_overlap = T)+
                            ggplot2::theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
                                           panel.background = element_blank(), axis.line = element_line(colour = "black"))))
  }
  return(p_list)
}
YunhuiQi/TestPMD documentation built on May 5, 2022, 8:23 p.m.