#' Calculate the posterior probability of ancestral host repertoires
#'
#' Group of functions to calculate the posterior probabilities of
#' ancestral host repertoires at internal nodes of the symbiont tree.
#'
#' @param history Data frame with posterior samples of interaction histories returned from
#' `read_history()`.
#' @param tree Symbiont tree
#' @param host_tree Host tree
#' @param nodes Vector of internal nodes for which to calculate the posterior
#' probability of `state`.
#' @param state Which state? Default is 2. For analyses using the 3-state model, give `c(1, 2)` to
#' include both states (where 1 is a potential host and 2 an actual host).
#'
#' @return A list with three elements:
#' \itemize{
#' \item{"`samples`"}{ An array of samples x nodes x hosts, containing the state of each sample.}
#' \item{"`post_states`"}{ An array of nodes x hosts x state containing the posterior probability
#' for each state.}
#' \item{"`post_repertoires`"}{ An array of nodes x hosts x repertoire containing the posterior
#' probability for 1) the `"realized"` repertoire which is defined as state 2, and 2) the
#' `"fundamental"` repertoire which is defined as having any state (usually 1 or 2).}
#' }
#' The number of samples is the number of iterations in `history`.
#' @export
#' @importFrom rlang .data
#' @import data.table
#'
#' @examples
#' # read parasite and host tree
#' data_path <- system.file("extdata", package = "evolnets")
#' tree <- read_tree_from_revbayes(paste0(data_path,"/tree_pieridae.tre"))
#' host_tree <- ape::read.tree(paste0(data_path,"/host_tree_pieridae.phy"))
#'
#' # read histories sampled during MCMC
#' history <- read_history(paste0(data_path,"/history_thin_pieridae.txt"), burnin = 0)
#'
#' # calculate the posterior probability of host repertoires
#' # at chosen internal nodes of the parasite tree
#' nodes <- c(129:131)
#' pp_at_nodes <- posterior_at_nodes(history, tree, host_tree, nodes)
posterior_at_nodes <- function(history, tree, host_tree, nodes = NULL, state = c(2)) {
# input checking
if (!is.data.frame(history)) {
stop('`history` should be a data.frame, usually generated by `read_history()`')
}
if (!all(c('node_index', 'iteration', 'transition_type') %in% names(history))) {
stop('`history` needs to have columns `node_index`, `iteration` and `transition_type`.')
}
if (!inherits(tree, 'phylo')) stop('`tree` should be a phylogeny of class `phylo`.')
if (!inherits(host_tree, 'phylo')) stop('`host_tree` should be a phylogeny of class `phylo`.')
if (!is.numeric(state)) stop('`state` should be a numeric vector.')
if (is.null(nodes)) nodes <- (ape::Ntip(tree) + 1):(ape::Ntip(tree) + ape::Nnode(tree))
if (!is.numeric(nodes)) stop('`nodes` should be a numeric vector.')
iterations <- sort(unique(history$iteration))
n_iter <- length(iterations)
# Make factors so we do not accidentally drop iterations or nodes
history$iteration <- factor(history$iteration, iterations)
history$node_index <- factor(history$node_index, nodes)
# Convert to data.table
dat <- data.table::setDT(history)
# Drop all nodes that we won't need
dat <- dat[node_index %in% nodes, ]
# Select columns to reduce memory use
dat <- dat[, c('iteration', 'node_index', 'transition_time', 'end_state')]
# Sort the table by transition time, so we can quickly find the lowest transition time per node
data.table::setorderv(dat, 'transition_time', na.last = TRUE)
# For each node in each iteration, only take the first row. Either this is the row with the
# lowest transition time (since we just sorted), or the only row (since transition time can be NA).
dat <- dat[, .SD[1], by = c('iteration', 'node_index')]
# For each row left, find any interactions of type `state`
dat[
,
`:=`(
# Find for which hosts (their index) we find an interaction matching `state` at that node
s_index = lapply(
stringr::str_locate_all(end_state, paste0(state, collapse = '|')),
function(x) x[, 'start']
),
# And extract which state that is
states = stringr::str_extract_all(end_state, paste0(state, collapse = '|'))
)
]
# Drop to final columns
dat <- dat[, c('iteration', 'node_index', 's_index', 'states')]
# Unnest by unlisting s_index and states for each group
dat <- dat[
,
list(s_index = unlist(s_index), states = unlist(states)),
by = c('iteration', 'node_index')
]
# Convert to factors, so that we don't accidentally drop any levels
dat[, `:=`(
s_index = factor(s_index, levels = seq_along(host_tree$tip.label)),
states = factor(states, levels = state)
)]
# Now count all combinations to make our samples array
all_combs <- unclass(table(dat$iteration, dat$node_index, dat$s_index, dat$state))
if (!all(all_combs %in% c(0, 1))) {
# Just in case something has gone horribly wrong.
stop('Duplicated combinations detected, please contact the maintainer of evolnets.')
}
# We have to assign the correct state, so loop through the possible states and assign the state
# to wherever an interaction was found (i.e. count == 1)
l <- list()
for (s in seq_along(state)) {
a <- all_combs[, , , s]
a[a == 1] <- state[s]
l[[s]] <- a
}
# Then we merge the arrays for each state. We take the first array, and overwrite the elements
# where the second array is non-null with elements of the second. Repeat for a possible third
# state etc. (I don't think there is a third state, but this generalizes to n states.)
array <- Reduce(function(x, y) { x[y != 0] <- y[y != 0]; return(x) }, l)
# The line above should have the same result as the sum below, but the sum is much slower
#array <- apply(array, 1:3, sum)
# Finally we assign names to the array
dimnames(array) <- list(1:n_iter, paste0("Index_", nodes), host_tree$tip.label)
# For the probabilities arrays, we just need to count and divide by total iterations
# First, make an array of the posterior probabilities for each state
p_st <- unclass(table(dat$node_index, dat$s_index, dat$state))
p_st <- p_st / n_iter
dimnames(p_st) <- list(paste0("Index_", nodes), host_tree$tip.label, state)
# Then, make an array that has the realized repertoire (state 2), and the fundamental repertoire
# (sum of the probabilities of all states)
p_rep <- array(0, c(dim(p_st)[1:2], 2), c(dimnames(p_st)[1:2], list(c('realized', 'fundamental'))))
if (2 %in% state) p_rep[, , 1] <- p_st[, , '2']
p_rep[, , 2] <- apply(p_st, 1:2, sum)
list(samples = array, post_states = p_st, post_repertoires = p_rep)
}
# Assign global variables since my data.table skills are not good enough to avoid using them
node_index <- end_state <- s_index <- states <- NULL
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.