R/grn.R

Defines functions GRNPlot GRNHeatmap GetGRN GetTFGeneCorrelation

Documented in GetGRN GetTFGeneCorrelation GRNHeatmap GRNPlot

#' Get TF gene correlation
#'
#' This function will compute the correlation between TF binding activity and
#' gene expression along the trajectory.
#'
#' @param object A Seurat object
#' @param tf.use A string list to specify which TFs to use for correlation computation
#' @param gene.use A string list to specify which genes to use for correlation computation
#' @param tf.assay Assay that includes TF activity data. Default: "chromvar"
#' @param gene.assay Assay that includes gene expression activity data. Default: "RNA"
#' @param atac.assay Assay that includes peaks data. Default: "ATAC"
#' @param trajectory.name Trajectory name
#' @param groupEvery The number of sequential percentiles to group together when generating a trajectory.
#' This is similar to smoothing via a non-overlapping sliding window across pseudo-time.
#'
#' @return A matrix containing TF-gene correlation
#' @export
#'
GetTFGeneCorrelation <- function(object,
                                 tf.use = NULL,
                                 gene.use = NULL,
                                 tf.assay = "chromvar",
                                 gene.assay = "RNA",
                                 atac.assay="ATAC",
                                 trajectory.name = "Trajectory",
                                 groupEvery=1) {
  ## get tf activity and gene expression along trajectory
  trajMM <- GetTrajectory(
    object,
    assay = tf.assay,
    slot = "data",
    trajectory.name = trajectory.name,
    groupEvery=groupEvery,
    smoothWindow = 7,
    log2Norm = FALSE
  )

  trajRNA <- GetTrajectory(
    object,
    assay = gene.assay,
    slot = "data",
    trajectory.name = trajectory.name,
    groupEvery=groupEvery,
    smoothWindow = 7,
    log2Norm = TRUE
  )

  rownames(trajMM) <- object@assays[[atac.assay]]@motifs@motif.names

  tf_activity <- suppressMessages(
    TrajectoryHeatmap(
      trajMM,
      varCutOff = 0,
      pal = paletteContinuous(set = "solarExtra"),
      limits = c(-2, 2),
      name = "TF activity",
      returnMatrix = TRUE
    )
  )

  gene_expression <- suppressMessages(
    TrajectoryHeatmap(
      trajRNA,
      varCutOff = 0.9,
      pal = paletteContinuous(set = "solarExtra"),
      limits = c(-2, 2),
      name = "Gene expression",
      returnMatrix = TRUE
    )
  )

  ## here we filter the TFs according to our correlation analysis
  if (!is.null(tf.use)) {
    tf_activity <- tf_activity[tf.use, ]
  }

  ## here we filter the genes by only considering genes that are linked to peaks
  if (!is.null(gene.use)) {
    sel_genes <- intersect(rownames(gene_expression), gene.use)

    ## subset the gene expression matrix
    gene_expression <- gene_expression[sel_genes, ]
  }

  ## compute the correlation of TF activity and gene expression along the trajectory
  ## df.cor -> gene by TF matrix
  df.cor <- t(cor(t(tf_activity), t(gene_expression))) %>%
    as.data.frame()

  if (!is.null(tf.use)) {
    df.cor <- df.cor[, tf.use]
  }

  df.cor$gene <- rownames(df.cor)

  df.cor <- df.cor %>%
    tidyr::pivot_longer(!gene, names_to = "tf", values_to = "correlation") %>%
    select(c(tf, gene, correlation))

  df.cor$t_stat <-
    (df.cor$correlation / sqrt((
      pmax(1 - df.cor$correlation ^ 2, 0.00000000000000001, na.rm = TRUE)
    ) / (ncol(tf_activity) - 2))) #T-statistic P-value

  df.cor$p_value <-
    2 * pt(-abs(df.cor$t_stat), ncol(tf_activity) - 2)
  df.cor$fdr <- p.adjust(df.cor$p_value, method = "fdr")

  return(df.cor)

}

