R/expansions_infer.R

Defines functions log_posterior para.log_lh para.initialiser fn_log_J_inv fn_log_J expansions_infer run_expansion_inference standard_priors

Documented in expansions_infer run_expansion_inference standard_priors

offset <- 2

#' Generates a list of priors of the form described in the paper to use with MCMC inference. 
#' @param expansion_rate corresponds to 'phi' in paper 
#' @param N_mean_log corresponds to 'mu_{anc}' in paper 
#' @param N_sd_log corresponds to 'sigma_{anc}' in paper 
#' @param t_mid_rate corresponds to 'lambda_r' in paper 
#' @param K_sd_log corresponds to 'sigma_{exp}' in paper 
#' @param exp_time_nu corresponds to 'nu' in paper 
#' @param exp_time_kappa corresponds to 'kappa' in paper 
#' @returns a list of prior likelihoods and sampling functions with names 'prior_i', 'prior_i.sample', 'prior_N', 'prior_N.sample' ,prior_t_mid_given_N', 'prior_t_mid_given_N.sample', 'prior_K_given_N', 'prior_K_given_N.sample', 'prior_t_given_N', 'prior_t_given_N.sample'. See `standard_priors` for details.
#' @export
#' 
standard_priors <- function(expansion_rate=1, 
                            N_mean_log=3, 
                            N_sd_log=3, 
                            t_mid_rate=5, 
                            K_sd_log=1/2, 
                            exp_time_nu=1/2, 
                            exp_time_kappa=1/2) {
    return(priorList(
        prior_i=function(x) dpois(x, expansion_rate, log = TRUE),
        prior_i.sample=function() rpois(1, expansion_rate),
        prior_N=function(x) dlnorm(x, meanlog = N_mean_log, sdlog = N_sd_log, log = TRUE),
        prior_N.sample=function() rlnorm(1, meanlog = N_mean_log, sdlog = N_sd_log),
        prior_t_mid_given_N=function(x, N) dexp(x, t_mid_rate/N, log = TRUE),
        prior_t_mid_given_N.sample=function(N) rexp(1, t_mid_rate/N),
        prior_K_given_N=function(x, N) dlnorm(x, meanlog = log(N), sdlog = K_sd_log, log = TRUE),
        prior_K_given_N.sample=function(N) rlnorm(1, meanlog = log(N), sdlog = K_sd_log),
        prior_t_given_N=function(x, N) {
               if (all(x < 0)) {
                      out <- dgamma(-x, shape=(exp_time_nu^2)/(exp_time_kappa^2), scale = exp_time_kappa^2 * N / exp_time_nu, log = TRUE)
               } else {
                      out <- -Inf
               }
               return(out) 
        },
        prior_t_given_N.sample=function(N) (-rgamma(1, shape=(exp_time_nu^2)/(exp_time_kappa^2), scale = exp_time_kappa^2 * N / exp_time_nu))
        ))
}


#' Run rjmcmc inference on the provided phylogeny and supplied priors
#' This is the standard method for mcmc expansion inference.
#' @param phy phylogeny to run inference on 
#' @param priors a list of prior likelihoods and sampling functions with names 'prior_i', 'prior_i.sample', 'prior_N', 'prior_N.sample', 'prior_t_mid_given_N', 'prior_t_mid_given_N.sample', 'prior_K_given_N', 'prior_K_given_N.sample', 'prior_t_given_N', 'prior_t_given_N.sample'. See `standard_priors` for details.
#' @param concentration concentration parameter for the dirichlet prior on expansion membership probabilities
#' @param n_it number of MCMC iterations 
#' @param thinning mcmc output thinning 
#' @param init initial state for the mcmc chain 
#' @return list with names 'phylo_preprocessed', the preprocessed phylogeny, and 'mcmc_out' a list of MCMC states
#' @export
#' 
run_expansion_inference <- function(phy, priors, concentration=1, n_it=1e6, thinning=1, init=NULL) {
    pre <- preprocess_phylo(phy)
    o <- expansions_infer(pre, 
                            priors$prior_i, 
                            priors$prior_N, 
                            priors$prior_t_mid_given_N, 
                            priors$prior_t_mid_given_N.sample, 
                            priors$prior_K_given_N, 
                            priors$prior_K_given_N.sample, 
                            priors$prior_t_given_N,
                            priors$prior_t_given_N.sample,
                            concentration,
                            n_it=n_it,
                            thinning=thinning,
                            init=init)
    dat <- mcmc2data.frame(o)
    model_data <- dat$mcmc.df
    expansion_data <- dat$event.df
    effective_it <- floor(n_it/thinning)
    stopifnot(length(unique(model_data$it))==effective_it)
    metadata <- list(n_it=n_it, thinning=thinning, effective_it=effective_it)
    out <- expansionsMCMC(phylo_preprocessed=pre, priors=priors, model_data=model_data, expansion_data=expansion_data, metadata=metadata)
    return(out)
}

