R/plots.R

Defines functions plot_interaction_communities plot_contrast_heatmap plot_interaction_heatmap plot_view_contributions plot_improvement_stats

Documented in plot_contrast_heatmap plot_improvement_stats plot_interaction_communities plot_interaction_heatmap plot_view_contributions

# mistyR plotting functions
# Copyleft (ɔ) 2020 Jovan Tanevski <jovan.tanevski@uni-heidelberg.de>

#' Plot observed performance and improvement per target
#'
#' Generates a plot of the mean (+- standard deviation) of the performance value
#' per target across all samples from the results.
#'
#' @param misty.results a results list generated by
#'     \code{\link{collect_results}()}.
#' @param measure performance measure to be plotted (See
#'     \code{\link{collect_results}()}).
#' @param trim display targets with performance value above (if R2 or gain) or
#'     below (otherwise) this value only.
#'
#' @return The \code{misty.results} list (invisibly).
#'
#' @seealso \code{\link{collect_results}()} to generate a
#'     results list from raw results.
#'
#' @family plotting functions
#'
#' @examples
#' all.samples <- list.dirs("results", recursive = FALSE)
#'
#' collect_results(all.samples) %>% plot_improvement_stats()
#'
#' misty.results <- collect_results(all.samples)
#' misty.results %>% plot_improvement_stats(measure = "gain.RMSE")
#' misty.results %>% plot_improvement_stats(measure = "intra.R2")
#' @export
plot_improvement_stats <- function(misty.results,
                                   measure = c(
                                     "gain.R2", "multi.R2", "intra.R2",
                                     "gain.RMSE", "multi.RMSE", "intra.RMSE"
                                   ),
                                   trim = -Inf) {
  measure.type <- match.arg(measure)

  assertthat::assert_that(("improvements.stats" %in% names(misty.results)),
    msg = "The provided result list is malformed. Consider using collect_results()."
  )

  inv <- sign((stringr::str_detect(measure.type, "gain") |
    stringr::str_detect(measure.type, "RMSE", negate = TRUE)) - 0.5)

  plot.data <- misty.results$improvements.stats %>%
    dplyr::filter(measure == measure.type, inv * mean >= inv * trim)

  assertthat::assert_that(assertthat::not_empty(plot.data),
    msg = "Invalid selection of measure and/or trim value."
  )

  set2.orange <- "#FC8D62"

  results.plot <- ggplot2::ggplot(
    plot.data,
    ggplot2::aes(
      x = stats::reorder(target, -mean),
      y = mean
    )
  ) +
    ggplot2::geom_pointrange(ggplot2::aes(
      ymin = mean - sd,
      ymax = mean + sd
    )) +
    ggplot2::geom_point(color = set2.orange) +
    ggplot2::theme_classic() +
    ggplot2::ylab(measure) +
    ggplot2::xlab("Target") +
    ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90, hjust = 1))

  print(results.plot)

  invisible(misty.results)
}

#' Plot view contributions per target
#'
#' Generate a stacked barplot of the average view contribution fraction per target
#' across all samples from the results.
#'
#' @inheritParams plot_improvement_stats
#'
#' @param trim.measure the measure used for trimming.
#'
#' @return The \code{misty.results} list (invisibly).
#'
#' @seealso \code{\link{collect_results}()} to generate a
#'     results list from raw results.
#'
#' @family plotting functions
#'
#' @examples
#' all.samples <- list.dirs("results", recursive = FALSE)
#'
#' collect_results(all.samples) %>% plot_view_contributions()
#' @export
plot_view_contributions <- function(misty.results, trim = -Inf,
                                    trim.measure = c(
                                      "gain.R2", "multi.R2", "intra.R2",
                                      "gain.RMSE", "multi.RMSE", "intra.RMSE"
                                    )) {
  trim.measure.type <- match.arg(trim.measure)

  assertthat::assert_that(("contributions.stats" %in% names(misty.results)),
    msg = "The provided result list is malformed. Consider using collect_results()."
  )

  assertthat::assert_that(("improvements.stats" %in% names(misty.results)),
    msg = "The provided result list is malformed. Consider using collect_results()."
  )

  inv <- sign((stringr::str_detect(trim.measure.type, "gain") |
    stringr::str_detect(trim.measure.type, "RMSE", negate = TRUE)) - 0.5)

  targets <- misty.results$improvements.stats %>%
    dplyr::filter(
      measure == trim.measure.type,
      inv * mean >= inv * trim
    ) %>%
    dplyr::pull(target)

  assertthat::assert_that(assertthat::not_empty(targets),
    msg = "Invalid selection of trim measure and/or value."
  )

  plot.data <- misty.results$contributions.stats %>%
    dplyr::filter(target %in% targets)

  results.plot <- ggplot2::ggplot(plot.data, ggplot2::aes(x = target, y = fraction)) +
    ggplot2::geom_col(ggplot2::aes(group = view, fill = view)) +
    ggplot2::scale_fill_brewer(palette = "Set2") +
    ggplot2::theme_classic() +
    ggplot2::ylab("Contribution") +
    ggplot2::xlab("Target") +
    ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90, hjust = 1))

  print(results.plot)

  invisible(misty.results)
}