#' Get gene regulatory network
#'
#' This function will generate the final prediction of TF-gene network. It takes the
#' TF-gene correlation, peak-TF binding prediction, and peak-to-gene links as input.
#'
#' @param df.cor A matrix of TF-gene correlation as generated by using the function
#' \code{\link{GetTFGeneCorrelation}}.
#' @param df.p2g A data frame containing predicted peak-to-gene links as generated
#' by using the function \code{\link{PeakToGene}}.
#' @param motif.matching A matrix of peak by motif to indicate if a peak is bound
#' by a motif. This matrix should only contain 0 and 1 with 1 indicating a binding
#' event and 0 indicating no binding site(s).
#'
#' @return A data frame representing gene regulatory network
#' @export
#'
GetGRN <- function(motif.matching = NULL,
                   df.cor = NULL,
                   df.p2g = NULL) {
  if (is.null(motif.matching)) {
    stop("Please provide a motif matching matrix!")
  }

  if (is.null(df.cor)) {
    stop("Please provide a tf-gene correlation matrix!")
  }

  if (is.null(df.p2g)) {
    stop("Please provide peak-to-gene links!")
  }

  ## We next use peak-to-gene links to predict the target genes for each TF.
  ## We consider a gene is regulated by a peak if there is a positive
  ## correlation between gene expression and peak accessibility
  ## mat.p2g is a gene by peak data frame
  message("Filtering network by peak-to-gene links...")

  # convert sparse matrix to data frame
  summ <- Matrix::summary(motif.matching)

  df.p2m <- data.frame(peak = rownames(motif.matching)[summ$i],
                       tf = colnames(motif.matching)[summ$j],
                       is_bound = summ$x)

  df.p2g <- subset(df.p2g, select = c(peak, gene))

  df.m2g <- dplyr::left_join(df.p2m, df.p2g, by = "peak") %>%
    dplyr::group_by(tf, gene) %>%
    dplyr::summarise(n_peaks = n()) %>%
    as.data.frame()

  ## To link gene to TF, we also need the TF binding information
  ## Here we obtain a peak by TF matrix representing if peak is bound by a TF
  ## mat.motif is a peak by TF matrix
  message("Filtering network by TF binding site prediction...")
  df.grn <- dplyr::left_join(df.m2g, df.cor, by = c("tf", "gene"))

  return(df.grn)

}

#' Get a heatmap of TF gene correlation
#'
#' This function will generate a heatmap to visualize the TF-gene correlation computed
#' by the \code{\link{GetTFGeneCorrelation}}
#'
#'
#' @param tf.gene.cor A matrix representing TF-gene correlation
#' @param tf.timepoint A list of TF time point along the trajectory
#' @param km Number of clusters
#'
#' @return A heatmap
#' @export
#'
GRNHeatmap <- function(tf.gene.cor,
                       tf.timepoint = NULL,
                       km = 1) {

  mat.cor <- tf.gene.cor %>%
    as.data.frame() %>%
    select(c(tf, gene, correlation)) %>%
    tidyr::pivot_wider(names_from = tf, values_from = correlation) %>%
    textshape::column_to_rownames("gene")


  if (!is.null(tf.timepoint)) {
    col_fun <- circlize::colorRamp2(tf.timepoint,
                                    ArchR::paletteContinuous(set = "blueYellow",
                                                             n = length(tf.timepoint)))
    column_ha <-
      ComplexHeatmap::HeatmapAnnotation(time_point = tf.timepoint,
                                        col = list(time_point = col_fun))

  } else{
    column_ha <- NULL
  }

  ht <- Heatmap(
    as.matrix(mat.cor),
    name = "correlation",
    cluster_columns = FALSE,
    clustering_method_rows = "ward.D2",
    top_annotation = column_ha,
    show_row_names = FALSE,
    show_column_names = TRUE,
    row_km = km,
    column_km = km,
    border = TRUE
  )

  return(ht)

}