#' Run rjmcmc inference on the provided preprocessed phylogeny, and using provided prior distributions
#' This is a barebones method for advanced usage. For standard usage refer to `run_expansion_inference`
#' @param pre Preprocessed phylogeny
#' @param prior_i Prior likelihood on number of expansions
#' @param prior_N Prior likelihood on background population size
#' @param prior_t_mid_given_N Prior likelihood on expansion time to midpoint given background population size
#' @param prior_t_mid_given_N.sample Sampling function for prior on time to midpoint given background population size
#' @param prior_K_mid_given_N Prior likelihood on expansion carrying capacity given background population size
#' @param prior_K_mid_given_N.sample Sampling function for prior on expansion carrying capacity given background population size
#' @param prior_t_given_N Prior likelihood on expansion time given background population size
#' @param prior_t_given_N.sample Sampling function for prior on expansion timegiven background population size
#' @param concentration concentration parameter for the dirichlet prior on expansion membership probabilities
#' @param n_it number of MCMC iterations 
#' @param thinning mcmc output thinning 
#' @param init initial state for the mcmc chain 
#' @param debug debug flag
#' @return a list of MCMC states
#' @export
#' 
expansions_infer <- function(pre,
                            prior_i, 
                            prior_N, 
                            prior_t_mid_given_N, 
                            prior_t_mid_given_N.sample, 
                            prior_K_given_N, 
                            prior_K_given_N.sample, 
                            prior_t_given_N,
                            prior_t_given_N.sample,
                            concentration,
                            n_it=1e6, thinning=1, init=NULL, debug=FALSE) {

    if (debug) warning("Running in debug mode with only priors in use.")

    edges <- pre$edges.df
    nodes <- pre$nodes.df
    n_tips <- pre$n_tips

    inner_branches <- edges$id[which(edges$node.child>pre$n_tips)] 
    edges_subs <- edges[inner_branches,]
    total_branch_len <- sum(edges_subs$length)


    all_times <- extract_lineage_times(pre, pre$phy$node.label[pre$root_idx-pre$n_tips], -Inf)

    max_t <- max(nodes$times)
    min_t <- min(nodes$times)

    max_t_nodes <- max(nodes$times[which(!nodes$is_tip)])

    tree_height <- max_t - min_t

    div_time_prop.sample <- function(N) runif(1, min_t, max_t_nodes)
    div_time_prop.lh <- function(x, N) log(1/(max_t_nodes-min_t))

    const_log_lh <- function(n){
        if (n > 0){
            lh <- -coalescent_loglh(all_times$sam.times[[1]],
                        all_times$coal.times[[1]],
                        n,
                        0)
        } else {
            lh <- Inf
        }
        return(lh)
    }

    if (is.null(init)) {
    
        N_0 <- optim(1, const_log_lh, lower=(1e-2)*tree_height, upper=(1e2)*tree_height, method="Brent", control = list(maxit = 2000000))$par
    
        i_0 <- 0 
        x_0 <- list()
        x_0[[1]] <- N_0 
        x_0[[2]] <- c(1)
    } else {
        x_0 <- init$x_0
        i_0 <- init$i_0
    }

    prior_probs <- function(probs) {

        out <- -Inf
        if (abs(sum(probs)-1) < 1e-8 && all(probs > 0)) {
            out <- ddirichlet(t(matrix(probs)), alpha=rep(concentration, length(probs)), log=TRUE)
        }
        return(out) 
    }

    o <- rjmcmc(function(x, i) log_posterior(x,
                                         i,
                                         prior_i, 
                                         prior_N,
                                         prior_t_mid_given_N, 
                                         prior_K_given_N, 
                                         prior_t_given_N, 
                                         prior_probs,
                                         pre),
                function(x_prev, i_prev) prop.sampler(x_prev,
                                                      i_prev, 
                                                      pre, 
                                                      function(N) para.initialiser(N, 
                                                                                   prior_t_mid_given_N.sample,
                                                                                   prior_K_given_N.sample, 
                                                                                   div_time_prop.sample,
                                                                                   pre),
                                                      function(x_init, N) para.log_lh(x_init,
                                                                                      N,
                                                                                      prior_t_mid_given_N, 
                                                                                      prior_K_given_N, 
                                                                                      div_time_prop.lh,
                                                                                      pre),
                                                      fn_log_J,
                                                      fn_log_J_inv,
                                                      pop_scale=(tree_height/2) ### good enough approximation of the population size
                                                      ),
                x_0, i_0, n_it, thinning)
    return(o)
}

fn_log_J <- function(i_prev, x_prev, x_next) {
    return(0)
}

fn_log_J_inv <- function(i_prev, x_prev, x_next, which_mdl_rm) {
    return(0)
}

