R/modules.R

Defines functions pairwise_membership plot_pairwise_membership support_for_modules remove_duplicate_modules modules_from_samples match_modules match_modules_interval process_child_modules reconcile_module_names process_node_no_split modules_from_summary_networks modules_across_ages

Documented in match_modules modules_across_ages modules_from_samples modules_from_summary_networks pairwise_membership support_for_modules

#' Evolution of modules: find and match modules across time slices
#'
#' @param summary_networks List of reconstructed summary networks for each age (output from `get_summary_network()`).
#' @param tree The phylogeny of the symbiont clade (e.g. parasite, herbivore), a `phylo` object.
#' @param extant_modules A `moduleWeb` object defining the modules in the extant network.
#'
#' @return A list with:
#' 1) A list of 2 elements: 1.1) a data frame containing the module information for each node at each
#' network, 1.2) a data frame of correspondence between the original and the matched module names for each network;
#' 2) A list of 2 elements: 2.1) a data frame containing the module membership of each node at each age before matching, 2.2) a list of `moduleWeb` objects for each age.
#' @importFrom rlang .data
#' @export
#'
#' @examples
#' # read data that comes with the package
#' data_path <- system.file("extdata", package = "evolnets")
#' tree <- read_tree_from_revbayes(paste0(data_path,"/tree_pieridae.tre"))
#' host_tree <- ape::read.tree(paste0(data_path,"/host_tree_pieridae.phy"))
#' history <- read_history(paste0(data_path,"/history_thin_pieridae.txt"), burnin = 0)
#'
#' # get ancestral summary networks
#' ages <- c(60, 50, 40, 0)
#' at_ages <- posterior_at_ages(history, ages, tree, host_tree)
#' summary_networks <- get_summary_networks(at_ages, threshold = 0.5, weighted = TRUE)
#'
#' # find and match modules across ancestral and extant networks
#' all_mod <- modules_across_ages(summary_networks, tree)
modules_across_ages <- function(summary_networks, tree, extant_modules = NULL){

  # input checking
  if (!is.list(summary_networks) || !all(vapply(summary_networks, inherits, TRUE, 'matrix'))) {
    stop('`summary_networks` should be a list of matrices, usually generated by `get_summary_network`.')
  }
  if (!inherits(tree, 'phylo')) stop('`tree` should be a phylogeny of class `phylo`.')
  if (!is.null(extant_modules) & !inherits(extant_modules, 'moduleWeb')) {
    stop('`extant_modules` should be of class `moduleWeb`.')
  }

  unmatched_modules <- modules_from_summary_networks(summary_networks, extant_modules)
  matched_modules <- match_modules(summary_networks, unmatched_modules[[1]], tree)

  list <- list(matched_modules, unmatched_modules)
  names(list) <- c("matched_modules", "original_modules")

  return(list)
}


#' Identify modules for each summary network at each age
#'
#' This function is called within `modules_across_ages()`.
#'
#' @inheritParams modules_across_ages
#'
#' @return A list of 2 elements: 1) a data frame containing the module membership
#' of each node at each age before matching; 2) a list of `moduleWeb` objects for each age.
#' @export
#'
#' @examples
#' \dontrun{
#'  unmatched_modules <- modules_from_summary_networks(summary_networks)
#' }
modules_from_summary_networks <- function(summary_networks, extant_modules = NULL){

  # input checking
  if (!is.list(summary_networks) || !all(vapply(summary_networks, inherits, TRUE, 'matrix'))) {
    stop('`summary_networks` should be a list of matrices, usually generated by `get_summary_network`.')
  }
  if (!is.null(extant_modules) & !inherits(extant_modules, 'moduleWeb')) {
    stop('`extant_modules` should be of class `moduleWeb`.')
  }

  ages <- as.numeric(names(summary_networks))
  all_mod <- data.frame()
  summary_modules <- list()

  for (i in seq_along(summary_networks)) {

    if(!is.null(extant_modules) & names(summary_networks)[i] == "0"){
      wmod <- extant_modules
    } else {
      wmod <- mycomputeModules(summary_networks[[i]])
    }
    summary_modules[[i]] <- wmod
    wmod_list <- bipartite::listModuleInformation(wmod)[[2]]

    for (m in seq_along(wmod_list)) {
      members <- unlist(wmod_list[[m]])
      mtbl <- data.frame(
        name = members,
        age = rep(ages[i], length(members)),
        original_module = rep(m, length(members))
      )
      all_mod <- bind_rows(all_mod, mtbl)
    }
  }

  symbionts <- unique(unlist(lapply(summary_networks, rownames)))
  hosts <- unique(unlist(lapply(summary_networks, colnames)))

  all_mod <- dplyr::mutate(
    all_mod,
    type = dplyr::case_when(
      .data$name %in% symbionts ~ "symbiont",
      .data$name %in% hosts ~ "host",
      TRUE ~ "error"
    )
  )

  names(summary_modules) <- names(summary_networks)
  list_out <- list(all_mod,summary_modules)
  names(list_out) <- c("nodes_and_original_modules_per_age", "moduleWeb_objects")

  return(list_out)
}

# Processing nodes with no splits within interval
process_node_no_split = function(mod_df_sym_age, mod_df_sym, sym_mod_el, mods_strength,age_min) {

  # get all nodes with no assigned module
  # (this implies the node has no split within the interval)
  no_split <- mod_df_sym_age %>%
    dplyr::filter(.data$module_name == '0') %>%
    dplyr::pull(.data$name)

  # if there are 'no-split' nodes to handle
  if (length(no_split) > 0) {
    for (i in seq_along(no_split)) {
      # assign it to the module the node was in at age_min
      mod_no_split <- mod_df_sym %>%
        dplyr::filter(.data$age == age_min,
                      .data$name == no_split[i]) %>%
        dplyr::pull(.data$module_name)


      # define "descendant" here as the direct descendant at the time
      # equal to age_min, but does not represent a "node" in the tree
      if (length(mod_no_split) == 0) {
        # if the no-split node's descendant has no module, assign NA
        mod_df_sym_age[which(mod_df_sym_age$name == no_split[i]), 'module_name'] <- NA
      } else {
        # if the no-split node's descendant has a module, process it
        mod_df_sym_age[which(mod_df_sym_age$name == no_split[i]), 'module_name'] <- mod_no_split

        # calculate how strongly linked to each module the node was at age min
        mod_props_no_split <- sym_mod_el %>%
          dplyr::ungroup() %>%
          dplyr::filter(.data$name == no_split[i]) %>%
          dplyr::select(.data$prop_mod, .data$module_name) %>%
          tidyr::pivot_wider(names_from = .data$module_name,
                             values_from = .data$prop_mod,
                             names_prefix = "strength_")

        mod_df_sym_age[which(mod_df_sym_age$name == no_split[i]), 'strength'] <-
          mod_props_no_split[,paste0("strength_",mod_no_split)]

        no_split_mods <- colnames(mod_props_no_split)

        strength_no_split <- tibble::as_tibble(mods_strength) %>%
          dplyr::slice_head()
        strength_no_split[,no_split_mods] <- mod_props_no_split

        for (c in seq_along(strength_no_split)) {
          mod_df_sym_age[which(mod_df_sym_age$name == no_split[i]),
                         names(strength_no_split)[c]] <- strength_no_split[names(strength_no_split)[c]]
        }
      }
    }
  }
  return(mod_df_sym_age)
}