#' Get a graph
#'
#' This function will generate a graph to visualize the predicted gene regulatory network
#'
#' @param df.grn A data frame representing predicted network
#' @param tfs.timepoint Time points of TFs
#' @param genes.cluster A data frame containing clustering results of genes
#' @param genes.highlight A string list to include gene names for plotting
#' @param cols.highlight Color code for highlighted genes
#' @param seed Random seet
#' @param plot.importance Whether or not plot the scatter plot to visualize importance score of each TF
#' @param min.importance The minimum importance score for showing the TF labels.
#'
#' @importFrom igraph layout_with_fr
#' @importFrom igraph page_rank
#' @importFrom igraph betweenness
#' @importFrom igraph graph_from_data_frame
#' @importFrom ggraph geom_edge_link
#' @importFrom ggraph geom_node_point
#' @importFrom ggraph geom_node_label
#' @importFrom igraph E
#' @importFrom igraph V
#' @return A ggplot object
#' @export
#'
GRNPlot <- function(df.grn,
                    tfs.use = NULL,
                    show.tf.labels = TRUE,
                    tfs.timepoint = NULL,
                    genes.cluster = NULL,
                    genes.use = NULL,
                    genes.highlight = NULL,
                    cols.highlight = "#984ea3",
                    seed = 42,
                    plot.importance = TRUE,
                    min.importance = 2,
                    remove.isolated = FALSE) {
  if (is.null(tfs.timepoint)) {
    stop("Need time point for each TF!")
  }

  if (!is.null(tfs.use)){
    df.grn <- subset(df.grn, tf %in% tfs.use)

  }
  if (!is.null(genes.use)){
    df.grn <- subset(df.grn, gene %in% genes.use)
  }

  tf.list <- unique(df.grn$tf)
  gene.list <- setdiff(unique(df.grn$gene), tf.list)

  # create graph from data frame
  g <- igraph::graph_from_data_frame(df.grn, directed = TRUE)

  # remove the isolated if indicated
  if(remove.isolated){
    isolated <- which(degree(g)==0)
    g <- igraph::delete.vertices(g, isolated)
  }

  # compute pagerank and betweenness
  pagerank <- page_rank(g, weights = E(g)$weights)
  bet <-
    betweenness(g,
                weights = E(g)$weights,
                normalized = TRUE)

  df_measure <- data.frame(
    tf = V(g)$name,
    pagerank = pagerank$vector,
    betweenness = bet
  ) %>%
    subset(tf %in% df.grn$tf) %>%
    mutate(pagerank = scale(pagerank)[, 1]) %>%
    mutate(betweenness = scale(betweenness)[, 1])

  # compute importance only for TFs based on centrality and betweenness
  min.page <- min(df_measure$pagerank)
  min.bet <- min(df_measure$betweenness)
  df_measure$importance <-
    sqrt((df_measure$pagerank - min.page) ** 2 +
           (df_measure$betweenness - min.bet) ** 2)

  if (plot.importance) {
    p <- ggplot(data = df_measure) + aes(x = reorder(tf, -importance),
                                         y = importance) +
      geom_point() +
      xlab("TFs") + ylab("Importance") +
      cowplot::theme_cowplot() +
      theme(axis.text.x = element_text(angle = 60, hjust = 1))

    print(p)
  }

  df_measure_sub <- subset(df_measure, importance > 2)

  # assign size to each node
  # for TFs, the size is proportional to the importance
  tf_size <- df_measure$importance
  names(tf_size) <- df_measure$tf

  ## for genes, we use the minimum size of TFs
  gene_size <-
    rep(min(df_measure$importance), length(unique(df.grn$gene)))
  names(gene_size) <- gene.list
  v_size <- c(tf_size, gene_size)
  V(g)$size <- v_size[V(g)$name]

  # assign color to each node
  ## TFs are colored by pseudotime point
  cols.tf <- ArchR::paletteContinuous(set = "blueYellow",
                                      n = length(tfs.timepoint))
  names(cols.tf) <- names(tfs.timepoint)

  ## genes are colored based on the clustering
  if (is.null(genes.cluster)) {
    cols.gene <- rep("gray", length(gene.list))
    names(cols.gene) <- gene.list
  } else{
    genes.cluster <- genes.cluster %>%
      subset(gene %in% gene.list)

    cols <-
      ArchR::paletteDiscrete(values = as.character(genes.cluster$cluster))

    df.gene <- lapply(1:length(cols), function(x) {
      df <- subset(genes.cluster, cluster == x)
      df$color <- rep(cols[[x]], nrow(df))
      return(df)

    }) %>% Reduce(rbind, .)

    cols.gene <- df.gene$color
    names(cols.gene) <- df.gene$gene
  }
  v_color <- c(cols.tf, cols.gene)
  v_color <- v_color[V(g)$name]

  ## assign alpha
  tf_alpha <- rep(1, length(tf.list))
  gene_alpha <- rep(0.5, length(gene.list))
  names(tf_alpha) <- tf.list
  names(gene_alpha) <- gene.list
  v_alpha <- c(tf.list, gene.list)
  V(g)$alpha <- v_alpha[V(g)$name]

  # compute layout
  set.seed(seed)
  layout <- layout_with_fr(
    g,
    weights = E(g)$weights,
    dim = 2,
    niter = 1000
  )

  p <- ggraph(g, layout = layout) +
    geom_edge_link(edge_colour = "gray", edge_alpha = 0.25) +
    geom_node_point(aes(
      size = V(g)$size,
      color = as.factor(name),
      alpha = V(g)$alpha
    ),
    show.legend = FALSE) +
    scale_size(range = c(1, 10)) +
    scale_color_manual(values = v_color)

  if(show.tf.labels){
    p <- p +geom_node_label(
      aes(
        #filter = V(g)$name %in% df_measure_sub$tf,
        filter = V(g)$name %in% tf.list,
        label = V(g)$name
      ),
      repel = TRUE,
      hjust = "inward",
      color = "#ff7f00",
      size = 5,
      show.legend = FALSE,
      max.overlaps = Inf
    )
  }


  # highlight some genes
  if (!is.null(genes.highlight)) {
    p <-
      p + geom_node_label(
        aes(
          filter = V(g)$name %in% genes.highlight,
          label = V(g)$name
        ),
        repel = TRUE,
        hjust = "inward",
        size = 5,
        color = cols.highlight,
        show.legend = FALSE
      )

  }

  p <- p + theme_void()

  return(p)
}