#' Plot importance heatmap for a view
#'
#' Generate a heatmap with importances of predictor-target interaction.
#'
#' @inheritParams plot_view_contributions
#'
#' @param view abbreviated name of the view.
#' @param cutoff importance threshold. Importances below this value will
#' be colored white in the heatmap and considered as not relevant.
#' @param clean a \code{logical} indicating whether to remove rows and columns
#' with all importances are below \code{cutoff} from the heatmap.
#'
#' @return The \code{misty.results} list (invisibly).
#'
#' @seealso \code{\link{collect_results}()} to generate
#'     a results list from raw results.
#'
#' @family plotting functions
#'
#' @examples
#' all.samples <- list.dirs("results", recursive = FALSE)
#'
#' collect_results(all.samples) %>%
#'   plot_interaction_heatmap("intra") %>%
#'   plot_interaction_heatmap("para.10", cutoff = 0.5)
#' @export
plot_interaction_heatmap <- function(misty.results, view, cutoff = 1,
                                     trim = -Inf, trim.measure = c(
                                       "gain.R2", "multi.R2", "intra.R2",
                                       "gain.RMSE", "multi.RMSE", "intra.RMSE"
                                     ),
                                     clean = FALSE) {
  trim.measure.type <- match.arg(trim.measure)

  assertthat::assert_that(("importances.aggregated" %in% names(misty.results)),
    msg = "The provided result list is malformed. Consider using collect_results()."
  )

  assertthat::assert_that(("improvements.stats" %in% names(misty.results)),
    msg = "The provided result list is malformed. Consider using collect_results()."
  )

  assertthat::assert_that((view %in%
    (misty.results$importances.aggregated %>% dplyr::pull(view))),
  msg = "The selected view cannot be found in the results table."
  )

  inv <- sign((stringr::str_detect(trim.measure.type, "gain") |
    stringr::str_detect(trim.measure.type, "RMSE", negate = TRUE)) - 0.5)

  targets <- misty.results$improvements.stats %>%
    dplyr::filter(
      measure == trim.measure.type,
      inv * mean >= inv * trim
    ) %>%
    dplyr::pull(target)


  plot.data <- misty.results$importances.aggregated %>%
    dplyr::filter(view == !!view, Target %in% targets)

  if (clean) {
    clean.predictors <- plot.data %>%
      dplyr::mutate(Importance = Importance * (Importance >= cutoff)) %>%
      dplyr::group_by(Predictor) %>%
      dplyr::summarize(total = sum(Importance, na.rm = TRUE)) %>%
      dplyr::filter(total > 0) %>%
      dplyr::pull(Predictor)
    clean.targets <- plot.data %>%
      dplyr::mutate(Importance = Importance * (Importance >= cutoff)) %>%
      dplyr::group_by(Target) %>%
      dplyr::summarize(total = sum(Importance, na.rm = TRUE)) %>%
      dplyr::filter(total > 0) %>%
      dplyr::pull(Target)
    plot.data.clean <- plot.data %>%
      dplyr::filter(
        Predictor %in% clean.predictors,
        Target %in% clean.targets
      )
  } else {
    plot.data.clean <- plot.data
  }

  set2.blue <- "#8DA0CB"

  results.plot <- ggplot2::ggplot(
    plot.data.clean,
    ggplot2::aes(
      x = Predictor,
      y = Target
    )
  ) +
    ggplot2::geom_tile(ggplot2::aes(fill = Importance)) +
    ggplot2::scale_fill_gradient2(
      low = "white",
      mid = "white",
      high = set2.blue,
      midpoint = cutoff
    ) +
    ggplot2::theme_classic() +
    ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90, hjust = 1)) +
    ggplot2::coord_equal() +
    ggplot2::ggtitle(view)



  print(results.plot)

  invisible(misty.results)
}