# Reconcile all processed sub/module names within an interval
reconcile_module_names = function(mod_df_sym_age, all_submodules, submod_letters) {
  # Make rest of module reconciliation/assignment a final function
  # get sum of strength to decide which module name to choose for each original module
  idx_name_strength <- mod_df_sym_age %>%
    dplyr::select(.data$original_module, .data$module_name, .data$strength) %>%
    dplyr::group_by(.data$original_module, .data$module_name) %>%
    dplyr::summarise(sum_strength = sum(.data$strength), .groups = 'drop_last') %>%
    dplyr::slice_max(.data$sum_strength) %>%
    dplyr::slice_sample()

  # Resolve conflicts in symbiont modules
  mods <- sort(unique(mod_df_sym_age$module_name))
  valid_mods <- sort(unique(idx_name_strength$module_name))  # unsure if these
  invalid_mods <- setdiff(mods, valid_mods)                  # are needed

  # first, give the same module name (the strongest) to all nodes in the same original module
  for (r in seq_len(nrow(mod_df_sym_age))) {
    #if(mod_df_sym_age$module_name[r] %in% invalid_mods){
    valid_mod <- idx_name_strength[
      which(idx_name_strength$original_module == mod_df_sym_age$original_module[r]),
      "module_name"
    ]
    mod_df_sym_age$module_name[r] <- dplyr::pull(valid_mod)
    #}
  }

  # Are there submodules among the valid modules?
  sub_mod_left <- valid_mods[valid_mods %in% all_submodules]
  sub_mod_left_updated <- sub_mod_left
  full_mods <- dplyr::setdiff(valid_mods, sub_mod_left)

  if (length(sub_mod_left) != 0) {
    for (l in submod_letters) {
      n_sub <- tidyselect::contains(l, vars = sub_mod_left) %>% length()
      # If there is only one submodule of a module left
      if (n_sub == 1) {
        # and there isn't another module with the same letter
        if (!(l %in% full_mods)) {
          # give it the module name (just letter)
          mod_df_sym_age[tidyselect::contains(l, vars = mod_df_sym_age$module_name), "module_name"] <- l
          idx_name_strength[tidyselect::contains(l, vars = idx_name_strength$module_name), "module_name"] <- l
          sub_mod_left_updated <- setdiff(sub_mod_left, sub_mod_left[tidyselect::contains(l, vars = sub_mod_left)])
        }
      }
    }
  }

  mods_left <- unique(mod_df_sym_age[['module_name']])

  # second, create submodules (at max age) for modules (at min age) that were split
  for (m in seq_along(mods_left)) {

    # how many original modules (submodules) are linked to this module name?
    n_originals <- idx_name_strength %>%
      dplyr::filter(.data$module_name == mods_left[m]) %>%
      dplyr::pull(.data$original_module) %>%
      length()

    # if more than 1, give numbers to module names
    if (n_originals > 1) {
      originals <- idx_name_strength %>%
        dplyr::filter(.data$module_name == mods_left[m]) %>%
        dplyr::pull(.data$original_module)

      for(o in seq_along(originals)){
        # if splitting a submodule, don't do nested splits
        if(mods_left[m] %in% sub_mod_left_updated){

          sub_let <- sub("\\.[0-9]*","",mods_left[m])
          nsubs <- tidyselect::contains(sub_let, vars = sub_mod_left_updated) %>% length()

          mod_df_sym_age[which(mod_df_sym_age$module_name == mods_left[m] &
                                 mod_df_sym_age$original_module == originals[o]),
                         'module_name'] <- paste0(sub_let,".", which(originals == originals[o]) + nsubs)
        } else{
          mod_df_sym_age[which(mod_df_sym_age$module_name == mods_left[m] &
                                 mod_df_sym_age$original_module == originals[o]),
                         'module_name'] <- paste0(mods_left[m],".", which(originals == originals[o]))
        }
      }
    }
  }
  return(mod_df_sym_age)
}


# Processing child 1/2 modules
process_child_modules = function(mod_df_sym, age_min, children, nodes_interval, sym_mod_el, mods_strength, node_treeio, child_idx)
{

  child_str = paste0("child",child_idx,"_mod")
  if (children[child_idx] %in% dplyr::filter(mod_df_sym, .data$age == age_min)$name) {
    mod_child <- mod_df_sym %>%
      dplyr::filter(
        .data$age == age_min,
        .data$name == children[child_idx]
      ) %>%
      dplyr::pull(.data$module_name)

    if (length(mod_child) == 0) {
      nodes_interval[which(nodes_interval$node == node_treeio),
                     child_str] <- NA
    } else{
      nodes_interval[which(nodes_interval$node == node_treeio),
                     child_str] <- mod_child

      mod_props_child <- sym_mod_el %>%
        dplyr::ungroup() %>%
        dplyr::filter(.data$name == children[child_idx]) %>%
        dplyr::select(.data$prop_mod, .data$module_name) %>%
        tidyr::pivot_wider(
          names_from = .data$module_name,
          values_from = .data$prop_mod,
          names_prefix = "strength_"
        )

      child_mods <- colnames(mod_props_child)

      child_strengths <- tibble::as_tibble(mods_strength) %>%
        dplyr::slice_head()
      child_strengths[,child_mods] <- mod_props_child

    }
  } else if (children[child_idx] %in% nodes_interval$label) {

    child_strengths <- nodes_interval %>%
      dplyr::filter(.data$label == children[child_idx]) %>%
      dplyr::select(dplyr::contains("strength_"))

    mod_child <- nodes_interval %>%
      dplyr::filter(.data$label == children[child_idx]) %>%
      dplyr::pull(.data$module_name)

    nodes_interval[which(nodes_interval$node == node_treeio),
                   child_str] <- mod_child
  } else {
    # if a child node at an interval boundary has _zero_ interactions
    # that are above a threshold criterion, that child will not be
    # assigned to _any_ module; in this case, assume the worst, and
    # let that child node contribute zero information to any possible
    # module, which effectively dilutes the summed children module
    # strength values (performed later)
    child_strengths <- NULL
    mod_child <- NA
  }

  ret = list()
  ret$nodes_interval = nodes_interval
  ret$child_strengths = child_strengths
  ret$mod_child = mod_child

  return(ret)
}