#' Plot the target gene in spatial space
#'
#' This function plots the overall expression of all target genes
#' for a particular TF using spatial transcriptome data. It is based on
#'
#' @param object A Seurat object of spatial transcriptome data
#' @param assay Which assay to plot
#' @param df.grn A data frame including the inferred gene regulatory network
#' @param tf.use Which TF to use
#' @param min.cutoff Minimum cutoff values for each feature
#' @param max.cutoff Maximum cutoff values for each feature
#' @param vis.option Options for visualization. Default: "B"
#'
#' @return A ggplot object
#' @export
#'
GRNSpatialPlot <- function(object, assay,
                           df.grn,
                           tf.use,
                           vis.option = "B", ...){
    DefaultAssay(object) <- assay

    # select all targets for a TF
    df.target <- subset(df.grn, tf == tf.use)

    geneset <- list(tf.use = df.target$gene)
    object <- AddModuleScore(object, features = geneset)

    p <- Seurat::SpatialFeaturePlot(object, features = "Cluster1", ...)+
            scale_fill_viridis(option = vis.option) +
            ggtitle(glue::glue("{tf.use} targets")) +
    labs(fill='')

    return(p)

}

#' Add the TF target expression
#'
#' This function will created a new assay by using the predicted gene
#' regulatory network where features are the selected TFs. Each value represents
#' the average expression of all targets calculated by using the function
#' \code{\link{AddModuleScore}} from Seurat.
#'
#' @param object A Seurat object used as input
#' @param df.grn A data frame containing the predicted gene regulatory network.
#' @param target.assay Name for the new assay. Default: "target"
#'
#' @return A Seurat object with a new assay named by target.assay
#' @export
AddTargetAssay <- function(object,
                           target.assay = "target",
                           rna.assay = "RNA",
                           df.grn = NULL){

    df.genes <- split(df.grn$gene,df.grn$tf)
    object <- AddModuleScore(object, features = df.genes,
                             assay=rna.assay,
                             name = "tf_target_")

    target_gex <- object@meta.data %>%
        as.data.frame() %>%
        select(contains("tf_target_"))

    colnames(target_gex) <- names(df.genes)

    object[["target"]] <- CreateAssayObject(data = t(target_gex))

    return(object)

}