#' Plot heatmap of local contrast between two views
#'
#' The heatmap shows the interactions that are present and have importance above
#' a \code{cutoff} value in the \code{to.view} but but not in the \code{from.view}.
#'
#' @inheritParams plot_interaction_heatmap
#'
#' @param from.view,to.view abbreviated name of the view.
#'
#' @return The \code{misty.results} list (invisibly).
#'
#' @seealso \code{\link{collect_results}()} to generate a
#'     results list from raw results.
#'
#' @family plotting functions
#'
#' @examples
#' all.samples <- list.dirs("results", recursive = FALSE)
#'
#' misty.results <- collect_results(all.samples)
#'
#' misty.results %>%
#'   plot_contrast_heatmap("intra", "para.10")
#'
#' misty.results %>%
#'   plot_contrast_heatmap("intra", "para.10", cutoff = 0.5)
#' @export
plot_contrast_heatmap <- function(misty.results, from.view, to.view, cutoff = 1,
                                  trim = -Inf, trim.measure = c(
                                    "gain.R2", "multi.R2", "intra.R2",
                                    "gain.RMSE", "multi.RMSE", "intra.RMSE"
                                  )) {
  trim.measure.type <- match.arg(trim.measure)

  assertthat::assert_that(("importances.aggregated" %in% names(misty.results)),
    msg = "The provided result list is malformed. Consider using collect_results()."
  )

  assertthat::assert_that(("improvements.stats" %in% names(misty.results)),
    msg = "The provided result list is malformed. Consider using collect_results()."
  )

  assertthat::assert_that((from.view %in%
    (misty.results$importances.aggregated %>% dplyr::pull(view))),
  msg = "The selected from.view cannot be found in the results table."
  )

  assertthat::assert_that((to.view %in%
    (misty.results$importances.aggregated %>% dplyr::pull(view))),
  msg = "The selected to.view cannot be found in the results table."
  )

  inv <- sign((stringr::str_detect(trim.measure.type, "gain") |
    stringr::str_detect(trim.measure.type, "RMSE", negate = TRUE)) - 0.5)

  targets <- misty.results$improvements.stats %>%
    dplyr::filter(
      measure == trim.measure.type,
      inv * mean >= inv * trim
    ) %>%
    dplyr::pull(target)

  from.view.wide <- misty.results$importances.aggregated %>%
    dplyr::filter(view == from.view, Target %in% targets) %>%
    tidyr::pivot_wider(
      names_from = "Target",
      values_from = "Importance",
      -c(view, nsamples)
    )

  to.view.wide <- misty.results$importances.aggregated %>%
    dplyr::filter(view == to.view, Target %in% targets) %>%
    tidyr::pivot_wider(
      names_from = "Target",
      values_from = "Importance",
      -c(view, nsamples)
    )

  mask <- ((from.view.wide %>%
    dplyr::select(-Predictor)) < cutoff) &
    ((to.view.wide %>%
      dplyr::select(-Predictor)) >= cutoff)

  masked <- ((to.view.wide %>%
    tibble::column_to_rownames("Predictor")) * mask)

  plot.data <- masked %>%
    dplyr::slice(which(masked %>% rowSums(na.rm = TRUE) > 0)) %>%
    dplyr::select(which(masked %>% colSums(na.rm = TRUE) > 0)) %>%
    tibble::rownames_to_column("Predictor") %>%
    tidyr::pivot_longer(names_to = "Target", values_to = "Importance", -Predictor)

  set2.blue <- "#8DA0CB"

  results.plot <- ggplot2::ggplot(plot.data, ggplot2::aes(x = Predictor, y = Target)) +
    ggplot2::geom_tile(ggplot2::aes(fill = Importance)) +
    ggplot2::scale_fill_gradient2(low = "white", mid = "white", high = set2.blue, midpoint = cutoff) +
    ggplot2::theme_classic() +
    ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90, hjust = 1)) +
    ggplot2::coord_equal() +
    ggplot2::ggtitle(paste0(to.view, " - ", from.view))

  print(results.plot)

  invisible(misty.results)
}