# Corresponds to each iterated pass for Steps 3a through 3i
match_modules_interval <- function(summary_networks, mod_df_sym, unmatched_modules, tree_t, age_min, age_max) {

  # all modules present at min age (starting with the present)
  modules_age <- mod_df_sym %>%
    dplyr::filter(.data$age == age_min) %>%
    dplyr::pull(.data$module_name) %>%
    unique()

  # find submodules
  all_submodules <- modules_age[tidyselect::matches("\\.[0-9]*", vars = modules_age)]
  submod_letters <- unique(sub("\\.[0-9]*", "", all_submodules))
  modules_age    <- sort(unique(c(modules_age, submod_letters)))

  # data frame with all nodes at the network at max age
  mod_df_sym_age <- unmatched_modules %>%
    dplyr::filter(.data$age == age_max, .data$type == "symbiont") %>%
    dplyr::mutate(module_name = '0', strength = 0)

  # attach module strength scores to data frame
  mods_strength <- matrix(data = 0, nrow = nrow(mod_df_sym_age), ncol = length(modules_age))
  colnames(mods_strength) <- paste0("strength_", modules_age)
  mod_df_sym_age <- cbind(mod_df_sym_age, mods_strength)

  # Data frame with all internal nodes in the interval between min and max age
  nodes_interval <- tree_t %>%
    dplyr::filter(.data$depth > age_min & .data$depth <= age_max) %>%
    dplyr::arrange(.data$depth) %>%
    dplyr::mutate(child1_mod = '0', child2_mod = '0', module_name = '0', strength = 0)
  mods_strength_interval <- matrix(data = 0, nrow = nrow(nodes_interval), ncol = length(modules_age))
  colnames(mods_strength_interval) <- paste0("strength_",modules_age)
  nodes_interval <- cbind(nodes_interval, mods_strength_interval)

  # get node depths (nodes reverse-sorted by age within interval)
  node_depth <- nodes_interval %>%
    dplyr::select(.data$node, .data$label, .data$depth) %>%
    dplyr::rename(node_name = .data$label)

  # find relative linkage of each symbiont to all modules in
  # the network at age_min
  net_age_min <- summary_networks[[as.character(age_min)]]

  # get module info for the network at age_min
  module_members_age <- unmatched_modules %>%
    dplyr::filter(.data$age == age_min)

  # Collect edge weights for age_min network for module matching procedure
  # sym_mod_el: "symbiont module edge-list"
  sym_mod_el <- net_age_min %>%
    as.data.frame() %>%
    tibble::rownames_to_column('name') %>%
    tidyr::pivot_longer(!.data$name, names_to = "host", values_to = "state") %>%
    dplyr::left_join(module_members_age, by = c("host" = "name")) %>%
    dplyr::filter(.data$state > 0) %>%
    dplyr::group_by(.data$name, .data$original_module) %>%  # host's original module only
    dplyr::summarise(degree_mod = sum(.data$state), .groups = 'drop_last') %>%
    dplyr::mutate(degree = sum(.data$degree_mod), prop_mod = .data$degree_mod/.data$degree) %>%
    dplyr::left_join(
      mod_df_sym %>%
        dplyr::filter(.data$age == age_min) %>%
        dplyr::select(.data$original_module, .data$module_name) %>%
        dplyr::distinct(),
      by = "original_module"
    )

  # Step 3d
  # match modules if nodes occur within the (age_min, age_max) interval
  if (nrow(node_depth) > 0) {

    # process each node within interval (age_min, age_max)
    for (n in seq_len(nrow(node_depth))) {

      # get the TreeIO index for the target node
      node_treeio <- node_depth$node[n]
      node_name_treeio <- node_depth$node_name[n]

      # get its children and their names (labels)
      children <- tree_t %>%
        dplyr::filter(.data$parent == node_treeio) %>%
        dplyr::pull(.data$label)

      # check if children are in between time intervals or in the next network
      # if they are in the network, get info from age_min
      ch1 = process_child_modules(mod_df_sym, age_min, children, nodes_interval, sym_mod_el, mods_strength, node_treeio, 1)
      nodes_interval = ch1$nodes_interval
      child1_strengths = ch1$child_strengths
      mod_child1 = ch1$mod_child

      ch2 = process_child_modules(mod_df_sym, age_min, children, nodes_interval, sym_mod_el, mods_strength, node_treeio, 2)
      nodes_interval = ch2$nodes_interval
      child2_strengths = ch2$child_strengths
      mod_child2 = ch2$mod_child

      # Combine children strengths
      children_strengths <- rbind(child1_strengths, child2_strengths)

      # Merge submodules of the same module (letter) when each children are in submodules of the same module - or merge a submodule into the main module
      letter_mod_children <- c(sub("\\.[0-9]*","",mod_child1), sub("\\.[0-9]*","",mod_child2))

      if (!is.na(mod_child1) & !is.na(mod_child2) & mod_child1 != mod_child2 &
          letter_mod_children[1] == letter_mod_children[2]) {

        mods_to_merge <- children_strengths %>%
          dplyr::select(tidyselect::ends_with(c(mod_child1, mod_child2))) %>%
          colnames()

        mod_letter <- sub("\\.[0-9]*","",mods_to_merge)[1]

        children_strengths <- children_strengths %>%
          dplyr::mutate({{mod_letter}} := .data[[mods_to_merge[1]]] + .data[[mods_to_merge[2]]]) %>%
          dplyr::select(-tidyselect::matches("\\.[0-9]*"))
      }

      if(!is.null(children_strengths)){
        if(nrow(children_strengths) == 1){
          # if a child node at an interval boundary has _zero_ interactions
          # that are above a threshold criterion, that child will not be
          # assigned to _any_ module; in this case, assume the worst, and
          # let that child node contribute zero information to any possible
          # module, which effectively dilutes the summed children module
          # strength values

          # user option for penalty? e.g. 1/2 or 1/n_modules

          # divide only strength by 2
          mean_strengths <- children_strengths/2
          # and find strongest module
          strongest_mod_idx <- max.col(as.matrix(mean_strengths), "random")
        } else if(nrow(children_strengths) == 2){
          # get average strength
          mean_strengths <- colMeans(children_strengths)
          # and find strongest module
          strongest_mod_idx <- max.col(t(as.matrix(mean_strengths)), "random")
        }

        # add all strengths to data frame of all nodes within the time interval
        for (c in seq_along(mean_strengths)) {
          nodes_interval[which(nodes_interval$node == node_treeio),
                         names(mean_strengths)[c]] <- mean_strengths[names(mean_strengths)[c]]
        }

        # Assign strongest module to node
        nodes_interval[which(nodes_interval$node == node_treeio),
                       'module_name'] <- sub("strength_","",names(mean_strengths)[strongest_mod_idx])
        nodes_interval[which(nodes_interval$node == node_treeio),
                       'strength'] <- mean_strengths[strongest_mod_idx]

        # Write strengths for nodes in the network at max age
        if (node_name_treeio %in% mod_df_sym_age$name) {
          for (c in seq_along(mean_strengths)) {
            mod_df_sym_age[which(mod_df_sym_age$name == node_name_treeio),
                           names(mean_strengths)[c]] <- mean_strengths[names(mean_strengths)[c]]
          }

          # also write the strongest module
          mod_df_sym_age[which(mod_df_sym_age$name == node_name_treeio),
                         'strength'] <- mean_strengths[strongest_mod_idx]
          mod_df_sym_age[which(mod_df_sym_age$name == node_name_treeio),
                         'module_name'] <- sub("strength_","",names(mean_strengths)[strongest_mod_idx])
        }
      }
    }
  }

  # Branches that don't split in this time interval #
  mod_df_sym_age = process_node_no_split(mod_df_sym_age, mod_df_sym, sym_mod_el, mods_strength, age_min)

  # Reconcile original and matched module names
  mod_df_sym_age = reconcile_module_names(mod_df_sym_age, all_submodules, submod_letters)

  # Update main mod_df_sym object
  mod_df_sym <- dplyr::bind_rows(mod_df_sym, mod_df_sym_age)

  return(mod_df_sym)
}

