
Defines functions mash_plot_sig_by_condition mash_plot_pairwise_sharing mash_plot_manhattan_by_condition mash_plot_Ulist mash_plot_covar mash_plot_effects reorder_cormat get_U_by_mass get_mash_metadata get_GxE get_significant_results get_pairwise_sharing get_n_significant_conditions get_Ulist get_colnames multiply_list scale_cov expand_cov get_estimated_pi_no_collapse get_estimated_pi

Documented in expand_cov get_colnames get_estimated_pi get_GxE get_mash_metadata get_n_significant_conditions get_pairwise_sharing get_significant_results get_U_by_mass mash_plot_covar mash_plot_effects mash_plot_manhattan_by_condition mash_plot_pairwise_sharing mash_plot_sig_by_condition mash_plot_Ulist reorder_cormat scale_cov

# --- Get Results from Mash Object with mashr get functions ---------

#' @title Return the estimated mixture proportions. Use get_estimated_pi to
#'     extract the estimates of the mixture proportions for different types of
#'     covariance matrix. This tells you which covariance matrices have most of
#'     the mass.
#' @param m the mash result
#' @param dimension indicates whether you want the mixture proportions for the
#'     covariances, grid, or all
#' @return a named vector containing the estimated mixture proportions.
#' @details If the fit was done with `usepointmass=TRUE` then the first
#'     element of the returned vector will correspond to the null, and the
#'     remaining elements to the non-null covariance matrices. Suppose the fit
#'     was done with $K$ covariances and a grid of length $L$. If
#'     `dimension=cov` then the returned vector will be of length $K$
#'     (or $K+1$ if `usepointmass=TRUE`).  If `dimension=grid` then
#'     the returned vector will be of length $L$ (or $L+1$).  If
#'     `dimension=all` then the returned vector will be of length $LK$ (or
#'     $LK+1$). The names of the vector will be informative for which
#'     combination each element corresponds to.
#' @importFrom ashr get_fitted_g
get_estimated_pi = function(m, dimension = c("cov","grid","all")){
  dimension = match.arg(dimension)
  } else {
    g = get_fitted_g(m)
    pihat = g$pi
    pihat_names = NULL
    pi_null = NULL

      pi_null = pihat[1]
      pihat = pihat[-1]

    pihat = matrix(pihat,nrow=length(g$Ulist))
      pihat = rowSums(pihat)
      pihat_names = c(pihat_names,names(g$Ulist))
    } else if(dimension=="grid"){
      pihat = colSums(pihat)
      pihat_names = c(pihat_names,1:length(g$grid))

    pihat = c(pi_null,pihat)
    names(pihat) = pihat_names

get_estimated_pi_no_collapse = function(m){
  g = get_fitted_g(m)
  pihat = g$pi
  names(pihat) = names(expand_cov(g$Ulist, g$grid, g$usepointmass))

#' @title Create expanded list of covariance matrices expanded by
#'   grid, Sigma_{lk} = omega_l U_k
#' @description This is an internal (non-exported) function. This help
#'   page provides additional documentation mainly intended for
#'   developers and expert users.
#' @param Ulist a list of covarance matrices
#' @param grid a grid of scalar values by which the covariance
#'   matrices are to be sc
#' @param usepointmass if TRUE adds a point mass at 0 (null component)
#'   to the list
#' @return This takes the covariance matrices in Ulist and multiplies
#' them by the grid values If usepointmass is TRUE then it adds a null
#' component.
#' @keywords internal
expand_cov = function(Ulist,grid,usepointmass=TRUE){
  scaled_Ulist = scale_cov(Ulist, grid)
  R = nrow(Ulist[[1]])
    scaled_Ulist = c(list(null=matrix(0,nrow=R,ncol=R)),scaled_Ulist)

#' @title Scale each covariance matrix in list Ulist by a scalar in
#' vector grid
#' @description This is an internal (non-exported) function. This help
#'   page provides additional documentation mainly intended for
#'   developers and expert users.
#' @param Ulist a list of matrices
#' @param grid a vector of scaling factors (standard deviaions)
#' @return a list with length length(Ulist)*length(grid)
#' @keywords internal
scale_cov = function(Ulist, grid){
  orig_names = names(Ulist)
  Ulist = unlist( lapply(grid^2, function(x){multiply_list(Ulist,x)}), recursive=FALSE)
  names(Ulist) = unlist( lapply(1:length(grid), function(x){paste0(orig_names,".",x)}), recursive=FALSE)

# Multiply each element of a list by scalar. (In our application each
# element of the list is a matrix.)
multiply_list = function(Ulist, x){lapply(Ulist, function(U){x*U})}

#' @title Get column names from a mash object
#' @description This function extracts the column names from the local false
#' sign rate table of a mash object's results. This can tell you the condition
#' names or phenotype names used in the mash object. That can be useful for
#' looking at a subset of these columns, say.
#' @param m An object of type mash
#' @return A vector of phenotype names
#' @examples
#'     \dontrun{get_colnames(m = mash_obj)}
get_colnames <- function(m){
  column_names <- colnames(m$result$lfsr)

get_Ulist <- function(m){
  Ulist <- m$fitted_g$Ulist

#' Count number of conditions each effect is significant in
#' @param m the mash result (from joint or 1by1 analysis)
#' @param thresh indicates the threshold below which to call signals significant
#' @param conditions which conditions to include in check (default to all)
#' @param sig_fn the significance function used to extract significance from mash object; eg could be ashr::get_lfsr or ashr::get_lfdr
#' @return a vector containing the number of significant conditions
get_n_significant_conditions = function(m, thresh = 0.05, conditions = NULL,
                                        sig_fn = get_lfsr){
  if (is.null(conditions)) {
    conditions = 1:get_ncond(m)
  return(apply(sig_fn(m)[,conditions,drop=FALSE] < thresh, 1, sum))

#' Compute the proportion of (significant) signals shared by magnitude in each pair of conditions, based on the poterior mean
#' @param m the mash fit
#' @param factor a number between 0 and 1 - the factor within which effects are
#'     considered to be shared.
#' @param lfsr_thresh the lfsr threshold for including an effect in the
#'     assessment
#' @param FUN a function to be applied to the estimated effect sizes before
#'     assessing sharing. The most obvious choice beside the default
#'     'FUN=identity' would be 'FUN=abs' if you want to ignore the sign of the
#'     effects when assesing sharing.
#' @details For each pair of tissues, first identify the effects that are
#'     significant (by lfsr<lfsr_thresh) in at least one of the two tissues.
#'     Then compute what fraction of these have an estimated (posterior mean)
#'     effect size within a factor `factor` of one another. The results are
#'     returned as an R by R matrix.
#' @examples
#' \dontrun{
#' get_pairwise_sharing(m) # sharing by magnitude (same sign)
#' get_pairwise_sharing(m, factor=0) # sharing by sign
#' get_pairwise_sharing(m, FUN=abs) # sharing by magnitude when sign is ignored
#' }
#' @export
get_pairwise_sharing = function(m, factor=0.5, lfsr_thresh=0.05, FUN= identity){
  R = get_ncond(m)
  lfsr = get_lfsr(m)
  S=matrix(NA,nrow = R, ncol=R)
  for(i in 1:R){
    for(j in i:R){
      sig_i=get_significant_results(m,thresh=lfsr_thresh,conditions = i)
      sig_j=get_significant_results(m,thresh=lfsr_thresh,conditions = j)
      ratio=FUN(get_pm(m)[a,i])/FUN(get_pm(m)[a,j])##divide effect sizes
      S[i,j]=mean(ratio>factor & ratio<(1/factor))
  S[lower.tri(S, diag = FALSE)] = t(S)[lower.tri(S, diag = FALSE)]
  colnames(S) = row.names(S) = colnames(m$result$PosteriorMean)


#' From a mash result, get effects that are significant in at least one condition
#' @param m the mash result (from joint or 1by1 analysis)
#' @param thresh indicates the threshold below which to call signals significant
#' @param conditions which conditions to include in check (default to all)
#' @param sig_fn the significance function used to extract significance from mash object; eg could be ashr::get_lfsr or ashr::get_lfdr. (Small values must indicate significant.)
#' @return a vector containing the indices of the significant effects, by order of most significant to least
#' @importFrom ashr get_lfsr
#' @export
get_significant_results = function(m, thresh = 0.05, conditions = NULL,
                                   sig_fn = ashr::get_lfsr) {
  if (is.null(conditions)) {
    conditions = 1:get_ncond(m)
  top = apply(sig_fn(m)[,conditions,drop=FALSE],1,min) # find top effect in each condition
  sig = which(top < thresh)
  ord = order(top[sig],decreasing=FALSE)

#' @title Get data frames of types of GxE from a mash object
#' @description Performs set operations to determine pairwise GxE for effects
#'     from a mash object.
#' @param m An object of type mash
#' @param thresh Numeric. The threshold for including an effect in the assessment
#' @param factor a number between 0 and 1. The factor within which effects are
#'     considered to be shared.
#' @return A list containing eight data frames. Those with names that start
#'     "S_" contain significant effects of different types between pairs of
#'     named rows and columns. S_all_pairwise contains all significant effects;
#'     NS_pairwise contains all non-significant effects. S_CN contains effects
#'     significant in only one condition, and effects with a significantly
#'     different magnitude (differential sensitivity). This dataframe is not
#'     conservative using the local false sign rate test - we can't determine
#'     the sign of one of the effects for effects significant in only one
#'     condition - so it's not recommended to use this, but included. S_2_no
#'     contains effects significant in both conditions that do not differ
#'     significantly in magnitude. These effects do not have GxE. S_AP contains
#'     effects significant in both conditions that differ in their sign - and
#'     have antagonistic pleiotropy. S_DS contains effects significant in both
#'     conditions that differ in the magnitude of their effect, but not their
#'     sign - differentially sensitive alleles. S_1_row and S_1_col contain
#'     effects that are significant in just one of the two conditions - the row
#'     or the column, respectively.
#' @importFrom dplyr between mutate filter
#' @importFrom tibble enframe
#' @importFrom magrittr %>%
#' @importFrom rlang .data
#' @export
get_GxE = function(m, factor = 0.4, thresh = 0.05){
  R = get_ncond(m)                          # Effects to consider

  S_all = matrix(NA, nrow = R, ncol = R, dimnames = list(get_colnames(m),
  S_2_no = matrix(NA, nrow = R, ncol = R, dimnames = list(get_colnames(m),
  S_CN = matrix(NA, nrow = R, ncol = R, dimnames = list(get_colnames(m),
  S_AP = matrix(NA, nrow = R, ncol = R, dimnames = list(get_colnames(m),
  S_DS = matrix(NA, nrow = R, ncol = R, dimnames = list(get_colnames(m),
  NS_pair = matrix(NA, nrow = R, ncol = R, dimnames = list(get_colnames(m),
  S_i = matrix(NA, nrow = R, ncol = R, dimnames = list(get_colnames(m),
  S_j = matrix(NA, nrow = R, ncol = R, dimnames = list(get_colnames(m),

  for(i in 1:R){
    for(j in 1:R){

      if(i == j){
        S_all[i, j] = length(get_significant_results(m, thresh = thresh,
                                                     conditions = i))
        S_CN[i, j] = 0
        # Not conservative!!
        S_2_no[i, j] = length(get_significant_results(m, thresh = thresh,
                                                      conditions = i))
        S_AP[i, j] = 0
        S_DS[i, j] = 0

        sig_i = get_significant_results(m, thresh = thresh, conditions = i)
        all_i = get_significant_results(m, thresh = 1, conditions = i)
        ns_i = dplyr::setdiff(all_i, sig_i)   # effects that aren't sig in i
        NS_pair[i, j] = length(ns_i)
        S_i[i, j] = 0
        S_j[i, j] = 0

      } else {

        sig_i = get_significant_results(m, thresh = thresh, conditions = i)
        sig_j = get_significant_results(m, thresh = thresh, conditions = j)

        all_i = get_significant_results(m, thresh = 1, conditions = i)
        all_j = get_significant_results(m, thresh = 1, conditions = j)

        ns_i = setdiff(all_i, sig_i)   # effects that aren't sig in i
        ns_j = setdiff(all_j, sig_j)   # effects that aren't sig in j

        # Markers where we aren't sure of the sign in either condition
        # aka most of the effects
        NS_pair[i,j] <- length(intersect(ns_i, ns_j))

        # Markers where we are sure of the sign in just one condition

        # Markers significant in i but not in j
        ms_isigi = intersect(sig_i, ns_j)

        # Markers significant in j but not in i
        ms_jsigj = intersect(sig_j, ns_i)

        # Markers where we are sure of the sign in both conditions
        effects_df <- get_pm(m)[union(sig_i, sig_j), i] %>%
          enframe(name = "Marker", value = "Effect_i") %>%
          mutate(Effect_j = get_pm(m)[union(sig_i, sig_j), j],
                 ratio = .data$Effect_i/.data$Effect_j,  ##divide effect sizes, if this ratio is positive there is not AP
                 APratio = .data$Effect_i/-.data$Effect_j)  ##divide effect sizes, if this ratio is positive there is AP

        ## GxE: we are sure of the sign for two effects, and they are the same sign
        # No GxE in this pair - effects are same sign and same mag
        ms_sig2_noGxE <- effects_df %>%
          filter(between(.data$ratio, factor, 1/factor))

        # DS: we are sure of the sign for two effects, and they are the same sign
        ms_sig2_DS <- effects_df %>%
          filter(.data$ratio > 0 & !between(.data$ratio, factor, 1/factor))

        ## GxE we are sure of the sign for two effects, and they are opposite
        # AP: we are sure of the sign for two effects, and they are opposite
        ms_sig2_AP <- effects_df %>%
          filter(between(.data$APratio, 0, 1E10))

        S_all[i, j] = sum(length(ms_isigi), length(ms_jsigj), nrow(ms_sig2_noGxE),
                          nrow(ms_sig2_DS), nrow(ms_sig2_AP))
        S_CN[i, j] = sum(length(ms_isigi), length(ms_jsigj), nrow(ms_sig2_DS))
        # Not conservative!!
        S_2_no[i, j] = sum(nrow(ms_sig2_noGxE))
        S_AP[i, j] = sum(nrow(ms_sig2_AP))
        S_DS[i, j] = sum(nrow(ms_sig2_DS))
        S_i[i, j] = length(ms_isigi)
        S_j[i, j] = length(ms_jsigj)
  return(list(S_all_pairwise = S_all, S_CN = S_CN, S_2_no = S_2_no, S_AP = S_AP,
              S_DS = S_DS, NS_pairwise = NS_pair, S_1_row = S_i,
              S_1_col = S_j))

#' @title Get switchgrass metadata for SNPs in a mash object.
#' @description Takes a mash object and a bigsnp object and returns
#'     a dataframe containing summary metadata for each SNP in the mash object.
#'     This information includes details for the alternate allele of each SNP
#'     such as its frequency in the bigsnp object, the average latitude,
#'     longitude, elevation, and year that the SNP was sampled, and how
#'     frequently individuals with the SNP have specific subpop and ecotype
#'     identities in the bigsnp object.
#'     NB: This is VERY slow and best for up to tens of thousands of
#'     SNPs, NOT millions of SNPs.
#' @param m A mash object (outputted by mash).
#' @param snp A bigsnp object
#' @param saveoutput Logical. Should the output be saved to the path?
#' @param suffix Character. Optional. A unique suffix used to save the files,
#'     instead of the current date & time.
#' @return A data frame containing details for the alternate allele of each SNP
#'     such as its frequency in the bigsnp object, the average latitude,
#'     longitude, elevation, and year that the SNP was sampled, and how
#'     frequently individuals with the SNP have specific subpop and ecotype
#'     identities in the bigsnp object.
#' @note This function is VERY slow and best for up to tens of thousands of
#'     SNPs, NOT millions of SNPs. Use vcftools instead for millions of SNPs.
#' @importFrom dplyr summarise mutate filter select
#' @importFrom tidyselect everything
#' @importFrom lubridate year
#' @importFrom stats weighted.mean
#' @importFrom tibble add_row
#' @importFrom magrittr %>%
#' @importFrom rlang .data
#' @export
get_mash_metadata <- function(m, snp, suffix = "", saveoutput = FALSE){
  if(attr(snp, "class") != "bigSNP"){
    stop("snp needs to be a bigSNP object, produced by the bigsnpr package.")
  markers <- get_marker_df(m)
  filtercol <- which(snp$map$marker.ID %in% markers$Marker)
  if(str_sub(suffix, end = 1) %in% c("")){
    suffix <- paste0("_", get_date_filename())

  metadata1 <- switchgrassGWAS::pvdiv_metadata %>%
    filter(.data$PLANT_ID %in% snp$fam$sample.ID) %>%
    mutate(Atlantic = ifelse(.data$GENETIC_SUBPOP_KINSHIP == "Atlantic", 1, 0),
           Gulf = ifelse(.data$GENETIC_SUBPOP_KINSHIP == "Gulf", 1, 0),
           Midwest = ifelse(.data$GENETIC_SUBPOP_KINSHIP == "Midwest", 1, 0),
           Coastal = ifelse(.data$ECOTYPE_NNET == "Coastal", 1, 0),
           Upland = ifelse(.data$ECOTYPE_NNET == "Upland", 1, 0),
           Lowland = ifelse(.data$ECOTYPE_NNET == "Lowland", 1, 0),
           COLL_YEAR = year(.data$COLL_DATE))
  m1 <- metadata1 %>%
    mutate(SNP = snp$genotypes[,filtercol[1]]) %>%
    summarise(Alt_freq = mean(.data$SNP/2, na.rm = TRUE),
              Alt_Latitude = weighted.mean(.data$LATITUDE, .data$SNP,
                                           na.rm = TRUE),
              Alt_Longitude = weighted.mean(-.data$LONGITUDE, .data$SNP,
                                            na.rm = TRUE),
              Alt_Elevation = weighted.mean(.data$ELEVATION, .data$SNP,
                                            na.rm = TRUE),
              Alt_Year = weighted.mean(.data$COLL_YEAR, .data$SNP,
                                       na.rm = TRUE),
              Alt_Coastal_d = weighted.mean(.data$Coastal, .data$SNP,
                                            na.rm = TRUE),
              Alt_Coastal_c = weighted.mean(.data$COASTAL_ASSIGNMENT,
                                            .data$SNP, na.rm = TRUE),
              Alt_Lowland_d = weighted.mean(.data$Lowland, .data$SNP,
                                            na.rm = TRUE),
              Alt_Lowland_c = weighted.mean(.data$LOWLAND_ASSIGNMENT,
                                            .data$SNP, na.rm = TRUE),
              Alt_Upland_d = weighted.mean(.data$Upland, .data$SNP,
                                           na.rm = TRUE),
              Alt_Upland_c = weighted.mean(.data$UPLAND_ASSIGNMENT,
                                           .data$SNP, na.rm = TRUE),
              Alt_Atlantic_d = weighted.mean(.data$Atlantic, .data$SNP,
                                             na.rm = TRUE),
              Alt_Atlantic_c = weighted.mean(.data$ATLANTIC_Q, .data$SNP,
                                             na.rm = TRUE),
              Alt_Gulf_d = weighted.mean(.data$Gulf, .data$SNP, na.rm = TRUE),
              Alt_Gulf_c = weighted.mean(.data$GULF_Q, .data$SNP,
                                         na.rm = TRUE),
              Alt_Midwest_d = weighted.mean(.data$Midwest, .data$SNP,
                                            na.rm = TRUE),
              Alt_Midwest_c = weighted.mean(.data$MIDWEST_DA, .data$SNP,
                                            na.rm = TRUE)
    ) %>%
    mutate(Marker = snp$map$marker.ID[filtercol[1]]) %>%
    select(.data$Marker, everything())

  for(i in 2:length(filtercol)){
    m_i <- metadata1 %>%
      mutate(SNP = snp$genotypes[,filtercol[i]]) %>%
      summarise(Alt_freq = mean(.data$SNP/2, na.rm = TRUE),
                Alt_Latitude = weighted.mean(.data$LATITUDE, .data$SNP,
                                             na.rm = TRUE),
                Alt_Longitude = weighted.mean(-.data$LONGITUDE, .data$SNP,
                                              na.rm = TRUE),
                Alt_Elevation = weighted.mean(.data$ELEVATION, .data$SNP,
                                              na.rm = TRUE),
                Alt_Year = weighted.mean(.data$COLL_YEAR, .data$SNP,
                                         na.rm = TRUE),
                Alt_Coastal_d = weighted.mean(.data$Coastal, .data$SNP,
                                              na.rm = TRUE),
                Alt_Coastal_c = weighted.mean(.data$COASTAL_ASSIGNMENT,
                                              .data$SNP, na.rm = TRUE),
                Alt_Lowland_d = weighted.mean(.data$Lowland, .data$SNP,
                                              na.rm = TRUE),
                Alt_Lowland_c = weighted.mean(.data$LOWLAND_ASSIGNMENT,
                                              .data$SNP, na.rm = TRUE),
                Alt_Upland_d = weighted.mean(.data$Upland, .data$SNP,
                                             na.rm = TRUE),
                Alt_Upland_c = weighted.mean(.data$UPLAND_ASSIGNMENT,
                                             .data$SNP, na.rm = TRUE),
                Alt_Atlantic_d = weighted.mean(.data$Atlantic, .data$SNP,
                                               na.rm = TRUE),
                Alt_Atlantic_c = weighted.mean(.data$ATLANTIC_Q, .data$SNP,
                                               na.rm = TRUE),
                Alt_Gulf_d = weighted.mean(.data$Gulf, .data$SNP,
                                           na.rm = TRUE),
                Alt_Gulf_c = weighted.mean(.data$GULF_Q, .data$SNP,
                                           na.rm = TRUE),
                Alt_Midwest_d = weighted.mean(.data$Midwest, .data$SNP,
                                              na.rm = TRUE),
                Alt_Midwest_c = weighted.mean(.data$MIDWEST_DA, .data$SNP,
                                              na.rm = TRUE)
      ) %>%
      mutate(Marker = snp$map$marker.ID[filtercol[i]]) %>%
      select(.data$Marker, everything())
    m1 <- m1 %>% add_row(m_i)
    if(saveoutput == TRUE){
      if(i%%1000 == 0){
        write_rds(m1, paste0("Metadata_associated_with_alternate_alleles_",
                             suffix, ".rds"), compress = "gz")
  if(saveoutput == TRUE){
    write_rds(m1, paste0("Metadata_associated_with_alternate_alleles_",
                         suffix, "_full.rds"), compress = "gz")

#' @title Get the positions of objects in a mash object Ulist that are above
#'      some mass threshold.
#' @description Get the positions of objects in a mash object Ulist that are
#'      above some mass threshold.
#' @param m An object of type mash
#' @param thresh Numeric. The mass threshold for including a covariance matrix
#' @export
get_U_by_mass <- function(m, thresh = 0.05){
  umass <- mash_plot_covar(m, saveoutput = FALSE)
  mass_thresh <-
    umass$covar_df$`Covariance Matrix`[which(umass$covar_df$Mass >= thresh)]
  Ulist_order <- names(get_Ulist(m))
  range <- which(Ulist_order %in% mass_thresh)

#' @title Reorder correlation matrix
#' @description  Reorder correlation coefficients from a matrix of things
#'     (including NA's) and hierarchically cluster them
#' @param cormat A correlation matrix
#' @importFrom cluster daisy
#' @importFrom stats hclust
reorder_cormat <- function(cormat){
  # Use correlation between variables as distance
  dd <- daisy(cormat, metric = "gower")
  hc <- hclust(dd)
  cormat <- cormat[hc$order, hc$order]

# --- Plot & Save Plots ---------

#' ggplot of single mash effect
#' @description Creates a plot with point estimates and standard errors for
#'     effects of a single SNP in multiple conditions.
#' @param m An object of type mash
#' @param n Optional. Integer or integer vector. The result number to plot, in
#'     order of significance. 1 would be the top result, for example. Find
#'     these with \code{\link{get_significant_results}}.
#' @param i Optional. Integer or integer vector. The result number to plot, in
#'     the order of the mash object. 1 would be the first marker in the mash
#'     object, for example. Find these with \code{\link{get_marker_df}}.
#' @param marker Optional. Print the marker name on the plot?
#' @param saveoutput Logical. Should the output be saved to the path?
#' @param suffix Character. Optional. A unique suffix used to save the files,
#'     instead of the current date & time.
#' @note Specify only one of n or i.
#' @importFrom ashr get_psd
#' @importFrom cowplot save_plot
#' @importFrom tibble enframe
#' @importFrom dplyr mutate case_when
#' @importFrom tidyr separate
#' @import ggplot2
#' @importFrom purrr as_vector
#' @importFrom stringr str_replace str_replace_all
#' @export
mash_plot_effects <- function(m, n = NA, i = NA, marker = TRUE,
                              saveoutput = FALSE, suffix = ""){
  stopifnot((typeof(n) %in% c("double", "integer") | typeof(i) %in% c("double", "integer")))

  if(typeof(n) != "logical"){
    i <- get_significant_results(m)[n]
    marker_df <- names(i) %>%
      enframe(name = NULL, value = "Marker") %>%
      separate(.data$Marker, into = c("Chr", "Pos"), sep = "_", convert = TRUE) %>%
      mutate(Mb = round(.data$Pos / 1000000, digits = 1)) %>%
      mutate(Marker = case_when(typeof(.data$Chr) == "integer" &
                                  .data$Chr < 10 ~
                                  paste0("Chr0", .data$Chr, " ",
                                         .data$Mb, " Mb"),
                                typeof(.data$Chr) == "integer" &
                                  .data$Chr >= 10 ~
                                  paste0("Chr", .data$Chr, " ",
                                         .data$Mb, " Mb"),
                                TRUE ~ paste0(Chr, " ", .data$Mb, " Mb")
    marker_name <- marker_df$Marker
  } else {
    marker_df <- get_marker_df(m)[i,] %>%
      separate(.data$Marker, into = c("Chr", "Pos"), sep = "_",
               convert = TRUE) %>%
      mutate(Mb = round(.data$Pos / 1000000, digits = 1)) %>%
      mutate(Marker = case_when(typeof(.data$Chr) == "integer"
                                & .data$Chr < 10 ~
                                  paste0("Chr0", .data$Chr, " ", Mb, " Mb"),
                                typeof(.data$Chr) == "integer" &
                                  .data$Chr >= 10 ~
                                  paste0("Chr", .data$Chr, " ", .data$Mb,
                                         " Mb"),
                                TRUE ~ paste0(.data$Chr, " ", .data$Mb, " Mb")
    marker_name <- marker_df$Marker
  effectplot <- get_colnames(m) %>%
    enframe(name = NULL, value = "Conditions") %>%
    mutate(mn = get_pm(m)[i,],
           se = get_psd(m)[i,]) %>%
    mutate(Conditions = str_replace(.data$Conditions,
           Conditions = str_replace(.data$Conditions, "-mean$", "")

  ggobject <- ggplot(data = effectplot) +
    geom_point(mapping = aes(x = as.factor(.data$Conditions), y = .data$mn)) +
    switchgrassGWAS::theme_oeco +
    geom_errorbar(mapping = aes(ymin = .data$mn - .data$se,
                                ymax = .data$mn + .data$se,
                                x = .data$Conditions), width = 0.3) +
    geom_hline(yintercept = 0, lty = 2) +
    labs(x = "Conditions", y = "Effect Size") +
    theme(axis.text.x = element_text(angle = 45, hjust = 1))
  if(marker == TRUE){
    ggobject <- ggobject + ggtitle(label = marker_name)
  if(saveoutput == TRUE){
    if(!(str_sub(suffix, end = 1) %in% c("", "_"))){
      suffix <- paste0("_", suffix)
    if(str_sub(suffix, end = 1) %in% c("")){
      suffix <- paste0("_", get_date_filename())
    save_plot(filename = paste0("Effect_plot_", str_replace_all(marker_name,
                                                                " ", "_"),
                                suffix, ".png"),
              plot = ggobject, base_aspect_ratio = 0.9, base_height = 3.5)
  return(list(marker = marker_name, effect_df = effectplot,
              ggobject = ggobject))

#' ggplot of covariance matrix masses
#' @description Creates a bar plot using ggplot of the masses that are on each
#'     covariance matrix specified in the mash model.
#' @param m An object of type mash
#' @param saveoutput Logical. Should the output be saved to the path?
#' @param suffix Character. Optional. A unique suffix used to save the files,
#'     instead of the current date & time.
#' @importFrom cowplot save_plot
#' @importFrom tibble enframe
#' @importFrom dplyr mutate arrange desc
#' @importFrom stringr str_replace
#' @import ggplot2
#' @note This plot can be useful for seeing the overall patterns of effects in
#'     the data used in mash. Non-significant effects will add mass to the
#'     "no_effects" covariance matrix, while significant effects will add mass
#'     to one of the other covariance matrices. You can use mash_plot_Ulist()
#'     to plot the covariance matrix patterns themselves.
#' @export
mash_plot_covar <- function(m, saveoutput = FALSE, suffix = ""){
  df <- get_estimated_pi(m)
  df <- enframe(df, name = "Covariance Matrix", value = "Mass") %>%
    mutate(`Covariance Matrix` = str_replace(.data$`Covariance Matrix`,
           `Covariance Matrix` = str_replace(.data$`Covariance Matrix`,
                                             "^null$", "no_effects")) %>%
    arrange(desc(.data$`Covariance Matrix`))
  df$`Covariance Matrix` <- factor(df$`Covariance Matrix`,
                                   levels = (df$`Covariance Matrix`))
  ggobject <- ggplot(df) +
    geom_bar(aes(x = .data$Mass, y = .data$`Covariance Matrix`),
             stat = "identity") +

  if(saveoutput == TRUE){
    if(!(str_sub(suffix, end = 1) %in% c("", "_"))){
      suffix <- paste0("_", suffix)
    if(str_sub(suffix, end = 1) %in% c("")){
      suffix <- paste0("_", get_date_filename())
    save_plot(paste0("Covariance_matrix_mass_plot", suffix,
                     ".png"), plot = ggobject, base_aspect_ratio = 1,
              base_height = 4)
  return(list(covar_df = df, ggobject = ggobject))

#' ggplot of specific covariance matrix patterns
#' @description Creates a tile plot using ggplot of the covariance matrices
#'     specified in the mash model.
#' @param m An object of type mash
#' @param range Numeric vector. Which covariance matrices should be plotted?
#' @param saveoutput Logical. Should the output be saved to the path?
#' @param suffix Character. Optional. A unique suffix used to save the files,
#'     instead of the current date & time.
#' @param limits should there be plot limits of -1 and 1? Default is true.
#' @return A list of dataframes used to make the tile plots and the plots
#'     themselves.
#' @importFrom cowplot save_plot
#' @importFrom tibble enframe as_tibble
#' @importFrom dplyr mutate filter
#' @importFrom rlist list.append
#' @importFrom rlang .data
#' @importFrom stringr str_replace
#' @importFrom tidyr pivot_longer
#' @import ggplot2
#' @export
mash_plot_Ulist <- function(m, range = NA, saveoutput = FALSE, suffix = "",
                            limits = TRUE){
  Ulist <- get_Ulist(m)
  Ureturn <- list()
  stopifnot(typeof(range) %in% c("double", "integer", "logical"))
  if(typeof(range) == "logical"){
    range <- seq_along(Ulist)

  for(u in range){
    U1 <- Ulist[[u]]
    # Remove half the tiles if the matrix is symmetric
      for(i in 1:nrow(U1)){
        for(j in 1:ncol(U1)){
          if(i < j){
            U1[i, j] <- NA
    if(ncol(U1) == length(get_colnames(m))){
    colnames(U1) <- str_replace(get_colnames(m), "(Stand_)??Bhat_?", "") %>%
      str_replace("-mean$", "")
    U1 <- as_tibble(U1, .name_repair = "unique") %>%
      mutate(rowU = str_replace(get_colnames(m), "(Stand_)??Bhat_?", ""),
             rowU = str_replace(.data$rowU, "-mean$", "")) %>%
      pivot_longer(cols = -.data$rowU, names_to = "colU",
                   values_to = "covar") %>%
    } else {
      U1 <- as_tibble(U1, rownames = "rowU", .name_repair = "unique") %>%
        pivot_longer(cols = -.data$rowU, names_to = "colU",
                     values_to = "covar") %>%
        mutate(rowU = paste0("Condition", .data$rowU),
               colU = str_replace(.data$colU, "\\.\\.\\.", "Condition")) %>%

    if(limits == TRUE){
    U1_covar <- U1 %>%
      ggplot(aes(x = .data$rowU, y = .data$colU)) +
      switchgrassGWAS::theme_oeco +
      geom_tile(aes(fill = .data$covar), na.rm = TRUE) +
      scale_fill_gradientn(colors = c("#440154FF", "#3B528BFF", "#2C728EFF",
                                      "white", "#27AD81FF", "#5DC863FF",
                                      "#FDE725FF"), limits = c(-1,1)) +
      geom_text(aes(label = round(.data$covar, 1)), color = "darkgrey") +
      theme(legend.position = c(0.2, 0.9),
            axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) +
      xlab("") + ylab("") + ggtitle(names(Ulist)[u])
    } else {
      lim_lower <- min(U1$covar)
      lim_upper <- max(U1$covar) + max(U1$covar) * 0.01
      U1_covar <- U1 %>%
        ggplot(aes(x = .data$rowU, y = .data$colU)) +
        switchgrassGWAS::theme_oeco +
        geom_tile(aes(fill = .data$covar), na.rm = TRUE) +
        scale_fill_gradientn(colors = c("#440154FF", "#472D7BFF", "#3B528BFF",
                                        "#2C728EFF", "#21908CFF", "#27AD81FF",
                                        "#5DC863FF", "#AADC32FF", "#FDE725FF"),
                             limits = c(lim_lower, lim_upper)) +
        geom_text(aes(label = round(.data$covar, 1)), color = "darkgrey") +
        theme(legend.position = c(0.2, 0.9),
              axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) +
        xlab("") + ylab("") + ggtitle(names(Ulist)[u])

    Ureturn <- list.append(Ureturn, U1 = U1,
                           ggobject = U1_covar)
    names(Ureturn)[-2] <- paste0(names(Ulist)[u], "_df")
    names(Ureturn)[-1] <- paste0(names(Ulist)[u], "_ggobject")

    if(saveoutput == TRUE){
      if(!(str_sub(suffix, end = 1) %in% c("", "_"))){
        suffix <- paste0("_", suffix)
      if(str_sub(suffix, end = 1) %in% c("")){
        suffix <- paste0("_", get_date_filename())
      save_plot(paste0("Covariances_plot_", names(Ulist)[u], suffix, ".png"),
                plot = U1_covar, base_height = 3.8, base_asp = 1.15)


#' @title Manhattan plot in ggplot colored by significant conditions
#' @description Takes a mash object and, for some vector of phenotypes, returns
#'     a Manhattan plot ggplot object (and its dataframe). Each SNP in the plot
#'     is colored by the number of phenotypes it is significant for. Even and
#'     odd chromosomes have different shapes for their SNPs, so that
#'     chromosome identity can be determined.
#' @param m A mash object (outputted by mash).
#' @param cond A vector of phenotypes. Defaults to the names of each
#'     column in the mash object.
#' @param saveoutput Logical. Should the output be saved to the path?
#' @param suffix Character. Optional. A unique suffix used to save the files,
#'     instead of the current date & time.
#' @param thresh Numeric. The threshold used for the local false sign rate to
#'     call significance in a condition.
#' @return A \code{tbl_df()} of the data used to make the Manhattan plot, and a
#'     ggplot object containing the Manhattan.
#' @importFrom cowplot save_plot
#' @importFrom dplyr rename select arrange mutate left_join
#' @import ggplot2
#' @importFrom tibble as_tibble rownames_to_column enframe
#' @importFrom tidyr separate
#' @examples
#' \dontrun{manhattan_out <- mash_ggman_by_condition(m = m, saveoutput = TRUE)}
#' @export
mash_plot_manhattan_by_condition <- function(m, cond = NA, saveoutput = FALSE,
                                             suffix = "", thresh = 0.05){
  num_sig_in_cond <- c()

    cond <- get_colnames(m = m)

  log10bf_df <- get_log10bf(m = m) %>%
    as.data.frame() %>%
    rownames_to_column(var = "value") %>%
    mutate(value = as.integer(.data$value)) %>%
    as_tibble() %>%
    left_join(get_marker_df(m = m), by = "value") %>%
    dplyr::rename(log10BayesFactor = .data$V1) %>%

  ggman_df <- get_n_significant_conditions(m = m, thresh = thresh,
                                           conditions = cond) %>%
    enframe(name = "Marker", value = "Num_Sig_Conditions") %>%
    separate(.data$Marker, into = c("Chr", "Pos"), remove = FALSE, sep = "_",
             extra = "merge", convert = TRUE) %>%
    left_join(log10bf_df, by = "Marker") %>%
    arrange(.data$Chr, .data$Pos)

  log10BF <- expression(paste("log"[10], plain("(Bayes Factor)")))

  ggmanobject <- ggplot(data = ggman_df, aes(x = .data$Pos, y = .data$log10BayesFactor)) +
    switchgrassGWAS::theme_oeco +
    geom_point(aes(color = .data$Num_Sig_Conditions, fill = .data$Num_Sig_Conditions,
                   shape = as.factor(.data$Chr))) +
    facet_wrap(~ .data$Chr, nrow = 1, scales = "free_x", strip.position = "bottom") +
    scale_color_viridis_c(option = "B", end = 0.95) +
    scale_fill_viridis_c(option = "B", end = 0.95) +
    theme(strip.text = element_text(size = 8),
          axis.text.y = element_text(size = 8),
          axis.text.x = element_blank(),
          legend.text = element_text(size = 8),
          axis.ticks.x = element_blank(),
          panel.background = element_rect(fill=NA)) +
    labs(x = "Chromosome", y = log10BF) +
    scale_x_continuous(expand = c(0.32, 0.32)) +
    scale_shape_manual(values = rep(c(21,22),9), guide = FALSE)

  if(saveoutput == TRUE){
    if(!(str_sub(suffix, end = 1) %in% c("", "_"))){
      suffix <- paste0("_", suffix)
    if(str_sub(suffix, end = 1) %in% c("")){
      suffix <- paste0("_", get_date_filename())
    save_plot(paste0("Manhattan_mash", suffix,
                     ".png"), plot = ggmanobject, base_aspect_ratio = 2.8,
              base_height = 3.3)

  return(list(ggman_df = ggman_df, ggmanobject = ggmanobject))

#' @title Create a ggplot of pairwise sharing of mash effects
#' @description Given a correlation matrix, an RDS with a correlation matrix, or
#'     a mash object, create a ggplot of pairwise sharing of mash effects using
#'     \code{\link{get_pairwise_sharing}} and \code{\link{ggcorr}}.
#' @param m An object of type mash
#' @param effectRDS An RDS containing a correlation matrix.
#' @param corrmatrix A correlation matrix
#' @param reorder Logical. Should the columns be reordered by similarity?
#' @param saveoutput Logical. Should the output be saved to the path?
#' @param suffix Character. Optional. A unique suffix used to save the files,
#'     instead of the current date & time.
#' @param filename Character string with an output filename. Optional.
#' @param ... Other arguments to \code{\link{get_pairwise_sharing}} or
#'      \code{\link{ggcorr}}.
#' @importFrom GGally ggcorr
#' @importFrom stringr str_replace_all
#' @importFrom rlang .data
#' @return A list containing a dataframe containing the correlations and a
#'     ggplot2 object containing the correlation plot.
#' @export
mash_plot_pairwise_sharing <- function(m = NULL, effectRDS = NULL,
                                       corrmatrix = NULL, reorder = TRUE,
                                       saveoutput = FALSE, filename = NA,
                                       suffix = "", ...){
  # Additional arguments for get_pairwise_sharing, ggcorr, and save_plot
  factor <- dots(name = 'factor', value = 0.5, ...)
  lfsr_thresh <- dots(name = 'lfsr_thresh', value = 0.05, ...)
  FUN <- dots(name = 'FUN', value = identity, ...)
  geom <- dots(name = 'geom', value = 'circle', ...)
  label <- dots(name = 'label', value = FALSE, ...)
  label_alpha <- dots(name = 'label_alpha', value = TRUE, ...)
  label_size <- dots(name = 'label_size', value = 3, ...)
  hjust <- dots(name = 'hjust', value = 0.95, ...)
  vjust <- dots(name = 'vjust', value = 0.3, ...)
  layout.exp <- dots(name = 'layout.exp', value = 9, ...)
  min_size <- dots(name = 'min_size', value = 0, ...)
  max_size <- dots(name = 'max_size', value = 3.5, ...)
  option <- dots(name = 'option', value = 'B', ...)
  dpi <- dots(name = 'dpi', value = 500, ...)

  base_aspect_ratio <- dots(name = 'base_aspect_ratio', value = 1.1, ...)

    if(!(str_sub(suffix, end = 1) %in% c("", "_"))){
      suffix <- paste0("_", suffix)
    if(str_sub(suffix, end = 1) %in% c("")){
      suffix <- paste0("_", get_date_filename())
    filename <- paste0("Mash_pairwise_shared_effects", suffix, ".png")

  # look for a shared effects matrix in the path, and if not, generate one
  if(!is.null(effectRDS) && is.null(m) && is.null(corrmatrix)){
    shared_effects <- readRDS(effectRDS)
  } else if(!is.null(corrmatrix) && is.null(effectRDS) && is.null(m)){
    shared_effects <- corrmatrix
  } else if(!is.null(m)){
    shared_effects <- get_pairwise_sharing(m = m, factor = factor,
                                           lfsr_thresh = lfsr_thresh, FUN = FUN)
    rownames(shared_effects) <- str_replace_all(rownames(shared_effects),
                                                "(Stand_)??Bhat_?", "") %>%
      str_replace("-mean$", "")
    colnames(shared_effects) <- str_replace_all(colnames(shared_effects),
                                                "(Stand_)??Bhat_?", "") %>%
      str_replace("-mean$", "")
  } else {
    stop(paste0("Please specify one of these: ",
                "1. a mash output object (m), ",
                "2. the path to a effect rds file (mashRDS), ",
                "3.  a correlation matrix (corrmatrix)."))

  base_height <- dots(name = 'base_height',
                            value = nrow(shared_effects)*0.33+1, ...)

  if(reorder == TRUE){
    corrdf <- reorder_cormat(cormat = shared_effects)
    corrplot <- ggcorr(data = NULL, cor_matrix = corrdf, geom = geom,
                       label = label, label_alpha = label_alpha,
                       label_size = label_size, hjust = hjust, vjust = vjust,
                       layout.exp = layout.exp, min_size = min_size,
                       max_size = max_size) +
      scale_color_viridis_c(option = option)
  } else {
    corrplot <- ggcorr(data = NULL, cor_matrix = shared_effects, geom = geom,
                       label = label, label_alpha = label_alpha,
                       label_size = label_size, hjust = hjust, vjust = vjust,
                       layout.exp = layout.exp, min_size = min_size,
                       max_size = max_size) +
      scale_color_viridis_c(option = option)

  if(saveoutput == TRUE){
    save_plot(filename = filename, corrplot,
              base_aspect_ratio = base_aspect_ratio, base_height = base_height,
              dpi = dpi)
  return(list(corr_matrix = shared_effects, gg_corr = corrplot))

#' @title Significant SNPs per number of conditions
#' @description For some number of columns in a mash object that correspond to
#'     conditions, find the number of SNPs that are significant for that number
#'     of conditions.
#' @param m An object of type mash
#' @param conditions A vector of conditions. Get these with get_colnames(m).
#' @param saveoutput Logical. Save plot output to a file? Default is FALSE.
#' @param thresh What is the threshold to call an effect significant? Default
#'     is 0.05.
#' @param suffix Character. Optional. A unique suffix used to save the files,
#'     instead of the current date & time.
#' @return A list containing a dataframe of the number of SNPs significant per
#'     number of conditions, and a ggplot object using that dataframe.
#' @import ggplot2
#' @importFrom tibble enframe
#' @importFrom dplyr rename summarise filter group_by n
#' @examples
#'   \dontrun{mash_plot_sig_by_condition(m = mash_obj, saveoutput = TRUE)}
#' @export
mash_plot_sig_by_condition <- function(m, conditions = NA, saveoutput = FALSE,
                                       suffix = "", thresh = 0.05){

  thresh <- as.numeric(thresh)
  num_sig_in_cond <- c()

  if(typeof(conditions) == "logical"){
    cond <- get_colnames(m)
  } else {
    cond <- conditions

  SigHist <- get_n_significant_conditions(m = m, thresh = thresh,
                                          conditions = cond) %>%
    enframe(name = "Marker") %>%
    rename(Number_of_Conditions = .data$value) %>%
    group_by(.data$Number_of_Conditions) %>%
    summarise(Significant_SNPs = n()) %>%
    filter(.data$Number_of_Conditions != 0)

  vis <- ggplot(SigHist, aes(x = .data$Number_of_Conditions, y = .data$Significant_SNPs)) +
    switchgrassGWAS::theme_oeco +
    geom_line() +
    geom_point() +
    geom_hline(yintercept = 0, lty = 2) +
    xlab(label = "Number of Conditions") +
    ylab(label = "Number of Significant SNPs")

  if(saveoutput == TRUE){
    if(!(str_sub(suffix, end = 1) %in% c("", "_"))){
      suffix <- paste0("_", suffix)
    if(str_sub(suffix, end = 1) %in% c("")){
      suffix <- paste0("_", get_date_filename())
    ggsave(paste0("SNPs_significant_by_number_of_conditions", suffix, ".png"),
           width = 5, height = 3, units = "in", dpi = 400)

  return(list(sighist = SigHist, ggobject = vis))
Alice-MacQueen/switchgrassGWAS documentation built on Jan. 23, 2022, 7:55 p.m.