
Defines functions derivs.classe check.pars.classe inflate.pars.classe flatten.pars.classe make.pars.classe make.branches.classe projection.matrix.classe stationary.freq.classe.ev stationary.freq.classe starting.point.classe rootfunc.classe make.initial.conditions.classe make.cache.classe default.argnames.classe make.info.classe make.classe

Documented in make.classe starting.point.classe

## Models should provide:
##   1. make
##   2. info
##   3. make.cache, including initial tip conditions
##   4. initial.conditions(init, pars,t, idx)
##   5. rootfunc(res, pars, ...)

## Common other functions include:
##   stationary.freq
##   starting.point
##   branches

# The full ClaSSE parameter structure is a vector with, for n states:
#     n * n * (n+1) / 2  speciation rates (lambda_ijk, j <= k)
#     n                  extinction rates (mu_i)
#     n * n - n          transition rates (q_ij, i != j)
#   = (n + 3) * n^2 / 2  elements

## 1: make
make.classe <- function(tree, states, k, sampling.f=NULL, strict=TRUE,
                        control=list()) {
  ## Note that this uses MuSSE's cache...
  cache <- make.cache.classe(tree, states, k, sampling.f, strict)
  initial.conditions <- make.initial.conditions.classe(k)
  all_branches <- make.all_branches.dtlik(cache, control,
  rootfunc <- rootfunc.classe
  f.pars <- make.pars.classe(k)

  ll <- function(pars, condition.surv=TRUE, root=ROOT.OBS,
                 root.p=NULL, intermediates=FALSE) {
    pars2 <- f.pars(pars)
    ans <- all_branches(pars2, intermediates)
    ## TODO: This is different to other functions, as the
    ## stationary.freq function assumes the non-expanded case.
    ## However, it would be straightforward to modify stationary.freq
    ## classe to use the expanded case.  At worst, it could just strip
    ## off the extra parameters, but I think that it builds these
    ## anyway.
    ## This will be an issue for creating a split or time version.
    rootfunc(ans, pars, condition.surv, root, root.p, intermediates)
  class(ll) <- c("classe", "dtlik", "function")

## 2: info
make.info.classe <- function(k, phy) {
       ## Parameters:
       np=as.integer((k+3)*k*k/2 + k),
       ## Variables:
       ## R version of derivatives function
       ## Phylogeny:
       ## Inference:
       ## These are optional
         "Goldberg (submitted)"))
default.argnames.classe <- function(k) {
  fmt <- sprintf("%%0%dd", ceiling(log10(k + .5)))
  sstr <- sprintf(fmt, 1:k)
  lambda.names <- sprintf("lambda%s%s%s", rep(sstr, each=k*(k+1)/2),
                          rep(rep(sstr, times=seq(k,1,-1)), k), 
                          unlist(lapply(1:k, function(i) sstr[i:k])))
  mu.names <- sprintf("mu%s", sstr)
  q.names <- sprintf("q%s%s", rep(sstr, each=k-1), 
                     unlist(lapply(1:k, function(i) sstr[-i])))
  c(lambda.names, mu.names, q.names)

## 3: make.cache (& initial.tip)
## Note: classe uses the same functions as musse here to generate the
## cache and initial tip states.
make.cache.classe <- function(tree, states, k, sampling.f=NULL,
                             strict=TRUE) {
  if (k > 31)
    stop("No more than 31 states allowed.  Increase in classe-eqs.c.")
  cache <- make.cache.musse(tree, states, k, sampling.f, strict)
  cache$info <- make.info.classe(k, tree)

## 4: initial.conditions:
## save on index computations by wrappping initial.conditions.classe
make.initial.conditions.classe <- function(n) {
  ## n = number of states; called k elsewhere but k is used as an index below
  nseq <- seq_len(n)
  lam.idx <- matrix(seq_len(n*n*(n+1)/2), byrow=TRUE, nrow=n)

  idxD <- (n+1):(2*n)
  j <- rep(nseq, times=seq(n,1,-1))
  k <- unlist(lapply(1:n, function(i) nseq[i:n]))
  d <- rep(NA, n)

  initial.conditions.classe <- function(init, pars, t, is.root=FALSE) {
    ## E_i(t), same for N and M
    e <- init[nseq,1]

    ## D_i(t), formed from N and M
    DM <- init[idxD,1]
    DN <- init[idxD,2]
    DM.DN <- 0.5 * (DM[j] * DN[k] + DM[k] * DN[j])
    for (i in nseq)
      d[i] <- sum(pars[lam.idx[i,]] * DM.DN) # slower with apply

    ## a touch slower but cleaner:
    ##   idxlam = seq_len(n*n*(n+1)/2)
    ##     d = colSums(matrix(pars[idxlam], ncol=n) * DM.DN)
    ## or slightly better (but still not faster than for):
    ##   lamseq = seq_len(n*n*(n+1)/2)
    ##   lam.mat = matrix(lamseq, ncol=n)
    ##     lam.mat[lamseq] = pars[lamseq]
    ##     d = colSums(lam.mat * DM.DN)

    c(e, d)

rootfunc.classe <- function(res, pars, condition.surv, root, root.p,
                            intermediates) {
  vals <- res$vals
  lq <- res$lq
  k <- length(vals)/2

  i <- seq_len(k)
  d.root <- vals[-i]

  ## TODO: This could be tidied up:
  root.equi <- function(pars) stationary.freq.classe(pars, k)
  root.p <- root_p_calc(d.root, pars, root, root.p, root.equi)
  if ( condition.surv ) {
    ## species in state i are subject to all lambda_ijk speciation rates
    nsum <- k*(k+1)/2
    lambda <- colSums(matrix(pars[1:(nsum*k)], nrow=nsum))
    e.root <- vals[i]
    d.root <- d.root / sum(root.p * lambda * (1 - e.root)^2)

  if ( root == ROOT.ALL )
    loglik <- log(d.root) + sum(lq)
    loglik <- log(sum(root.p * d.root)) + sum(lq)

  if ( intermediates ) {
    res$root.p <- root.p
    attr(loglik, "intermediates") <- res
    attr(loglik, "vals") <- vals


## Additional functions
## Heuristic starting point
## based on starting.point.geosse()
starting.point.classe <- function(tree, k, eps=0.5) {
  if (eps == 0) {
    s <- (log(Ntip(tree)) - log(2)) / max(branching.times(tree))
    x <- 0
    d <- s/10
  } else {
    n <- Ntip(tree)
    r <- ( log( (n/2) * (1 - eps*eps) + 2*eps + (1 - eps)/2 *
           sqrt( n * (n*eps*eps - 8*eps + 2*n*eps + n))) - log(2)
         ) / max(branching.times(tree))
    s <- r / (1 - eps)
    x <- s * eps
    q <- s - x
  p <- c( rep(s / (k*(k+1)/2), k*k*(k+1)/2 ), rep(x, k), rep(q, k*(k-1)) )
  names(p) <- default.argnames.classe(k)

stationary.freq.classe <- function(pars, k) {
  if (k == 2) {
    g <- (sum(pars[1:3]) - pars[7]) - (sum(pars[4:6]) - pars[8])
    eps <- sum(pars[1:8]) * 1e-14
    ss1 <- pars[9]  + 2*pars[3] + pars[2]  # shift from 1
    ss2 <- pars[10] + 2*pars[4] + pars[5]  # shift from 2

    if ( abs(g) < eps ) {
      if (ss1 + ss2 == 0) 
        eqfreq <- 0.5
        eqfreq <- ss2/(ss1 + ss2)
      eqfreq <- c(eqfreq, 1 - eqfreq)
    } else {
      roots <- quadratic.roots(g, ss2 + ss1 - g, -ss2)
      eqfreq <- roots[roots >= 0 & roots <= 1]
      if ( length(eqfreq) > 1 )
        eqfreq <- NA
        eqfreq <- c(eqfreq, 1 - eqfreq)
  } else { ## also works for k=2, but much slower
    eqfreq <- stationary.freq.classe.ev(pars, k)

## like stationary.freq.geosse()
stationary.freq.classe.ev <- function(pars, k) {
  A <- projection.matrix.classe(pars, k)
  ## continuous time, so the dominant eigenvalue is the largest one
  ## return its eigenvector, normalized
  evA <- eigen(A)
  i <- which(evA$values == max(evA$values))
  evA$vectors[,i] / sum(evA$vectors[,i])

projection.matrix.classe <- function(pars, k) {
  A <- matrix(0, nrow=k, ncol=k)

  nsum <- k*(k+1)/2
  kseq <- seq_len(k)
  pars.lam <- pars[seq(1, nsum*k)]
  pars.mu <- pars[seq(nsum*k+1, (nsum+1)*k)]
  pars.q <- pars[seq((nsum+1)*k+1, length(pars))]

  ## array indices of lambda's in parameter vector
  idx.lam <- cbind(rep(kseq, each=nsum), rep(rep(kseq, times=seq(k,1,-1)), k),
                   unlist(lapply(kseq, function(i) i:k)))
  ## transpose of matrix indices of q's in parameter vector
  idx.q <- cbind(unlist(lapply(kseq, function(i) (kseq)[-i])), 
                 rep(kseq, each=k-1))

  ## take care of off-diagonal elements
  for (n in seq_len(nsum*k)) {
    ## add this lambda to A[daughter states, parent state]
    ## (separate steps in case the daughter states are the same)
    r <- idx.lam[n,]
    A[r[2], r[1]] <- A[r[2], r[1]] + pars.lam[n]
    A[r[3], r[1]] <- A[r[3], r[1]] + pars.lam[n]
  A[idx.q] <- A[idx.q] + pars.q

  ## fix the diagonal elements
  diag(A) <- 0
  diag(A) <- -colSums(A) + unlist(lapply(kseq, function(i) 
                 sum(pars.lam[seq((i-1)*nsum+1, i*nsum)]) - pars.mu[i]))

## For historical and debugging purposes, not used directly in the
## calculations, but branches function is generated this way
## internally.
make.branches.classe <- function(cache, control)
  make.branches.dtlik(cache$info, control)

## Parameter manipulation
make.pars.classe <- function(k) {
  np0 <- as.integer((k+3)*k*k/2)  # number of actual params
  np <- np0 + k                   # number, including Q's diagonal elements
  qmat <- matrix(0, k, k)
  idx.qmat <- cbind(rep(1:k, each=k-1),
               unlist(lapply(1:k, function(i) (1:k)[-i])))
  x <- k * k * (k + 1) / 2 + k
  idx.lm <- seq_len(x)
  idx.q <- seq(x+1, np0)
  function(pars) {
    check.pars.classe(pars, k)
    qmat[idx.qmat] <- pars[idx.q]
    diag(qmat) <- -rowSums(qmat)
    c(pars[idx.lm], qmat)

## These two functions are intended to make the classe parameters easier to
## visualize and populate, since they get unwieldy with more than two states.
## The speciation rate array is indexed lambda[parent state, daughter1 state,
## daughter2 state].  The transition matrix is indexed q[from state, to state].
## Elements that are not parameters get NA: daughter2 > daughter 1, from = to.
## The parameter list might be a good way to work with constrain(), eventually.

## Input: list containing lambda_ijk array, mu vector, q_ij array, num states
## Output: parameter vector, ordered as default.argnames.classe describes
flatten.pars.classe <- function(parlist) {
  k <- parlist$nstates
  kseq <- seq_len(k)

  idx.lam <- cbind( rep(kseq, each=k*(k+1)/2), 
                    rep(rep(kseq, times=seq(k,1,-1)), k), 
                    unlist(lapply(kseq, function(i) i:k)) )

  idx.q <- cbind( rep(kseq, each=k-1), 
                  unlist(lapply(kseq, function(i) (kseq)[-i])) )

  pars <- c(parlist$lambda[idx.lam], parlist$mu, parlist$q[idx.q])
  names(pars) <- default.argnames.classe(k)

## Output: list containing lambda_ijk array, mu vector, q_ij array, num states
## Input: parameter vector, ordered as default.argnames.classe describes
inflate.pars.classe <- function(pars, k) {
  check.pars.classe(pars, k)
  kseq <- seq_len(k)

  Lam <- array(NA, dim=rep(k, 3))  # 3 = parent + 2 daughters
  idx <- cbind(rep(kseq, each=k*(k+1)/2), rep(rep(kseq, times=seq(k,1,-1)), k),
               unlist(lapply(kseq, function(i) i:k)))
  j <- length(idx[,1])
  Lam[idx] <- pars[seq(j)]

  Mu <- pars[seq(j+1, j+k)]
  names(Mu) <- NULL

  Q <- array(NA, dim=rep(k, 2))
  idx <- cbind(rep(kseq, each=k-1), unlist(lapply(kseq, function(i) kseq[-i])))
  Q[idx] <- pars[seq(j+k+1, length(pars))]

  list(lambda=Lam, mu=Mu, q=Q, nstates=k)

check.pars.classe <- function(pars, k)
  check.pars.nonnegative(pars, (k+3)*k*k/2)

derivs.classe <- function(t, y, pars) {
  ## TODO: Need to write this (Emma: Do you have a copy somewhere?)
  stop("Not yet possible")

Try the diversitree package in your browser

Any scripts or data that you put into this service are public.

diversitree documentation built on Sept. 8, 2023, 5:54 p.m.