Nothing
# Fit a fixed-rates continuous-time Markov model (aka. "Mk model")
# Requires that states (or prior likelihoods) are known for all tips
# The transition matrix is estimated via maximum-likelihood.
fit_mk = function( trees, # either a single tree in phylo format, or a list of trees
Nstates, # integer, number of possible states
tip_states = NULL, # either a 1D integer vector of size Ntips (if trees[] is a single tree) or a list of 1D integer vectors (if trees[] is a list of trees), listing the state of each tip in each tree. Can also be NULL, in which case tip_priors must be provided.
tip_priors = NULL, # either a 2D numerical matrix of size Ntips x Nstates (if trees[] is a single tree) or a list of 2D numerical matrixes (if trees[] is a list of trees), listing the prior likelihoods for each state at each tip and for each tree. Can also be NULL.
rate_model = "ER", # either "ER" or "SYM" or "ARD" or "SUEDE" or an integer vector mapping entries of the transition matrix to a set of independent rate parameters. The format and interpretation is the same as for index.matrix generated by the function get_transition_index_matrix(..).
root_prior = "auto", # can be 'auto', 'flat', 'stationary', 'empirical', 'max_likelihood' or a numeric vector of size Nstates, specifying the prior probabilities for the tree's root. Used to define the tree's likelihood, based on the root's marginal likelihoods
oldest_ages = NULL, # optional numeric or numeric vector of size Ntrees, specifying the oldest age to consider for each tree. Typically this will be <=root_age. If NULL, this will be set to each tree's root_age. May contain NAs or NaNs; for these trees the oldest age will be the root_age.
guess_transition_matrix = NULL, # optional 2D numeric matrix, specifying a first guess for the transition rate matrix. May contain NAs.
Ntrials = 1, # (int) number of trials (starting points) for fitting the transition matrix. Only relevant if transition_matrix=NULL.
Nscouts = NULL, # integer, number of randomly chosen parameter values to consider as possible fitting starts (only Ntrials-1 of those will be kept). Each scout costs only one evaluation of the loglikelihood function, which is much cheaper than a whole fitting trial. If NULL, this is automatically chosen based on the number of fitted parameters and Ntrials. Only relevant if Ntrials>1, since the first trial uses the default first guess (i.e., without scouting).
max_model_runtime = NULL, # maximum time (in seconds) to allocate for each likelihood evaluation per tree. Use this to escape from badly parameterized models during fitting (this will likely cause the affected fitting trial to fail). If NULL or <=0, this option is ignored.
optim_algorithm = "nlminb", # either "optim" or "nlminb". What algorithm to use for fitting.
optim_max_iterations = 200, # maximum number of iterations of the optimization algorithm (per trial)
optim_rel_tol = 1e-8, # relative tolerance when optimizing the objective function
check_input = TRUE, # (bool) perform some basic sanity checks on the input data. Set this to FALSE if you're certain your input data is valid.
Nthreads = 1, # (integer) number of threads for running multiple fitting trials in parallel
Nbootstraps = 0, # integer optional number of parametric-bootstrap samples for estimating confidence intervals of fitted parameters. If 0, no parametric bootstrapping is performed. Typical values are 10-100.
Ntrials_per_bootstrap = NULL, # integer optional number of fitting trials for each bootstrap sampling. If NULL, this is set equal to Ntrials. A smaller Ntrials_per_bootstrap will reduce computation, at the expense of increasing the estimated confidence intervals (i.e. yielding more conservative estimates of confidence).
verbose = FALSE, # boolean, specifying whether to print informative messages
diagnostics = FALSE, # boolean, specifying whether to print detailed info (such as log-likelihood) at every iteration of the fitting. For debugging purposes mainly.
verbose_prefix = ""){ # string, specifying the line prefix when printing messages. Only relevant if verbose==TRUE.
# basic input checking
if(verbose) cat(sprintf("%sChecking input variables..\n",verbose_prefix))
if((!is.null(tip_states)) && (!is.null(tip_priors))) return(list(success=FALSE, error="tip_states and tip_priors are both non-NULL, but exactly one of them should be NULL"))
else if(is.null(tip_states) && is.null(tip_priors)) return(list(success=FALSE, error="tip_states and tip_priors are both NULL, but exactly one of them should be non-NULL"))
if("phylo" %in% class(trees)){
# trees[] is actually a single tree
trees = list(trees)
Ntrees = 1
}else if("list" %in% class(trees)){
# trees[] is a list of trees
Ntrees = length(trees)
}else{
return(list(success=FALSE,error=sprintf("Unknown data format '%s' for input trees[]: Expected a list of phylo trees or a single phylo tree",class(trees)[1])))
}
if(!is.null(tip_states)){
if(!(("list" %in% class(tip_states)) && (length(tip_states)==Ntrees))){
# something is wrong with tip_states, perhaps we can fix it
if((Ntrees==1) && (length(tip_states)==length(trees[[1]]$tip.label))){
tip_states = list(unlist(tip_states))
}else{
return(list(success=FALSE, error=sprintf("Invalid input format for tip_states: Expected a list of vectors, each listing the tip states of a specific tree")))
}
}
}
if(!is.null(tip_priors)){
if(!(("list" %in% class(tip_priors)) && (length(tip_priors)==Ntrees))){
# something is wrong with tip_priors, perhaps we can fix it
if((Ntrees==1) && (nrow(tip_priors)==length(trees[[1]]$tip.label)) && (ncol(tip_priors)==Nstates)){
tip_priors = list(tip_priors)
}else{
return(list(success=FALSE, error=sprintf("Invalid input format for tip_priors: Expected a list of matrixes, each listing the tip priors of a specific tree")))
}
}
}
if(is.null(max_model_runtime)) max_model_runtime = 0;
if(is.null(Ntrials_per_bootstrap)) Ntrials_per_bootstrap = max(1,Ntrials)
if(!is.null(oldest_ages)){
if(length(oldest_ages)==1){
oldest_ages = rep(oldest_ages,times=Ntrees)
}else if(length(oldest_ages)!=Ntrees){
return(list(success=FALSE, error=sprintf("Invalid number of oldest_ages[]; expected either 1 value or %d values (=Ntrees)",Ntrees)))
}
}
if(!is.null(guess_transition_matrix)){
if((nrow(guess_transition_matrix)!=Nstates) || (ncol(guess_transition_matrix)!=Nstates)) return(list(success=FALSE, error=sprintf("Guess transition matrix has incorrect dimensions (%d x %d); expected a %d x %d matrix",nrow(guess_transition_matrix),ncol(guess_transition_matrix),Nstates,Nstates)))
}
original_guess_transition_matrix = guess_transition_matrix
# prepare priors if needed
if(!is.null(tip_states)){
tip_priors = vector(mode="list", Ntrees)
for(tr in 1:Ntrees){
focal_tip_states = tip_states[[tr]]
focal_tree = trees[[tr]]
Ntips = length(focal_tree$tip.label)
if(!is.numeric(focal_tip_states)) return(list(success=FALSE, error=sprintf("tip_states for tree %d are not integers",tr)))
if(length(focal_tip_states)==0) return(list(success=FALSE, error=sprintf("tip_states for tree %d are non-NULL but empty",tr)))
if(length(focal_tip_states)!=Ntips) return(list(success=FALSE, error=sprintf("Length of tip_states (%d) for tree %d is not the same as the number of tips in the tree (%d)",length(focal_tip_states),tr,Ntips)))
if(check_input){
min_tip_state = min(focal_tip_states)
max_tip_state = max(focal_tip_states)
if((min_tip_state<1) || (max_tip_state>Nstates)) return(list(success=FALSE, error=sprintf("tip_states must be integers between 1 and %d, but found values between %d and %d for tree %d",Nstates,min_tip_state,max_tip_state,tr)))
if((!is.null(names(focal_tip_states))) && any(names(focal_tip_states)!=focal_tree$tip.label)) return(list(success=FALSE, error="Names in tip_states and tip labels in tree %d don't match (must be in the same order)",tr))
}
focal_tip_priors = matrix(1e-8/(Nstates-1), nrow=Ntips, ncol=Nstates);
focal_tip_priors[cbind(1:Ntips, focal_tip_states)] = 1.0-1e-8;
tip_priors[[tr]] = focal_tip_priors
}
}else{
for(tr in 1:Ntrees){
focal_tree = trees[[tr]]
focal_tip_priors = tip_priors[[tr]]
if(nrow(focal_tip_priors)==0) return(list(success=FALSE, error=sprintf("ERROR: tip_priors for tree %d is non-NULL but has zero rows",tr)))
if(Nstates != ncol(focal_tip_priors)){
return(list(success=FALSE, error=sprintf("ERROR: Nstates (%d) differs from the number of columns in tip_priors (%d), for tree %d",Nstates,ncol(focal_tip_priors),tr)))
}
if(check_input){
if(any(focal_tip_priors>1.0)) return(list(success=FALSE, error=sprintf("ERROR: Some tip_priors are larger than 1.0 (max was %g), for tree %d",max(focal_tip_priors),tr)))
if((!is.null(rownames(focal_tip_priors))) && (!is.null(focal_tree$tip.label)) && (rownames(focal_tip_priors)!=focal_tree$tip.label)) return(list(success=FALSE, error=sprintf("ERROR: Row names in tip_priors and tip labels in tree %d don't match",tr)))
}
}
}
NtotalTips = sum(sapply(1:Ntrees, FUN=function(tr) length(trees[[tr]]$tip.label)))
# figure out prior distribution for root if needed
if(root_prior[1]=="auto"){
root_prior_type = "max_likelihood"
root_prior_probabilities = numeric(0)
}else if(root_prior[1]=="flat"){
root_prior_type = "custom"
root_prior_probabilities = rep(1.0/Nstates, times=Nstates);
}else if(root_prior[1]=="empirical"){
root_prior_type = "custom"
root_prior_probabilities = sapply(1:Nstates, FUN=function(state) sum(sapply(1:Ntrees, FUN=function(tr) sum(tip_priors[[tr]][,state])))/NtotalTips)
root_prior_probabilities = root_prior_probabilities/sum(root_prior_probabilities) # normalize empirical probabilities
}else if((root_prior[1]=="stationary")){
# root prior cannot yet be calculated, and will be determined at every fitting iteration
root_prior_type = NULL
root_prior_probabilities = NULL
}else if(root_prior[1]=="max_likelihood"){
root_prior_type = "max_likelihood"
root_prior_probabilities = numeric(0)
}else if(root_prior[1]=="likelihoods"){
root_prior_type = "likelihoods"
root_prior_probabilities = numeric(0)
}else{
# the user provided their own root_prior probability vector
if(length(root_prior)!=Nstates) return(list(success=FALSE, error=sprintf("ERROR: root_prior has length %d, expected %d",length(root_prior),Nstates)))
if(check_input){
if(any(root_prior<0)) return(list(success=FALSE, error=sprintf("ERROR: root_prior contains negative values (down to %g)",min(root_prior))))
if(abs(1.0-sum(root_prior))>1e-6) return(list(success=FALSE, error=sprintf("ERROR: Entries in root prior do not sum up to 1 (sum=%.10g)",sum(root_prior))))
}
root_prior_type = "custom"
root_prior_probabilities = root_prior
}
# get set of independent rates to be fitted (as a condensed vector), and the corresponding index_matrix (which links each entry in the transition matrix to an independent rate parameter)
temp_results = get_transition_index_matrix(Nstates,rate_model);
index_matrix = temp_results$index_matrix;
Nrates = temp_results$Nrates;
# figure out reasonable first guess for the transition rates between states
# use independent contrasts to get the scale of transition rates
first_guess_rate = 0
NICs = 0
for(tr in 1:Ntrees){
if(is.null(tip_states) || is.null(tip_states[[tr]])){
focal_tip_states = max.col(tip_priors[[tr]])
}else{
focal_tip_states = tip_states[[tr]]
}
focal_tree = trees[[tr]]
focal_tip_pairs = get_independent_sister_tips(focal_tree) # get independent contrasts
phylogenetic_distances = get_pairwise_distances(focal_tree, A=focal_tip_pairs[,1], B=focal_tip_pairs[,2], check_input=FALSE)
valids = which(phylogenetic_distances>0)
phylogenetic_distances = phylogenetic_distances[valids]
focal_tip_pairs = focal_tip_pairs[valids,,drop=FALSE]
transitions = abs(focal_tip_states[focal_tip_pairs[,1]]-focal_tip_states[focal_tip_pairs[,2]])
first_guess_rate = first_guess_rate + sum(transitions/phylogenetic_distances)
NICs = NICs + length(valids)
}
first_guess_rate = first_guess_rate/NICs
if(first_guess_rate==0){
if(is.null(guess_transition_matrix) || all(is.na(guess_transition_matrix))){
first_guess_rate = mean(sapply(1:Ntrees, FUN=function(tr) Nstates/((if(is.null(trees[[tr]]$edge.length)) 1 else mean(trees[[tr]]$edge.length))*log(length(trees[[tr]]$tip.label))/log(2.0))))
}else{
first_guess_rate = mean(abs(as.vector(guess_transition_matrix)), na.rm=TRUE)
}
}
if(diagnostics) cat(sprintf("%s First guess rate: %.5g\n",verbose_prefix,first_guess_rate))
if(is.null(guess_transition_matrix) || all(is.na(guess_transition_matrix))){
guess_transition_matrix = get_transition_matrix_from_rate_vector(rates = rep(first_guess_rate, Nrates), index_matrix = index_matrix, Nstates = Nstates)
}else{
guess_transition_matrix[is.na(guess_transition_matrix)] = first_guess_rate
}
# make sure first-guess transition matrix is a valid transition matrix
diag(guess_transition_matrix) = 0
diag(guess_transition_matrix) = -rowSums(guess_transition_matrix)
# define objective function to be minimized (negated log-likelihood)
objective_function = function(dense_rates, trial){
if(any(is.nan(dense_rates)) || any(is.infinite(dense_rates))){
if(diagnostics) cat(sprintf("%s Trial %d: Objective requested for invalid rates: %s\n",verbose_prefix,trial,paste(sprintf("%.4g", dense_rates), collapse=", ")))
return(Inf)
}
Q = get_transition_matrix_from_rate_vector(dense_rates, index_matrix, Nstates)
if(root_prior[1]=="stationary"){
root_prior_type = "custom"
root_prior_probabilities = get_stationary_distribution(Q)
}
loglikelihood = 0
for(tr in seq_len(Ntrees)){
focal_tree = trees[[tr]]
results = Mk_loglikelihood_CPP( Ntips = length(focal_tree$tip.label),
Nnodes = focal_tree$Nnode,
Nedges = nrow(focal_tree$edge),
Nstates = Nstates,
tree_edge = as.vector(t(focal_tree$edge))-1, # flatten in row-major format and make indices 0-based,
edge_length = (if(is.null(focal_tree$edge.length)) numeric() else focal_tree$edge.length),
transition_matrix = as.vector(t(Q)), # flatten in row-major format
prior_probabilities_per_tip = as.vector(t(tip_priors[[tr]])), # flatten in row-major format
root_prior_type = root_prior_type,
root_prior = root_prior_probabilities,
oldest_age = (if(is.null(oldest_ages)) -1 else (if(is.finite(oldest_ages[tr])) oldest_ages[tr] else -1)),
runtime_out_seconds = max_model_runtime,
exponentiation_accuracy = 1e-3,
max_polynomials = 1000)
if((!results$success) || is.na(results$loglikelihood) || is.nan(results$loglikelihood)){
if(diagnostics) cat(sprintf("%s Trial %d, tree %d (%d tips): Model evaluation failed: %s\n",verbose_prefix,trial,tr,length(focal_tree$tip.label),results$error))
return(Inf)
}
loglikelihood = loglikelihood + results$loglikelihood
}
if(diagnostics) cat(sprintf("%s Trial %d: loglikelihood %.10g\n",verbose_prefix,trial,loglikelihood))
return(-loglikelihood)
}
# fit starting with various starting_rates, keep track of best fit
fit_single_trial = function(start_dense_rates, trial){
rate_scale = mean(abs(start_dense_rates))
if(optim_algorithm == "optim"){
fit = stats::optim( start_dense_rates/rate_scale,
function(x) objective_function(x*rate_scale, trial),
method = "L-BFGS-B",
lower = rep(first_guess_rate/(10**power_range), Nrates)/rate_scale,
upper = rep((10**power_range)*first_guess_rate, Nrates)/rate_scale,
control = list(maxit=optim_max_iterations))
LL = -fit$value;
Nevaluations = fit$counts
Niterations = NA
converged = (fit$convergence==0)
}else{
fit = stats::nlminb(start_dense_rates/rate_scale,
function(x) objective_function(x*rate_scale, trial),
lower=rep(0, Nrates)/rate_scale,
upper=rep((10**power_range)*first_guess_rate, Nrates)/rate_scale,
control = list(iter.max=optim_max_iterations, eval.max=optim_max_iterations*Nrates*10, rel.tol=optim_rel_tol, step.min=1e-5))
LL = -fit$objective;
Nevaluations = fit$evaluations[1]
Niterations = fit$iterations
converged = (fit$convergence==0)
}
fit$par = fit$par * rate_scale
if(diagnostics) cat(sprintf("%s Trial %d: Final loglikelihood %.10g, converged = %d\n",verbose_prefix,trial,LL,converged))
return(list(LL=LL, Nevaluations=Nevaluations, Niterations=Niterations, converged=converged, fit=fit))
}
power_range = 6
default_start = extract_independent_rates_from_transition_matrix(guess_transition_matrix, index_matrix)
if(Ntrials>1){
# randomly choose multiple parameter starting points and keep the Ntrials-1 most promising ones (i.e., with smallest objective values) plus the defaut_start
Nscouts = (if(is.null(Nscouts)) min(10000,10*Nrates*Ntrials) else max(Ntrials-1,Nscouts))
if(verbose) cat(sprintf("%sGenerating %d random parameter starts and selecting the most promising ones..\n",verbose_prefix,Nscouts))
starts_pool = lapply(seq_len(Nscouts), FUN=function(k) first_guess_rate * 10**runif(n=Nrates, min=-((k/Nscouts)^2)*power_range/2, max=((k/Nscouts)^2)*power_range/2))
# compute the objective values for each start in the pool
if((Nthreads>1) && (.Platform$OS.type!="windows")){
start_objectives = unlist(parallel::mclapply(starts_pool, FUN = function(dense_rates) objective_function(dense_rates, trial=-1), mc.cores = min(Nthreads, length(starts_pool)), mc.preschedule = TRUE, mc.cleanup = TRUE))
}else{
start_objectives = sapply(starts_pool, FUN = function(dense_rates) objective_function(dense_rates, trial=-1))
}
# keep only the most promising starts (i.e., with lowest non-nan objectives), but always include first_guess_compr
start_objectives[!is.finite(start_objectives)] = NaN
starts_pool = c(list(default_start), starts_pool[get_smallest_items(values=start_objectives, N=Ntrials-1, check_nan=TRUE)])
}else{
# only consider the default guess as starting point
starts_pool = list(default_start)
}
Ntrials = length(starts_pool)
# run one or more independent fitting trials
if((Ntrials>1) && (Nthreads>1) && (.Platform$OS.type!="windows")){
# run trials in parallel using multiple forks
# Note: Forks (and hence shared memory) are not available on Windows
if(verbose) cat(sprintf("%sFitting %d free parameters (%d trials, parallelized)..\n",verbose_prefix,Nrates,Ntrials))
fits = parallel::mclapply( seq_len(Ntrials),
FUN = function(trial) fit_single_trial(start_dense_rates=starts_pool[[trial]], trial=trial),
mc.cores = min(Nthreads, Ntrials),
mc.preschedule = FALSE,
mc.cleanup = TRUE)
}else{
# run in serial mode
if(verbose) cat(sprintf("%sFitting %d free parameters (%s)..\n",verbose_prefix,Nrates,(if(Ntrials==1) "1 trial" else sprintf("%d trials",Ntrials))))
fits = sapply(1:Ntrials,function(x) NULL)
for(trial in seq_len(Ntrials)){
fits[[trial]] = fit_single_trial(start_dense_rates=starts_pool[[trial]], trial=trial)
}
}
# extract information from best fit (note that some fits may have LL=NaN or NA)
LLs = unlist_with_nulls(sapply(1:Ntrials, function(trial) fits[[trial]]$LL))
valids = which((!is.na(LLs)) & (!is.nan(LLs)) & (!is.null(LLs)) & (!is.infinite(LLs)) & sapply(1:Ntrials, function(trial) (!any(is.null(fits[[trial]]$fit$par))) && all(is.finite(fits[[trial]]$fit$par))))
if(length(valids)==0) return(list(success=FALSE, error="Fitting failed for all trials"));
best = valids[which.max(LLs[valids])]
loglikelihood = fits[[best]]$LL
fitted_rates = fits[[best]]$fit$par
transition_matrix = get_transition_matrix_from_rate_vector(fitted_rates, index_matrix, Nstates)
if(Nbootstraps>0){
if(verbose) cat(sprintf("%sEstimating confidence intervals using %d parametric bootstraps..\n",verbose_prefix,Nbootstraps))
bootstrap_params_flat = matrix(NA,nrow=Nbootstraps,ncol=Nstates*Nstates)
for(b in 1:Nbootstraps){
# simulate model with fitted parameters
if(verbose) cat(sprintf("%s Bootstrap #%d..\n",verbose_prefix,b))
bootstrap_tip_states = vector(mode="list", Ntrees)
for(tr in 1:Ntrees){
bootstrap_tip_states[[tr]] = simulate_mk_model( tree = trees[[tr]],
Q = transition_matrix,
root_probabilities = "stationary",
include_tips = TRUE,
include_nodes = FALSE,
Nsimulations = 1)$tip_states
}
# fit model to simulated tree
fit = fit_mk( trees = trees,
Nstates = Nstates,
tip_states = bootstrap_tip_states,
rate_model = rate_model,
root_prior = root_prior,
guess_transition_matrix = original_guess_transition_matrix,
Ntrials = Ntrials_per_bootstrap,
Nscouts = Nscouts,
max_model_runtime = max_model_runtime,
optim_algorithm = optim_algorithm,
optim_max_iterations = optim_max_iterations,
optim_rel_tol = optim_rel_tol,
check_input = FALSE,
Nthreads = Nthreads,
Nbootstraps = 0,
verbose = verbose,
diagnostics = diagnostics,
verbose_prefix = paste0(verbose_prefix," "))
if(!fit$success){
if(verbose) cat(sprintf("%s WARNING: Fitting failed for this bootstrap: %s\n",verbose_prefix,fit$error))
}else{
bootstrap_params_flat[b,] = as.vector(fit$transition_matrix)
}
}
# calculate standard errors and confidence intervals from distribution of bootstrapped parameters
standard_errors = matrix(sqrt(pmax(0, colMeans(bootstrap_params_flat^2, na.rm=TRUE) - colMeans(bootstrap_params_flat, na.rm=TRUE)^2)), nrow=Nstates, byrow=FALSE)
quantiles_flat = sapply(1:ncol(bootstrap_params_flat), FUN=function(p) quantile(bootstrap_params_flat[,p], probs=c(0.25, 0.75, 0.025, 0.975, 0.5), na.rm=TRUE, type=8))
CI50lower = matrix(quantiles_flat[1,], nrow=Nstates, byrow=FALSE)
CI50upper = matrix(quantiles_flat[2,], nrow=Nstates, byrow=FALSE)
CI95lower = matrix(quantiles_flat[3,], nrow=Nstates, byrow=FALSE)
CI95upper = matrix(quantiles_flat[4,], nrow=Nstates, byrow=FALSE)
medians = matrix(quantiles_flat[5,], nrow=Nstates, byrow=FALSE)
}
# return results
results = list( success = TRUE,
Nstates = Nstates,
transition_matrix = transition_matrix,
loglikelihood = loglikelihood,
Niterations = fits[[best]]$Niterations, # may be NA, depending on the optimization algorithm
Nevaluations = fits[[best]]$Nevaluations, # may be NA, depending on the optimization algorithm
converged = fits[[best]]$converged,
guess_rate = first_guess_rate,
AIC = 2*Nrates - 2*loglikelihood)
if(Nbootstraps>0){
results$CI50lower = CI50lower
results$CI50upper = CI50upper
results$CI95lower = CI95lower
results$CI95upper = CI95upper
}
return(results)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.