#' Match modules of different ancestral networks across time.
#'
#' This function is called within `modules_across_ages()` and gives the same name to modules from
#' ancestral networks at different ages that contain the same symbiont species or their parental species.
#'
#' @param summary_networks List of reconstructed summary networks for each age.
#' @param unmatched_modules A data frame outputted from `modules_from_summary_networks()` containing:
#' $name of the network node (hosts and symbionts), $age of the network, $original_module assigned to the
#' node, and $type of the node (either "symbiont" or "host).
#' @param tree The phylogeny of the symbiont clade (e.g. parasite, herbivore), a `phylo` object.
#'
#' @return A list of two elements: 1) a data frame containing the module information for each node at each
#' network; 2) a data frame of correspondence between the original and the matched module names for each network.
#' @export
#' @importFrom rlang .data
#'
#' @examples
#' \dontrun{
#' unmatched_modules <- modules_from_summary_networks(summary_networks)
#' matched_modules <- match_modules(summary_networks, unmatched_modules[[1]], tree)
#' }
match_modules <- function(summary_networks, unmatched_modules, tree){

  # Step 1: input checking
  if (!is.list(summary_networks) || !all(vapply(summary_networks, inherits, TRUE, 'matrix'))) {
    stop('`summary_networks` should be a list of matrices, usually generated by `get_summary_network`.')
  }
  if (
    !is.data.frame(unmatched_modules) ||
    !all(c('name', 'age', 'original_module', 'type') %in% names(unmatched_modules))
  ) {
    stop('`unmatched modules` should be a data.frame with "name", "age", "original_module" and "type" columns, usually generated by `modules_from_summary_networks`.')
  }
  if (!inherits(tree, 'phylo')) stop('`tree` should be a phylogeny of class `phylo`.')

  ages <- as.numeric(names(summary_networks))
  # Make sure that ages are ordered present -> past and start at present
  ages <- sort(ages)

  if (ages[1] > 0) stop('`summary_networks` has to include the extant network (age = 0).')

  # Step 2: prepare for adding children module information
  mod_df_sym <- unmatched_modules %>%
    dplyr::filter(.data$age == 0, .data$type == "symbiont") %>%
    dplyr::mutate(module_name = paste0("M",.data$original_module)
    )

  mod_df_host <- unmatched_modules %>%
    dplyr::filter(.data$type == "host")

  # Get a tibble with tree information for finding child nodes
  tree_t <- dplyr::mutate(
    tibble::as_tibble(tree),
    depth = round(max(ape::node.depth.edgelength(tree)) - ape::node.depth.edgelength(tree), digits = 5)
  )

  # Step 3: Reverse inherit modules
  # Find all nodes between this time slice and the previous one
  for (t in seq_len(length(ages) - 1)) {

    age_min <- ages[t]
    age_max <- ages[t + 1]

    mod_df_sym = match_modules_interval(summary_networks, mod_df_sym, unmatched_modules, tree_t, age_min, age_max)
  }

  # set modules for hosts
  mod_idx_name <- mod_df_sym %>%
    dplyr::select(.data$age, .data$original_module, .data$module_name) %>%
    dplyr::distinct()

  mod_df_host <- dplyr::left_join(mod_df_host, mod_idx_name, by = c("age","original_module"))

  mod_df <- bind_rows(mod_df_sym, mod_df_host) %>%
    dplyr::rename(module = .data$module_name)

  # TODO: replace with user given module names
  list <- list(mod_df, mod_idx_name)
  names(list) <- c("nodes_and_modules_per_age", "original_and_matched_module_names")

  return(list)

}



#' Find modules in networks sampled across MCMC at specific time slices
#'
#' @param sampled_networks List of sampled networks at time slices produced by
#'   `get_sampled_networks()`.
#'
#' @return Data frame with module membership information for each sampled network at each time
#'   slice.
#' @importFrom dplyr mutate group_by distinct summarize left_join case_when select bind_rows tibble
#'   n
#' @importFrom bipartite empty
#' @export
#'
#' @examples
#' \dontrun{
#'   data_path <- system.file("extdata", package = "evolnets")
#'   tree <- read_tree_from_revbayes(paste0(data_path,"/tree_pieridae.tre"))
#'   host_tree <- ape::read.tree(paste0(data_path,"/host_tree_pieridae.phy"))
#'   history <- read_history(paste0(data_path,"/history_thin_pieridae.txt"), burnin = 0)
#'
#'   ages <- c(60,50,40,0)
#'   at_ages <- posterior_at_ages(history, ages, tree, host_tree)
#'   sampled_networks <- get_sampled_networks(at_ages)
#
#'   mod_samples <- modules_from_samples(sampled_networks)
#' }
modules_from_samples <- function(sampled_networks) {

  ages <- as.numeric(names(sampled_networks))
  if (!(0 %in% ages)) stop('the last element in `sampled_networks` has to be the present (age = 0).')

  Qsamples <- tibble()
  mod_samples <- tibble()

  nsamp <- dim(sampled_networks[[1]])[1]

  for (a in seq_len(length(ages) - 1)) {
    for (i in seq_len(nsamp)) {
      net <- sampled_networks[[a]][i, , ]

      if (ncol(empty(net)) > 1) {

        mod <- mycomputeModules(net)

        q <- mod@likelihood
        Qsamples <- bind_rows(Qsamples, tibble(age = ages[a], sample = i, Q = q))

        mod_list <- bipartite::listModuleInformation(mod)[[2]]
        nmod <- length(mod_list)
        for (m in seq_len(nmod)) {
          members <- unlist(mod_list[[m]])
          mtbl <- tibble(
            name = members,
            age = rep(ages[a], length(members)),
            sample = rep(i, length(members)),
            original_module = rep(m, length(members))
          )

          mod_samples <- bind_rows(mod_samples, mtbl)
        }
      }
    }
  }

  # remove replicate modules in fully connected networks with identical modules
  duplicates_removed <- remove_duplicate_modules(mod_samples)

  duplicates_removed

}


remove_duplicate_modules <- function(mod_samples) {

  duplicates_removed <- mod_samples %>%
    dplyr::group_by(.data$age, .data$sample) %>%
    dplyr::distinct(.data$name) %>%
    dplyr::summarize(u = dplyr::n()) %>%
    dplyr::left_join(mod_samples %>% dplyr::group_by(.data$age, .data$sample) %>% dplyr::summarize(n = dplyr::n())) %>%
    dplyr::mutate(problem = dplyr::case_when(u != .data$n ~ "YES", u == .data$n ~ "NO")) %>%
    dplyr::left_join(mod_samples) %>%
    dplyr::mutate(original_module = dplyr::case_when(.data$problem == "YES" ~ 1,
                                       .data$problem == "NO" ~ as.numeric(.data$original_module))) %>%
    dplyr::distinct() %>%
    dplyr::select(.data$name, .data$age, .data$sample, .data$original_module)

  duplicates_removed

}