#' PCA of Topological Measures
#'
#' This function will generate a graph centrality PCA embedding
#'
#' @param df.grn A data frame representing predicted network
#' @param genes.cluster A data frame containing clustering results of genes
#'
#' @importFrom igraph layout_with_fr
#' @importFrom igraph page_rank
#' @importFrom factoextra fviz_pca_biplot
#' @importFrom igraph betweenness
#' @importFrom igraph degree
#' @importFrom igraph graph_from_data_frame
#' @importFrom ggraph geom_edge_link
#' @importFrom ggraph geom_node_point
#' @importFrom ggraph geom_node_label
#' @importFrom ggrepel geom_label_repel
#' @importFrom igraph E
#' @importFrom igraph V
#' @return a prcomp output
#' @export
TopEmbGRN <- function(df.grn,gene.cluster=NULL,axis=c(1,2)){
	netembb <- tibble("nodes" = V(df.grn)$name,
			  "outdegree"= degree(as.directed(df.grn),mode = 'out')[V(df.grn)$name],
			  "indegree"=degree(as.directed(df.grn),mode = 'in')[V(df.grn)$name],
			  "mediator"= betweenness(df.grn,normalized = TRUE)[V(df.grn)$name],
			  "pagerank"=page.rank(df.grn)$vector[V(df.grn)$name],
			  "type"=V(df.grn)$type[V(netobj)$name])
	netembb$indegree <- (netembb$indegree - min(netembb$indegree))/(max(netembb$indegree)-min(netembb$indegree))
	netembb$outdegree <- (netembb$outdegree - min(netembb$outdegree))/(max(netembb$outdegree)-min(netembb$outdegree))
	pcaemb <- prcomp(netembb[,c(2,3,4)],center = T)
	rownames(pcaemb$x) <- netembb$nodes
	x <- max(abs(pcaemb$x[,axis[1]]))
	y <- max(abs(pcaemb$x[,axis[2]]))
	z_x <- pcaemb$x[,axis[1]]
	z_y <- pcaemb$x[,axis[2]]
	ver_zx <- ifelse(abs(z_x)>2*pcaemb$sdev[axis[1]],1,0)
	ver_zy <- ifelse(abs(z_y)>2*pcaemb$sdev[axis[2]],1,0)
	p<-fviz_pca_biplot(pcaemb,
              axes = axis,
              pointshape = 21, pointsize = 0.5,labelsize = 12,
              repel = TRUE,max.overlaps=150,label='var',arrowsize = 1.5)+
	      geom_label_repel(aes(label=rownames(pcaemb$x)),hjust=0, vjust=0,size = 4)+
              xlim(-(x), (x))+
              ylim(-(y), (y))+
	      theme(text = element_text(size = 4),
                    axis.title = element_text(size = 7.5),
                    axis.text = element_text(size = 7.5))

	print(p)
	return(pcaemb)

}


#' Centric Plot
#'
#' This function will generate a plot centered in a specific gene/TF
#'
#' @param netobj scMEGA net
#' @importFrom ggraph geom_edge_link
#' @importFrom ggraph geom_node_point
#' @importFrom ggraph geom_node_label
#' @importFrom ggraph scale_edge_color_viridis
#' @importFrom igraph degree
#' @importFrom igraph V
#' @return a ggraph plot
#' @export
NetCentPlot <- function(netobj,gene,highlights=NULL){
d <- betweenness(netobj)
if(is.null(highlights)){
	p<- ggraph(netobj, layout = "focus", focus = which(V(netobj)$name==gene)) +
		geom_edge_link(aes(colour=weights,alpha=weights),edge_width=0.1) +
		geom_node_point(aes(size=d),shape = 20) +
		geom_node_label(aes(filter = (d > mean(d)) & (type == "TF/Gene"),
				      label = name),
				     size = 5,
				repel = TRUE) +
		scale_edge_color_viridis()+
		coord_fixed() +
		theme_graph() +
		theme(legend.position = "none")
}else{
	p<- ggraph(netobj, layout = "focus", focus = which(V(netobj)$name==gene)) +
		geom_edge_link(aes(colour=weights,alpha=weights),edge_width=0.1) +
		geom_node_point(aes(size=d),shape = 20) +
		geom_node_label(aes(filter = ifelse(V(netobj)$name %in% highlights, T, F),
				      label = name),
				     size = 5,
				repel = TRUE) +
		scale_edge_color_viridis()+
		coord_fixed() +
		theme_graph() +
		theme(legend.position = "none")

}
return(p)

}
CostaLab/scMEGA documentation built on Sept. 25, 2024, 6:11 a.m.