R/posterior_at_nodes.R

Defines functions posterior_at_nodes

Documented in posterior_at_nodes

#' Calculate the posterior probability of ancestral host repertoires
#'
#' Group of functions to calculate the posterior probabilities of
#'   ancestral host repertoires at internal nodes of the symbiont tree.
#'
#' @param history Data frame with posterior samples of interaction histories returned from
#'   `read_history()`.
#' @param tree Symbiont tree
#' @param host_tree Host tree
#' @param nodes Vector of internal nodes for which to calculate the posterior
#'   probability of `state`.
#' @param state Which state? Default is 2. For analyses using the 3-state model, give `c(1, 2)` to
#' include both states (where 1 is a potential host and 2 an actual host).
#'
#' @return A list with three elements:
#' \itemize{
#'  \item{"`samples`"}{ An array of samples x nodes x hosts, containing the state of each sample.}
#'  \item{"`post_states`"}{ An array of nodes x hosts x state containing the posterior probability
#'  for each state.}
#'  \item{"`post_repertoires`"}{ An array of nodes x hosts x repertoire containing the posterior
#'  probability for 1) the `"realized"` repertoire which is defined as state 2, and 2) the
#'  `"fundamental"` repertoire which is defined as having any state (usually 1 or 2).}
#' }
#'  The number of samples is the number of iterations in `history`.
#' @export
#' @importFrom rlang .data
#' @import data.table
#'
#' @examples
#' # read parasite and host tree
#' data_path <- system.file("extdata", package = "evolnets")
#' tree <- read_tree_from_revbayes(paste0(data_path,"/tree_pieridae.tre"))
#' host_tree <- ape::read.tree(paste0(data_path,"/host_tree_pieridae.phy"))
#'
#' # read histories sampled during MCMC
#' history <- read_history(paste0(data_path,"/history_thin_pieridae.txt"), burnin = 0)
#'
#' # calculate the posterior probability of host repertoires
#' # at chosen internal nodes of the parasite tree
#' nodes <- c(129:131)
#' pp_at_nodes <- posterior_at_nodes(history, tree, host_tree, nodes)
posterior_at_nodes <- function(history, tree, host_tree, nodes = NULL, state = c(2)) {

  # input checking
  if (!is.data.frame(history)) {
    stop('`history` should be a data.frame, usually generated by `read_history()`')
  }
  if (!all(c('node_index', 'iteration', 'transition_type') %in% names(history))) {
    stop('`history` needs to have columns `node_index`, `iteration` and `transition_type`.')
  }
  if (!inherits(tree, 'phylo')) stop('`tree` should be a phylogeny of class `phylo`.')
  if (!inherits(host_tree, 'phylo')) stop('`host_tree` should be a phylogeny of class `phylo`.')
  if (!is.numeric(state)) stop('`state` should be a numeric vector.')

  if (is.null(nodes)) nodes <- (ape::Ntip(tree) + 1):(ape::Ntip(tree) + ape::Nnode(tree))
  if (!is.numeric(nodes)) stop('`nodes` should be a numeric vector.')

  iterations <- sort(unique(history$iteration))
  n_iter <- length(iterations)

  # Make factors so we do not accidentally drop iterations or nodes
  history$iteration <- factor(history$iteration, iterations)
  history$node_index <- factor(history$node_index, nodes)

  # Convert to data.table
  dat <- data.table::setDT(history)

  # Drop all nodes that we won't need
  dat <- dat[node_index %in% nodes, ]
  # Select columns to reduce memory use
  dat <- dat[, c('iteration', 'node_index', 'transition_time', 'end_state')]
  # Sort the table by transition time, so we can quickly find the lowest transition time per node
  data.table::setorderv(dat, 'transition_time', na.last = TRUE)
  # For each node in each iteration, only take the first row. Either this is the row with the
  # lowest transition time (since we just sorted), or the only row (since transition time can be NA).
  dat <- dat[, .SD[1], by = c('iteration', 'node_index')]
  # For each row left, find any interactions of type `state`
  dat[
    ,
    `:=`(
      # Find for which hosts (their index) we find an interaction matching `state` at that node
      s_index = lapply(
        stringr::str_locate_all(end_state, paste0(state, collapse = '|')),
        function(x) x[, 'start']
      ),
      # And extract which state that is
      states = stringr::str_extract_all(end_state, paste0(state, collapse = '|'))
    )
  ]
  # Drop to final columns
  dat <- dat[, c('iteration', 'node_index', 's_index', 'states')]
  # Unnest by unlisting s_index and states for each group
  dat <- dat[
    ,
    list(s_index = unlist(s_index), states = unlist(states)),
    by = c('iteration', 'node_index')
  ]
  # Convert to factors, so that we don't accidentally drop any levels
  dat[, `:=`(
    s_index = factor(s_index, levels = seq_along(host_tree$tip.label)),
    states = factor(states, levels = state)
  )]

  # Now count all combinations to make our samples array
  all_combs <- unclass(table(dat$iteration, dat$node_index, dat$s_index, dat$state))
  if (!all(all_combs %in% c(0, 1))) {
    # Just in case something has gone horribly wrong.
    stop('Duplicated combinations detected, please contact the maintainer of evolnets.')
  }
  # We have to assign the correct state, so loop through the possible states and assign the state
  # to wherever an interaction was found (i.e. count == 1)
  l <- list()
  for (s in seq_along(state)) {
    a <- all_combs[, , , s]
    a[a == 1] <- state[s]
    l[[s]] <- a
  }
  # Then we merge the arrays for each state. We take the first array, and overwrite the elements
  # where the second array is non-null with elements of the second. Repeat for a possible third
  # state etc. (I don't think there is a third state, but this generalizes to n states.)
  array <- Reduce(function(x, y) { x[y != 0] <- y[y != 0]; return(x) }, l)
  # The line above should have the same result as the sum below, but the sum is much slower
  #array <- apply(array, 1:3, sum)
  # Finally we assign names to the array
  dimnames(array) <- list(1:n_iter, paste0("Index_", nodes), host_tree$tip.label)

  # For the probabilities arrays, we just need to count and divide by total iterations
  # First, make an array of the posterior probabilities for each state
  p_st <- unclass(table(dat$node_index, dat$s_index, dat$state))
  p_st <- p_st / n_iter
  dimnames(p_st) <- list(paste0("Index_", nodes), host_tree$tip.label, state)
  # Then, make an array that has the realized repertoire (state 2), and the fundamental repertoire
  # (sum of the probabilities of all states)
  p_rep <- array(0, c(dim(p_st)[1:2], 2), c(dimnames(p_st)[1:2], list(c('realized', 'fundamental'))))
  if (2 %in% state) p_rep[, , 1] <- p_st[, , '2']
  p_rep[, , 2] <- apply(p_st, 1:2, sum)

  list(samples = array, post_states = p_st, post_repertoires = p_rep)
}

# Assign global variables since my data.table skills are not good enough to avoid using them
node_index <- end_state <- s_index <- states <- NULL
maribraga/evolnets documentation built on Feb. 3, 2025, 6:46 p.m.