#' Calculate support for modules from summary networks based on modules of sampled networks
#'
#' Validate modules from summary networks by calculating the frequency with
#' which pairs of nodes are placed in the same module in networks sampled across MCMC
#'
#' @param mod_samples Data frame produced by `modules_from_samples()` containing module membership
#'   of each node for each sampled network at each time slice before the present.
#' @param modules_across_ages Data frame containing the module information for the summary network.
#'   A `list` object returned from `modules_across_ages()` or a `data.frame` object defining the
#'   modules in the networks. If a `data.frame` is passed, it must contain three columns:
#'   `$age` - the age of the network,
#'   `$name` - taxon names,
#'   `$module` - the module the taxon was assigned to.
#' @param threshold Minimum frequency with which two nodes are placed in the same module to consider it supported. Only pairs with frequency higher than this threshold are plotted with the color of the module from the summary network.
#' @param edge_list Logical. Whether to return a list of edge lists or a list of matrices of pairwise frequency.
#' @param include_all Logical. Include all nodes or only those present at the time slice?
#' @param colors Optional. Vector of module colors in the plot.
#' @param module_levels Optional. Order of modules in the color vector.
#' @param axis_text Logical. Plot taxon names?
#'
#' @return A list containing:
#' 1) `plot`: A plot of pairwise frequency with which the two nodes are placed in the same module in the ancestral network for each time slice in `ages`. Cells in the diagonal show how often a host is sampled in a network (symbionts are always present);
#' 2) `pairwise_membership`: A list of edge lists or matrices with the pairwise frequencies at each time slice;
#' 3) `mean_support`: A list of mean and geometric mean pairwise frequency for each module at each time slice.
#' @importFrom dplyr mutate case_when arrange pull arrange bind_rows distinct filter left_join inner_join full_join select summarize desc
#' @importFrom tidyr complete
#' @importFrom rlang .data
#' @importFrom patchwork wrap_plots
#' @importFrom stats reorder
#' @export
#'
#' @examples
#' \dontrun{
#'   ages <- c(60, 50, 40, 0)
#'   at_ages <- posterior_at_ages(history, ages, tree, host_tree)
#'
#'   weighted_net_50 <- get_summary_network(at_ages, pt = 0.5, weighted = TRUE)
#'   all_mod <- modules_across_ages(weighted_net_50, tree)
#'
#'   # find modules for each sampled network
#'   mod_samples <- modules_from_samples(at_ages)
#'
#'   # calculate support
#'   support <- support_for_modules(mod_samples, all_mod)
#'   support$plot
#'   support$means
#' }
support_for_modules <- function(
  mod_samples, modules_across_ages, threshold = 0.7, edge_list = TRUE, include_all = FALSE,
  colors = NULL, module_levels = NULL, axis_text = FALSE
) {

  if (!is.null(modules_across_ages) && (
    !inherits(modules_across_ages, 'list') && !inherits(modules_across_ages, 'data.frame')
  )) {
    stop('`modules_across_ages` should be of class `list` or `data.frame`.')
  }
  if (inherits(modules_across_ages, 'list')) {
    modules_across_ages <- modules_across_ages$matched_modules$nodes_and_modules_per_age
  }
  if (!all(unique(mod_samples$age) %in% unique(modules_across_ages$age))) {
    stop('`modules_across_ages` must contain all time slices in `mod_samples`.')
  }

  ages <- rev(unique(mod_samples$age))

  # calculate pairwise module membership
  pair_mod_tbl <- pairwise_membership(mod_samples, ages, edge_list)

  # make heatmaps
  pair_heatmaps <- list()

  for (i in seq_along(ages)) {

    Edge_list <- tibble::tibble(pair_mod_tbl[[i]]) %>%
      purrr::when(include_all ~
                    full_join(., modules_across_ages %>% filter(.data$age == ages[i]) %>% select(.data$name, .data$module), by = c("row" = "name")) %>%
                    full_join(modules_across_ages %>% filter(.data$age == ages[i]) %>% select(.data$name, .data$module), by = c("col" = "name")),
                  ~ inner_join(., modules_across_ages %>% filter(.data$age == ages[i]) %>% select(.data$name, .data$module), by = c("row" = "name")) %>%
                    inner_join(modules_across_ages %>% filter(.data$age == ages[i]) %>% select(.data$name, .data$module), by = c("col" = "name"))) %>%
      mutate(
        Module = ifelse(.data$module.x == .data$module.y, .data$module.x, NA),
        supported_mod = case_when(.data$freq >= threshold ~ .data$Module)
      )

    order <- Edge_list %>% arrange(.data$module.x) %>% pull(.data$row) %>% unique()

    for_heatmap <- mutate(
      Edge_list,
      row = factor(.data$row, levels = order),
      col = factor(.data$col, levels = order)
    )

    pair_heatmaps[[i]] <- data.frame(for_heatmap)
  }

  names(pair_heatmaps) <- ages

  # calculate mean support
  means <- list()

  for (i in seq_along(ages)) {
    means_age <- data.frame(module = NULL, mean = NULL, geo_mean = NULL)
    mods <- sort(unique(pair_heatmaps[[i]]$module.x))

    for (m in mods) {
      within_mod <- filter(pair_heatmaps[[i]], .data$Module == m)
      mean <- mean(within_mod$freq)
      gmean <- exp(mean(log(within_mod$freq)))

      means_age <- bind_rows(means_age, data.frame(module = m, mean = mean, geo_mean = gmean))
    }

    means[[i]] <- data.frame(means_age)
  }

  names(means) <- ages

  if(is.null(module_levels)) {
    module_levels <- modules_across_ages %>%
      dplyr::pull(.data$module) %>%
      unique() %>%
      sort()
  }
  if(is.null(colors)) colors <- scales::hue_pal()(length(module_levels))

  plot <- plot_pairwise_membership(pair_heatmaps, ages, colors = colors, module_levels = module_levels, axis_text = axis_text)

  support_list <- list(plot, pair_heatmaps, means)
  names(support_list) <- c("plot", "pairwise_membership", "mean_support")

  support_list

}


plot_pairwise_membership <- function(pair_heatmaps, ages, colors, module_levels, axis_text){

  nages <- length(pair_heatmaps)
  plot_list <- list()

  for (a in seq_len(nages)) {
    heatmap <- pair_heatmaps[[a]]

    p <- ggplot2::ggplot(heatmap,
      ggplot2::aes(x = .data$row,
                   y = reorder(.data$col,desc(.data$col)),
                   fill = factor(.data$supported_mod, levels = module_levels),
                   alpha = .data$freq)) +
      ggplot2::geom_tile() +
      ggplot2::theme_bw() +
      ggplot2::scale_x_discrete(drop = FALSE) +
      ggplot2::scale_y_discrete(drop = FALSE) +
      ggplot2::scale_alpha(limits = c(0,1), range = c(0.1,1), name = "Frequency") +
      ggplot2::labs(fill = "Module") +
      ggplot2::theme(
        axis.title.x = ggplot2::element_blank(),
        axis.title.y = ggplot2::element_blank(),
        axis.text.x = ggplot2::element_text(angle = 270)) +
      ggplot2::ggtitle(paste0(ages[a]," Ma"))

    if (!axis_text) {
      p <- p + ggplot2::theme(
        axis.text.x = ggplot2::element_blank(),
        axis.text.y = ggplot2::element_blank(),
        axis.ticks = ggplot2::element_blank()
      )
    }

    p <- p + ggplot2::scale_fill_manual(values = colors, na.value = "grey40", drop = F)

    plot_list[[a]] <- p

  }

  return(plot_list)

}


#' Calculate frequency that pairs of nodes fall within the same module
#'
#' @param mod_samples Output from `modules_from_samples()`
#' @param ages Vector of network ages
#' @param edge_list Should output be an edge list?
#'
#' @return An edge list with the frequency that each pair of nodes in the network is placed in the same module across network samples.
#' @export
#'
#' @examples
#' \dontrun{
#' pairwise_membership(mod_samples, c(0))
#' }
pairwise_membership <- function(mod_samples, ages, edge_list = TRUE) {

  result <- list()

  for (a in seq_along(ages)) {
    taxa <- mod_samples %>%
      filter(.data$age == ages[a]) %>%
      distinct(.data$name)
    ntaxa <- nrow(taxa)

    heat <- matrix(data = 0, nrow = ntaxa, ncol = ntaxa)
    rownames(heat) <- colnames(heat) <- taxa$name

    mod_samples_at_age <- mod_samples %>%
      filter(.data$age == ages[a]) %>%
      distinct(.data$sample) %>%
      pull(.data$sample)

    for (i in mod_samples_at_age) {
      table <- filter(mod_samples, .data$age == ages[a], sample == i)
      mods <- unique(table$original_module)

      for (m in mods) {
        module <- filter(table, .data$original_module == m)

        mod_mat <- outer(rownames(heat) %in% module$name, colnames(heat) %in% module$name, '&')
        heat <- heat + as.numeric(mod_mat)
      }
    }

    heat <- heat / length(mod_samples_at_age)

    if (edge_list) {
      tbl <- heat %>%
        as.data.frame() %>%
        tibble::rownames_to_column('row') %>%
        tidyr::pivot_longer(-row, names_to = 'col', values_to = 'freq') %>%
        dplyr::arrange(row, col)
      result[[a]] <- tbl
    } else {
      result[[a]] <- heat
    }
  }

  return(result)
}