para.initialiser <- function(N, prop_t_mid_given_N, prop_K_given_N, prop_t_given_N, pre){
    edges <- pre$edges.df
    nodes <- pre$nodes.df
    x_next <- vector(mode = "list", length = 4)

    mid.times <- prop_t_mid_given_N(N)
    K <- prop_K_given_N(N)
    div.times <- prop_t_given_N(N)

    ### which branches exist at time of divergence
    br_extant_before <- edges$id[which(nodes$times[edges$node.child] > div.times)]
    br_extant_after <- br_extant_before[which(nodes$times[edges$node.parent[br_extant_before]] < div.times)]

    ### filter out terminal branches as those have 0 prior mass
    extant_inner <- br_extant_after[which(!nodes$is_tip[edges$node.child[br_extant_after]])] 
    ### choose one at random
    if(length(extant_inner) < 1) extant_inner <- c(NA)

    div.branch <- extant_inner[sample.int(length(extant_inner),1)] 

    x_next[[1]] <- mid.times
    x_next[[2]] <- K
    x_next[[3]] <- div.times
    x_next[[4]] <- div.branch

    return(x_next)
}

para.log_lh <- function(x, N, prop_t_mid_given_N, prop_K_given_N, prop_t_given_N, pre) {
    edges <- pre$edges.df
    nodes <- pre$nodes.df

    mid.times <- x[[1]]
    K <- x[[2]]
    div.times <- x[[3]]
    div.branch <- x[[4]]

    if (is.na(div.branch)){ 
        out <- -Inf
    } else {
        ### which branches exist at time of divergence
        br_extant_before <- edges$id[which(nodes$times[edges$node.child] > div.times)]
        br_extant_after <- br_extant_before[which(nodes$times[edges$node.parent[br_extant_before]] < div.times)]
    
        ### filter out terminal branches as those have 0 prior mass
        extant_inner <- br_extant_after[which(!nodes$is_tip[edges$node.child[br_extant_after]])] 
        if (length(extant_inner) < 1) {
            out <- -Inf
        } else {
            out <- prop_t_mid_given_N(mid.times,N) + prop_K_given_N(K,N) + prop_t_given_N(div.times,N) + log(1/length(extant_inner))
        }
    }
    return(out)
}

log_posterior <- function(x,
                            i, 
                            prior_i, 
                            prior_N, 
                            prior_t_mid_given_N, 
                            prior_K_given_N, 
                            prior_t_given_N, 
                            prior_probs, 
                            pre) 
{
    prior <- 0
    lh <- 0

    edges <- pre$edges.df
    nodes <- pre$nodes.df

    n_tips <- pre$n_tips
    root_MRCA <- pre$nodes.df$lab[pre$root_idx]
    root_div <- -Inf

    ### Extract values

    N <- x[[1]]
    probs <- x[[2]]
    if (i > 0) {
        div.times <- sapply(c(1:i), function(j) x[[j+offset]][[3]])

        div_ord <- order(-div.times)
        div.times <- div.times[div_ord]

        mid.times <- sapply(div_ord, function(j) x[[j+offset]][[1]])
        K <- sapply(div_ord, function(j) x[[j+offset]][[2]])
        div.branch <- sapply(div_ord, function(j) x[[j+offset]][[4]])

        probs[c(1:i)] <- probs[div_ord]
    } else {
        mid.times <- c()
        K <- c()
        div.branch <- c()
        div.times <- c()
    }

    ### Check that all values make sense and lie within support of prior and likelihood functions
    if (
      (i >= 0) &&
      all(K > 0) &&
      all(mid.times > 0) &&
      all(N > 0) && 
      all(probs >= 0) &&
      (abs(sum(probs)-1) < 1e-8) &&
      all(!is.na(div.branch)) &&
      all(div.times > nodes$times[edges$node.parent[div.branch]]) && 
      all(div.times < nodes$times[edges$node.child[div.branch]])) 
    {
        MRCAs <- sapply(edges$node.child[div.branch], function(x) if (x > n_tips) pre$phy$node.label[x-n_tips] else NA)
        if (all(!is.na(MRCAs))) { ### Make sure no terminal branch is being proposed as that is a zero set. 
            
            prior <- prior_i(i) + prior_probs(probs) + prior_N(N) + 
                 lgamma(i+1) ### Correction for exchangeable RVs
            if (i > 0) {
                    prior <- prior + 
                             sum(prior_t_mid_given_N(mid.times, N)) +
                             sum(prior_K_given_N(K,N)) + 
                             sum(prior_t_given_N(div.times,N)) -
                             lgamma(i+1) ### prior on divergence events
            }

            MRCAs_root <- c(MRCAs, root_MRCA) ### add root for parent population
            div.times_root <- c(div.times, root_div) ### parent diverges at -Inf

            ### structured coal likelihood accepts rates, need to transform mid point to a rate
            rates <- sapply(mid.times, function (x) (1/x)**2)

            structured.log_lh <- structured_coal.likelihood(pre,
                                                            MRCAs_root, 
                                                            div.times_root, 
                                                            rates, 
                                                            K, 
                                                            N)

            partition_counts <- structured.log_lh$partition_counts
            partition_prior <- sum(sapply(c(1:length(probs)), function (i) log(probs[i])*partition_counts[[i]]))

            prior <- prior + partition_prior
            lh <- structured.log_lh$log_lh
        } else {
            prior <- -Inf
            lh <- -Inf
        }     
    } else {
        prior <- -Inf
        lh <- -Inf
    }
    return(list(prior=prior, lh=lh))
}
dhelekal/CaveDive documentation built on June 11, 2024, 4:32 p.m.