#' Plot marker interaction communities
#'
#' Identify and plot a graph of marker interaction communities.
#'
#' The communities are identified using the Louvain algorithm. Communities can
#' be extracted only from views that have the same predictor and target markers.
#'
#' @inheritParams plot_interaction_heatmap
#'
#' @return The \code{misty.results} list (invisibly).
#'
#' @seealso \code{\link{collect_results}()} to generate a
#'     results list from raw results.
#'
#' @family plotting functions
#'
#' @examples
#' all.samples <- list.dirs("results", recursive = FALSE)
#'
#' misty.results <- collect_results(all.samples)
#'
#' misty.results %>%
#'   plot_interaction_communities("intra") %>%
#'   plot_interaction_communities("para.10")
#'
#' misty.results %>%
#'   plot_interaction_communities("para.10", cutoff = 0.5)
#' @export
plot_interaction_communities <- function(misty.results, view, cutoff = 1) {
  assertthat::assert_that(("importances.aggregated" %in% names(misty.results)),
    msg = "The provided result list is malformed. Consider using collect_results()."
  )

  assertthat::assert_that((view %in%
    (misty.results$importances.aggregated %>% dplyr::pull(view))),
  msg = "The selected view cannot be found in the results table."
  )

  view.wide <- misty.results$importances.aggregated %>%
    dplyr::filter(view == !!view) %>%
    tidyr::pivot_wider(
      names_from = "Target", values_from = "Importance",
      -c(view, nsamples)
    )


  assertthat::assert_that(
    all((view.wide %>%
      dplyr::select(-Predictor) %>% colnames() %>% sort()) ==
      (view.wide %>%
        dplyr::pull(Predictor)) %>% sort()),
    msg = "The predictor and target markers in the view must match."
  )

  assertthat::assert_that(requireNamespace("igraph",
    versionCheck = list(op = ">=", version = "1.2.7"),
    quietly = TRUE
  ),
  msg = "The package igraph (>= 1.2.7) is required to calculate the interaction communities."
  )

  A <- view.wide %>%
    dplyr::select(-Predictor) %>%
    as.matrix()
  A[A < cutoff | is.na(A)] <- 0

  # Workaround: fix binding check for variable . in this function context.
  # The reference to . below works just fine in its own scope
  . <- NULL

  G <- igraph::graph.adjacency(A, mode = "plus", weighted = TRUE) %>%
    igraph::set.vertex.attribute("name", value = names(igraph::V(.))) %>%
    igraph::delete.vertices(which(igraph::degree(.) == 0))

  C <- igraph::cluster_leiden(G)

  layout <- igraph::layout_with_fr(G)

  igraph::plot.igraph(G,
    layout = layout, mark.groups = C, main = view, vertex.size = 4,
    vertex.color = "black", vertex.label.dist = 1,
    vertex.label.color = "black", vertex.label.font = 2, vertex.label.cex = 0.66
  )

  invisible(misty.results)
}