module.y <- module.x <- freq <- name <- NULL



#'
#' #' Match modules of different ancestral networks across time.
#' #'
#' #' This function is called within `modules_across_ages()` and gives the same name to modules from
#' #' ancestral networks at different ages that contain the same symbiont species or their parental species.
#' #'
#' #' @param summary_networks List of reconstructed summary networks for each age.
#' #' @param unmatched_modules A data frame outputted from `modules_from_summary_networks()` containing:
#' #' $name of the network node (hosts and symbionts), $age of the network, $original_module assigned to the
#' #' node, and $type of the node (either "symbiont" or "host).
#' #' @param tree The phylogeny of the symbiont clade (e.g. parasite, herbivore), a `phylo` object.
#' #'
#' #' @return A list of two elements: 1) a data frame containing the module information for each node at each
#' #' network; 2) a data frame of correspondence between the original and the matched module names for each network.
#' #' @export
#' #' @importFrom rlang .data
#' #'
#' #' @examples
#' #' \dontrun{
#' #' unmatched_modules <- modules_from_summary_networks(summary_networks)
#' #' matched_modules <- match_modules(summary_networks, unmatched_modules[[1]], tree)
#' #' }
#' match_modules_old <- function(summary_networks, unmatched_modules, tree){
#'
#'   # input checking
#'   if (!is.list(summary_networks) || !all(vapply(summary_networks, inherits, TRUE, 'matrix'))) {
#'     stop('`summary_networks` should be a list of matrices, usually generated by `get_summary_network`.')
#'   }
#'   if (
#'     !is.data.frame(unmatched_modules) ||
#'     !all(c('name', 'age', 'original_module', 'type') %in% names(unmatched_modules))
#'   ) {
#'     stop('`unmatched modules` should be a data.frame with "name", "age", "original_module" and "type" columns, usually generated by `modules_from_summary_networks`.')
#'   }
#'   if (!inherits(tree, 'phylo')) stop('`tree` should be a phylogeny of class `phylo`.')
#'
#'   ages <- as.numeric(names(summary_networks))
#'   # Make sure that ages are ordered present -> past and start at present
#'   ages <- sort(ages)
#'
#'   if (ages[1] > 0) stop('`summary_networks` has to include the extant network (age = 0).')
#'
#'   # prepare for adding children module information
#'   mod_df_sym <- unmatched_modules %>%
#'     dplyr::filter(.data$age == 0, .data$type == "symbiont") %>%
#'     dplyr::mutate(#child1_mod = '0',
#'       #child2_mod = '0',
#'       module_name = LETTERS[.data$original_module]
#'     )
#'
#'   mod_df_host <- unmatched_modules %>%
#'     dplyr::filter(.data$type == "host")
#'
#'   # Get a tibble with tree information for finding child nodes
#'   tree_t <- dplyr::mutate(
#'     tibble::as_tibble(tree),
#'     depth = round(max(ape::node.depth.edgelength(tree)) - ape::node.depth.edgelength(tree), digits = 5)
#'   )
#'
#'   # Find all nodes between this time slice and the previous one
#'   # Reverse inherit modules
#'   for (t in seq_len(length(ages) - 1)) {
#'
#'     age_min <- ages[t]
#'     age_max <- ages[t + 1]
#'
#'     # all modules present at min age (starting with the present)
#'     modules_age <- mod_df_sym %>%
#'       dplyr::filter(.data$age == age_min) %>%
#'       dplyr::pull(.data$module_name) %>%
#'       unique()
#'
#'     # ? need to update matches() ? #### can't remember why
#'     all_submodules <- modules_age[tidyselect::matches("\\d", vars = modules_age)]
#'     submod_letters <- unique(sub("\\d", "", all_submodules))
#'
#'     modules_age <- sort(unique(c(modules_age, submod_letters)))
#'
#'     # Data frame with all nodes at the network at max age
#'     mod_df_sym_age <- unmatched_modules %>%
#'       dplyr::filter(.data$age == age_max,
#'                     .data$type == "symbiont") %>%
#'       dplyr::mutate(#child1_mod = '0',
#'         #child2_mod = '0',
#'         module_name = '0',
#'         strength = 0)
#'     mods_strength <- matrix(data = 0, nrow = nrow(mod_df_sym_age), ncol = length(modules_age))
#'     colnames(mods_strength) <- paste0("strength_", modules_age)
#'     mod_df_sym_age <- cbind(mod_df_sym_age, mods_strength)
#'
#'     # Data frame with all internal nodes in the interval between min and max age
#'     nodes_interval <- tree_t %>%
#'       dplyr::filter(.data$depth > age_min & .data$depth <= age_max) %>%
#'       dplyr::arrange(.data$depth) %>%
#'       # dplyr::left_join(dplyr::select(filter(mod_df_sym, .data$age == ages[t]),
#'       #                                c("name","original_module","module_name")),
#'       #                  by = c("label" = "name")) %>%
#'       dplyr::mutate(child1_mod = '0',
#'                     child2_mod = '0',
#'                     module_name = '0',
#'                     strength = 0)
#'     mods_strength_interval <- matrix(data = 0,nrow = nrow(nodes_interval), ncol = length(modules_age))
#'     colnames(mods_strength_interval) <- paste0("strength_",modules_age)
#'     nodes_interval <- cbind(nodes_interval, mods_strength_interval)
#'
#'     all_nodes <- nodes_interval[['node']]
#'
#'     node_depth <- nodes_interval %>%
#'       dplyr::select(.data$node, .data$label, .data$depth) %>%
#'       dplyr::rename(node_name = .data$label)
#'
#'     # find relative linkage of each symbiont to all modules in
#'     # the network at age_min
#'     net_age_min <- summary_networks[[as.character(age_min)]]
#'
#'     module_members_age <- unmatched_modules %>%
#'       dplyr::filter(age == age_min)
#'
#'     sym_mod_el <- net_age_min %>%
#'       as.data.frame() %>%
#'       tibble::rownames_to_column('name') %>%
#'       tidyr::pivot_longer(!.data$name, names_to = "host", values_to = "state") %>%
#'       dplyr::left_join(module_members_age, by = c("host" = "name")) %>%
#'       dplyr::filter(.data$state > 0) %>%
#'       dplyr::group_by(.data$name, .data$original_module) %>%             # host's original module only
#'       dplyr::summarise(degree_mod = sum(.data$state), .groups = 'drop_last') %>%
#'       dplyr::mutate(
#'         degree = sum(.data$degree_mod),
#'         prop_mod = .data$degree_mod/.data$degree
#'       ) %>%
#'       dplyr::left_join(
#'         mod_df_sym %>%
#'           dplyr::filter(.data$age == age_min) %>%
#'           dplyr::select(.data$original_module, module_name) %>%
#'           dplyr::distinct(),
#'         by = "original_module"
#'       )
#'
#'     if (nrow(node_depth) > 0) {
#'
#'       for (n in seq_len(nrow(node_depth))) {
#'         node_treeio <- node_depth$node[n]
#'
#'         children <- tree_t %>%
#'           dplyr::filter(.data$parent == node_treeio) %>%
#'           dplyr::pull(.data$label)
#'
#'         # check if children are in between time intervals or in the next network
#'         # if they are in the network, get info from age_min
#'         if (children[1] %in% dplyr::filter(mod_df_sym, .data$age == age_min)$name) {
#'
#'           mod_child1 <- mod_df_sym %>%
#'             dplyr::filter(
#'               .data$age == age_min,
#'               .data$name == children[1]
#'             ) %>%
#'             dplyr::pull(.data$module_name)
#'
#'           if (length(mod_child1) == 0) {
#'             nodes_interval[which(nodes_interval$node == node_depth$node[n]),
#'                            'child1_mod'] <- NA
#'           } else{
#'             nodes_interval[which(nodes_interval$node == node_depth$node[n]),
#'                            'child1_mod'] <- mod_child1
#'
#'             # mod_idx_child1 <- mod_df_sym %>%
#'             #   dplyr::filter(.data$age == age_min,
#'             #                 .data$module_name  == mod_child1) %>%
#'             #   dplyr::pull(.data$original_module) %>%
#'             #   unique()
#'
#'             mod_props_child1 <- sym_mod_el %>%
#'               dplyr::ungroup() %>%
#'               dplyr::filter(.data$name == children[1]) %>%
#'               dplyr::select(.data$prop_mod, .data$module_name) %>%
#'               tidyr::pivot_wider(
#'                 names_from = .data$module_name,
#'                 values_from = .data$prop_mod,
#'                 names_prefix = "strength_"
#'               )
#'
#'             child1_mods <- colnames(mod_props_child1)
#'
#'             child1_strengths <- tibble::as_tibble(mods_strength) %>%
#'               dplyr::slice_head()
#'             child1_strengths[,child1_mods] <- mod_props_child1
#'
#'           }
#'         } else if (children[1] %in% nodes_interval$label) {
#'           child1_strengths <- nodes_interval %>%
#'             dplyr::filter(.data$label == children[1]) %>%
#'             dplyr::select(dplyr::contains("strength_"))
#'
#'           mod_child1 <- nodes_interval %>%
#'             dplyr::filter(.data$label == children[1]) %>%
#'             dplyr::pull(.data$module_name)
#'
#'           nodes_interval[which(nodes_interval$node == node_depth$node[n]),
#'                          'child1_mod'] <- mod_child1
#'         } else {
#'           child1_strengths <- NULL
#'           mod_child1 <- NA
#'         }
#'
#'         if (children[2] %in% dplyr::filter(mod_df_sym, .data$age == age_min)$name) {
#'           mod_child2 <- mod_df_sym %>%
#'             dplyr::filter(.data$age == age_min,
#'                           .data$name == children[2]) %>%
#'             dplyr::pull(module_name)
#'
#'           if (length(mod_child2) == 0) {
#'             nodes_interval[which(nodes_interval$node == node_depth$node[n]),
#'                            'child2_mod'] <- NA
#'           } else {
#'             nodes_interval[which(nodes_interval$node == node_depth$node[n]),
#'                            'child2_mod'] <- mod_child2
#'
#'             # mod_idx_child2 <- mod_df_sym %>%
#'             #   dplyr::filter(.data$age == age_min,
#'             #                 .data$module_name  == mod_child2) %>%
#'             #   dplyr::pull(.data$original_module) %>%
#'             #   unique()
#'
#'             mod_props_child2 <- sym_mod_el %>%
#'               dplyr::ungroup() %>%
#'               dplyr::filter(.data$name == children[2]) %>%
#'               dplyr::select(.data$prop_mod, .data$module_name) %>%
#'               tidyr::pivot_wider(names_from = .data$module_name,
#'                                  values_from = .data$prop_mod,
#'                                  names_prefix = "strength_")
#'
#'             child2_mods <- colnames(mod_props_child2)
#'
#'             child2_strengths <- tibble::as_tibble(mods_strength) %>%
#'               dplyr::slice_head()
#'             child2_strengths[,child2_mods] <- mod_props_child2
#'
#'           }
#'         } else if (children[2] %in% nodes_interval$label) {
#'           child2_strengths <- nodes_interval %>%
#'             dplyr::filter(.data$label == children[2]) %>%
#'             dplyr::select(dplyr::contains("strength_"))
#'
#'           mod_child2 <- nodes_interval %>%
#'             dplyr::filter(.data$label == children[2]) %>%
#'             dplyr::pull(.data$module_name)
#'
#'           nodes_interval[which(nodes_interval$node == node_depth$node[n]),
#'                          'child2_mod'] <- mod_child2
#'         } else {
#'           child2_strengths <- NULL
#'           mod_child2 <- NA
#'         }
#'
#'         # Combine children strengths
#'         children_strengths <- rbind(child1_strengths, child2_strengths)
#'
#'         # Merge submodules of the same module (letter) when each children are in submodules of the same module
#'         letter_mod_children <- c(substring(mod_child1, 1, 1), substring(mod_child2, 1, 1))
#'
#'         if (!is.na(mod_child1) &
#'             !is.na(mod_child2) &
#'             mod_child1 != mod_child2 &
#'             letter_mod_children[1] == letter_mod_children[2]
#'             ) {
#'
#'           submodules <- children_strengths %>%
#'             dplyr::select(tidyselect::matches("\\d")) %>%
#'             colnames()
#'           mod_letter <- sub("\\d","",submodules)[1]
#'
#'           children_strengths <- children_strengths %>%
#'             dplyr::mutate({{mod_letter}} := .data[[submodules[1]]] + .data[[submodules[2]]]) %>%
#'             dplyr::select(-tidyselect::matches("\\d"))
#'         }
#'
#'
#'         if(!is.null(children_strengths)){
#'           if(nrow(children_strengths) == 1){
#'             # divide only strength by 2
#'             mean_strengths <- children_strengths/2
#'             # and find strongest module
#'             strongest_mod_idx <- max.col(as.matrix(mean_strengths), "random")
#'           } else if(nrow(children_strengths) == 2){
#'             # get average strength
#'             mean_strengths <- colMeans(children_strengths)
#'             # and find strongest module
#'             strongest_mod_idx <- max.col(t(as.matrix(mean_strengths)), "random")
#'           }
#'
#'           # add all strengths to data frame of all nodes within the time interval
#'           for (c in seq_along(mean_strengths)) {
#'             nodes_interval[which(nodes_interval$node == node_depth$node[n]),
#'                            names(mean_strengths)[c]] <- mean_strengths[names(mean_strengths)[c]]
#'           }
#'
#'           # Assign strongest module to node
#'           nodes_interval[which(nodes_interval$node == node_depth$node[n]),
#'                          'module_name'] <- sub("strength_","",names(mean_strengths)[strongest_mod_idx])
#'           nodes_interval[which(nodes_interval$node == node_depth$node[n]),
#'                          'strength'] <- mean_strengths[strongest_mod_idx]
#'
#'           # Write strengths for nodes in the network at max age
#'           if (node_depth$node_name[n] %in% mod_df_sym_age$name) {
#'             for (c in seq_along(mean_strengths)) {
#'               mod_df_sym_age[which(mod_df_sym_age$name == node_depth$node_name[n]),
#'                              names(mean_strengths)[c]] <- mean_strengths[names(mean_strengths)[c]]
#'             }
#'
#'             # also write the strongest module
#'             mod_df_sym_age[which(mod_df_sym_age$name == node_depth$node_name[n]),
#'                            'strength'] <- mean_strengths[strongest_mod_idx]
#'             mod_df_sym_age[which(mod_df_sym_age$name == node_depth$node_name[n]),
#'                            'module_name'] <- sub("strength_","",names(mean_strengths)[strongest_mod_idx])
#'           }
#'         }
#'       }
#'     }
#'
#'     # Branches that don't split in this time interval #
#'     no_split <- mod_df_sym_age %>%
#'       dplyr::filter(.data$module_name == '0') %>%
#'       dplyr::pull(.data$name)
#'
#'     if (length(no_split) > 0) {
#'       for (i in seq_along(no_split)) {
#'
#'         # assign it to the module the node was in at age_min
#'         mod_no_split <- mod_df_sym %>%
#'           dplyr::filter(.data$age == age_min,
#'                         .data$name == no_split[i]) %>%
#'           dplyr::pull(.data$module_name)
#'
#'         if (length(mod_no_split) == 0) {
#'           mod_df_sym_age[which(mod_df_sym_age$name == no_split[i]), 'module_name'] <- NA
#'         } else{
#'
#'           mod_df_sym_age[which(mod_df_sym_age$name == no_split[i]), 'module_name'] <- mod_no_split
#'
#'           # mod_idx_no_split <- unmatched_modules %>%
#'           #   dplyr::filter(.data$age == age_min,
#'           #                 .data$name == no_split[i]) %>%
#'           #   dplyr::pull(.data$original_module)
#'
#'           # calculate how strongly linked to each module the node was at age min
#'           mod_props_no_split <- sym_mod_el %>%
#'             dplyr::ungroup() %>%
#'             dplyr::filter(.data$name == no_split[i]) %>%
#'             dplyr::select(.data$prop_mod, .data$module_name) %>%
#'             tidyr::pivot_wider(names_from = .data$module_name,
#'                                values_from = .data$prop_mod,
#'                                names_prefix = "strength_")
#'
#'           mod_df_sym_age[which(mod_df_sym_age$name == no_split[i]), 'strength'] <-
#'             mod_props_no_split[,paste0("strength_",mod_no_split)]
#'
#'           no_split_mods <- colnames(mod_props_no_split)
#'
#'           strength_no_split <- tibble::as_tibble(mods_strength) %>%
#'             dplyr::slice_head()
#'           strength_no_split[,no_split_mods] <- mod_props_no_split
#'
#'           for (c in seq_along(strength_no_split)) {
#'             mod_df_sym_age[which(mod_df_sym_age$name == no_split[i]),
#'                            names(strength_no_split)[c]] <- strength_no_split[names(strength_no_split)[c]]
#'           }
#'         }
#'       }
#'     }
#'
#'     # get sum of strength to decide which module name to choose for each original module
#'     idx_name_strength <- mod_df_sym_age %>%
#'       dplyr::select(
#'         .data$original_module,
#'         .data$module_name,
#'         .data$strength
#'       ) %>%
#'       dplyr::group_by(.data$original_module, .data$module_name) %>%
#'       dplyr::summarise(sum_strength = sum(.data$strength), .groups = 'drop_last') %>%
#'       dplyr::slice_max(.data$sum_strength) %>%
#'       dplyr::slice_sample()
#'
#'     # Resolve conflicts in symbiont modules
#'     mods <- sort(unique(mod_df_sym_age$module_name))
#'     valid_mods <- sort(unique(idx_name_strength$module_name))  # unsure if these
#'     invalid_mods <- setdiff(mods, valid_mods)                  # are needed
#'
#'     # first, give the same module name (the strongest) to all nodes in the same original module
#'     for (r in seq_len(nrow(mod_df_sym_age))) {
#'       #if(mod_df_sym_age$module_name[r] %in% invalid_mods){
#'       valid_mod <- idx_name_strength[
#'         which(idx_name_strength$original_module == mod_df_sym_age$original_module[r]),
#'         "module_name"
#'       ]
#'       mod_df_sym_age$module_name[r] <- dplyr::pull(valid_mod)
#'       #}
#'     }
#'
#'     # Are there submodules among the valid modules?
#'     sub_mod_left <- valid_mods[valid_mods %in% all_submodules]
#'     sub_mod_left_updated <- sub_mod_left
#'     full_mods <- dplyr::setdiff(valid_mods, sub_mod_left)
#'
#'     if (length(sub_mod_left) != 0) {
#'       for (l in submod_letters) {
#'         n_sub <- tidyselect::contains(l, vars = sub_mod_left) %>% length()
#'         # If there is only one submodule of a module left
#'         if (n_sub == 1) {
#'           # and there isn't another module with the same letter
#'           if (!(l %in% full_mods)) {
#'             # give it the module name (just letter)
#'             mod_df_sym_age[tidyselect::contains(l, vars = mod_df_sym_age$module_name), "module_name"] <- l
#'             idx_name_strength[tidyselect::contains(l, vars = idx_name_strength$module_name), "module_name"] <- l
#'             sub_mod_left_updated <- setdiff(sub_mod_left, sub_mod_left[tidyselect::contains(l, vars = sub_mod_left)])
#'           }
#'         }
#'       }
#'     }
#'
#'     mods_left <- unique(mod_df_sym_age[['module_name']])
#'
#'     # second, create submodules (at max age) for modules (at min age) that were split
#'     for (m in seq_along(mods_left)) {
#'
#'       # how many original modules (submodules) are linked to this module name?
#'       n_originals <- idx_name_strength %>%
#'         dplyr::filter(.data$module_name == mods_left[m]) %>%
#'         dplyr::pull(.data$original_module) %>%
#'         length()
#'
#'       # if more than 1, give numbers to module names
#'       if (n_originals > 1) {
#'         originals <- idx_name_strength %>%
#'           dplyr::filter(.data$module_name == mods_left[m]) %>%
#'           dplyr::pull(.data$original_module)
#'
#'         for(o in seq_along(originals)){
#'           # if splitting a submodule, don't do nested splits
#'           if(mods_left[m] %in% sub_mod_left_updated){
#'
#'             sub_let <- sub("\\d","",mods_left[m])
#'             nsubs <- tidyselect::contains(sub_let, vars = sub_mod_left_updated) %>% length()
#'
#'             mod_df_sym_age[which(mod_df_sym_age$module_name == mods_left[m] &
#'                                    mod_df_sym_age$original_module == originals[o]),
#'                            'module_name'] <- paste0(sub_let, which(originals == originals[o]) + nsubs)
#'           } else{
#'             mod_df_sym_age[which(mod_df_sym_age$module_name == mods_left[m] &
#'                                    mod_df_sym_age$original_module == originals[o]),
#'                            'module_name'] <- paste0(mods_left[m], which(originals == originals[o]))
#'           }
#'         }
#'       }
#'     }
#'
#'     mod_df_sym <- dplyr::bind_rows(mod_df_sym, mod_df_sym_age)
#'
#'   }
#'
#'   # set modules for hosts
#'   mod_idx_name <- mod_df_sym %>%
#'     dplyr::select(.data$age, .data$original_module, .data$module_name) %>%
#'     distinct()
#'
#'   mod_df_host <- dplyr::left_join(mod_df_host, mod_idx_name, by = c("age","original_module"))
#'
#'   mod_df <- bind_rows(mod_df_sym, mod_df_host) %>%
#'     dplyr::rename(module = module_name)
#'
#'
#'   # replace with user given names
#'
#'   list <- list(mod_df, mod_idx_name)
#'   names(list) <- c("nodes_and_modules_per_age", "original_and_matched_module_names")
#'
#'   return(list)
#'
#' }
maribraga/evolnets documentation built on Feb. 3, 2025, 6:46 p.m.