#' Plot heatmap of contrast between two result lists
#'
#' Plot interexperiment contrast of views.
#'
#' The heatmaps show the interactions that are present and have importance above
#' a \code{cutoff.to} value in the \code{views} of \code{misty.results.to} but
#' not present or have importance below \code{cutoff.from} in the \code{views}
#' of \code{misty.results.from}.
#'
#' @inheritParams plot_interaction_heatmap
#'
#' @param misty.results.from,misty.results.to a results list generated by
#'     \code{\link{collect_results}()}.
#' @param views one or more abbreviated names of views.
#' @param cutoff.from,cutoff.to importance thresholds respective to the result
#'     lists.
#'
#' @return The \code{misty.results.from} list (invisibly).
#'
#' @seealso \code{\link{collect_results}()} to generate a
#'     results list from raw results.
#'
#' @family plotting functions
#'
#' @examples
#' # if for example the available samples come from different grades of tumors
#'
#' grade1.results <- collect_results(c("results/synthetic1", "results/synthetic2"))
#' grade3.results <- collect_results("results/synthetic10")
#'
#' # highlight interactions present in grade 1 tumors but not in grade 3 tumors
#' # in the paraview
#'
#' grade3.results %>% plot_contrast_results(grade1.results, views = "para.10")
#'
#' # see the loss of interactions in all views with lower sensitivity
#'
#' plot_contrast_results(grade3.results, grade1.results, cutoff.from = 0.75, cutoff.to = 0.5)
#' @export
plot_contrast_results <- function(misty.results.from, misty.results.to,
                                  views = NULL, cutoff.from = 1, cutoff.to = 1,
                                  trim = -Inf, trim.measure = c(
                                    "gain.R2", "multi.R2", "intra.R2",
                                    "gain.RMSE", "multi.RMSE", "intra.RMSE"
                                  )) {
  trim.measure.type <- match.arg(trim.measure)

  assertthat::assert_that(("importances.aggregated" %in% names(misty.results.from)),
    msg = "The first provided result list is malformed. Consider using collect_results()."
  )

  assertthat::assert_that(("improvements.stats" %in% names(misty.results.from)),
    msg = "The provided result list is malformed. Consider using collect_results()."
  )

  assertthat::assert_that(("importances.aggregated" %in% names(misty.results.to)),
    msg = "The second provided result list is malformed. Consider using collect_results()."
  )

  assertthat::assert_that(("improvements.stats" %in% names(misty.results.to)),
    msg = "The provided result list is malformed. Consider using collect_results()."
  )

  if (is.null(views)) {
    assertthat::assert_that(rlang::is_empty(setdiff(
      misty.results.from$importances.aggregated %>%
        dplyr::pull(view) %>%
        unique(),
      misty.results.to$importances.aggregated %>%
        dplyr::pull(view) %>%
        unique()
    )),
    msg = "The requested views do not exist in both result lists."
    )
    views <- misty.results.from$importances.aggregated %>%
      dplyr::pull(view) %>%
      unique()
  } else {
    assertthat::assert_that(all(views %in%
      (misty.results.from$importances.aggregated %>%
        dplyr::pull(view))) &
      all(views %in%
        (misty.results.to$importances.aggregated %>%
          dplyr::pull(view))),
    msg = "The requested views do not exist in both result lists."
    )
  }

  assertthat::assert_that(
    all(views %>% purrr::map_lgl(function(current.view) {
      rlang::is_empty(setdiff(
        misty.results.from$importances.aggregated %>%
          dplyr::filter(view == current.view) %>%
          dplyr::pull(Predictor) %>%
          unique(),
        misty.results.to$importances.aggregated %>%
          dplyr::filter(view == current.view) %>%
          dplyr::pull(Predictor) %>%
          unique()
      )) &
        rlang::is_empty(setdiff(
          misty.results.from$importances.aggregated %>%
            dplyr::filter(view == current.view) %>%
            dplyr::pull(Target) %>%
            unique(),
          misty.results.to$importances.aggregated %>%
            dplyr::filter(view == current.view) %>%
            dplyr::pull(Target) %>%
            unique()
        ))
    })),
    msg = "Incompatible predictors and targets."
  )

  inv <- sign((stringr::str_detect(trim.measure.type, "gain") |
    stringr::str_detect(trim.measure.type, "RMSE", negate = TRUE)) - 0.5)

  targets <- misty.results.from$improvements.stats %>%
    dplyr::filter(
      measure == trim.measure.type,
      inv * mean >= inv * trim
    ) %>%
    dplyr::pull(target)


  views %>% purrr::walk(function(current.view) {
    from.view.wide <- misty.results.from$importances.aggregated %>%
      dplyr::filter(view == current.view, Target %in% targets) %>%
      tidyr::pivot_wider(
        names_from = "Target",
        values_from = "Importance",
        -c(view, nsamples)
      )
    to.view.wide <- misty.results.to$importances.aggregated %>%
      dplyr::filter(view == current.view, Target %in% targets) %>%
      tidyr::pivot_wider(
        names_from = "Target",
        values_from = "Importance",
        -c(view, nsamples)
      )

    mask <- ((from.view.wide %>%
      dplyr::select(-Predictor)) < cutoff.from) &
      ((to.view.wide %>%
        dplyr::select(-Predictor)) >= cutoff.to)

    assertthat::assert_that(sum(mask, na.rm = TRUE) > 0,
      msg = paste0("All values are cut off while contrasting.")
    )

    masked <- ((to.view.wide %>%
      tibble::column_to_rownames("Predictor")) * mask)

    plot.data <- masked %>%
      dplyr::slice(which(masked %>% rowSums(na.rm = TRUE) > 0)) %>%
      dplyr::select(which(masked %>% colSums(na.rm = TRUE) > 0)) %>%
      tibble::rownames_to_column("Predictor") %>%
      tidyr::pivot_longer(names_to = "Target", values_to = "Importance", -Predictor)

    set2.blue <- "#8DA0CB"

    results.plot <- ggplot2::ggplot(plot.data, ggplot2::aes(x = Predictor, y = Target)) +
      ggplot2::geom_tile(ggplot2::aes(fill = Importance)) +
      ggplot2::scale_fill_gradient2(low = "white", mid = "white", high = set2.blue, midpoint = cutoff.to) +
      ggplot2::theme_classic() +
      ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90, hjust = 1)) +
      ggplot2::coord_equal() +
      ggplot2::ggtitle(current.view)

    print(results.plot)
  })

  invisible(misty.results.from)
}
saezlab/misty documentation built on March 25, 2024, 4:11 p.m.