R/asr_mk_model.R

Defines functions asr_mk_model

Documented in asr_mk_model

# Ancestral state reconstruction (ASR) for discrete characters using a fixed-rates continuous-time Markov model (aka. "Mk model")
# Requires that states (or prior distributions) are known for all tips
# The transition matrix can either be provided, or can be estimated via maximum-likelihood fitting.
# Returns the loglikelihood of the model, the transition matrix and (optionally) the likelihoods of ancestral states ("marginal ancestral_likelihoods") for all internal nodes of the tree.
# The marginal ancestral likelihoods of an ancestral node is a vector of size Nstates, the i-th entry of which specifies the probability that the tree (if reroot==TRUE) or descending subtree (if reroot=FALSE) would be as observed, if the node's state was i.
# Uses the rerooting method introduced by Yang et al (1995), to infer marginal likelihoods of ancestral states.
# This function works similarly to phytools::rerootingMethod().
asr_mk_model = function(tree, 
						tip_states,									# 1D integer array of size Ntips. Can also be NULL.
						Nstates 				= NULL,				# number of possible states. Can be NULL.
						tip_priors 				= NULL,				# 2D numerical array of size Ntips x Nstates. Can also be NULL.
						rate_model 				= "ER",				# either "ER" or "SYM" or "ARD" or "SUEDE" or an integer vector mapping entries of the transition matrix to a set of independent rate parameters. The format and interpretation is the same as for index.matrix generated by the function get_transition_index_matrix(..). Only relevant if transition_matrix is NULL.
						transition_matrix 		= NULL,				# either NULL, or a transition matrix of size Nstates x Nstates, such that transition_matrix^T * p gives the rate of change of probability vector p. If NULL, the transition matrix will be fitted via maximum-likelihood. The convention is that [i,j] gives the transition rate i-->j.
						include_ancestral_likelihoods = TRUE,		# (bool) include the marginal ancestral state likelihoods for all nodes in the returned values
						reroot 					= TRUE,				# (bool) Use the rerooting method by [Yang 1995] to obtain the likelihoods for each node. This requires that the model be time-reversible. If FALSE, likelihoods will only be local, i.e. based on descending subtrees
						root_prior 				= "auto",			# can be 'auto', 'flat', 'stationary', 'empirical', 'likelihoods', 'max_likelihood' or a numeric vector of size Nstates, specifying the prior probabilities for the tree's root. Used to define the tree's likelihood, based on the root's marginal likelihoods
						Ntrials 				= 1,				# (int) number of trials (starting points) for fitting the transition matrix. Only relevant if transition_matrix=NULL.
						optim_algorithm		 	= "nlminb",			# either "optim" or "nlminb". What algorithm to use for fitting.
						optim_max_iterations	= 200,				# maximum number of iterations of the optimization algorithm (per trial)
						optim_rel_tol			= 1e-8,				# relative tolerance when optimizing the objective function
						store_exponentials 		= TRUE,
						check_input 			= TRUE,				# (bool) perform some basic sanity checks on the input data. Set this to FALSE if you're certain your input data is valid.
						Nthreads 				= 1){				# (integer) number of threads for running multiple fitting trials in parallel
    Ntips 			= length(tree$tip.label);
    Nnodes 			= tree$Nnode;
    Nedges 			= nrow(tree$edge);
	loglikelihood 	= NULL; # value will be calculated as we go
	got_tip_states	= (!is.null(tip_states))
	
	# create tip priors if needed
	if((!is.null(tip_states)) && (!is.null(tip_priors))) stop("ERROR: tip_states and tip_priors are both non-NULL, but exactly one of them should be NULL")
	else if(is.null(tip_states) && is.null(tip_priors))  stop("ERROR: tip_states and tip_priors are both NULL, but exactly one of them should be non-NULL")
	state_names = NULL
	if(!is.null(tip_states)){
		if(!is.numeric(tip_states)) stop(sprintf("ERROR: tip_states must be integers"))
		if(length(tip_states)==0) stop("ERROR: tip_states is non-NULL but empty")
		if(length(tip_states)!=Ntips) stop(sprintf("ERROR: Length of tip_states (%d) is not the same as the number of tips in the tree (%d)",length(tip_states),Ntips));
		if(is.null(Nstates)) Nstates = max(tip_states);
		if(check_input){
			min_tip_state = min(tip_states)
			max_tip_state = max(tip_states)
			if((min_tip_state<1) || (max_tip_state>Nstates)) stop(sprintf("ERROR: tip_states must be integers between 1 and %d, but found values between %d and %d",Nstates,min_tip_state,max_tip_state))
			if((!is.null(names(tip_states))) && any(names(tip_states)!=tree$tip.label)) stop("ERROR: Names in tip_states and tip labels in tree don't match (must be in the same order).")
		}
		tip_priors = matrix(1e-8/(Nstates-1), nrow=Ntips, ncol=Nstates);
		tip_priors[cbind(1:Ntips, tip_states)] = 1.0-1e-8;
	}else{
		if(nrow(tip_priors)==0) stop("ERROR: tip_priors is non-NULL but has zero rows")
		if(is.null(Nstates)){
			Nstates = ncol(tip_priors);
		}else if(Nstates != ncol(tip_priors)){
			stop(sprintf("ERROR: Nstates (%d) differs from the number of columns in tip_priors (%d)",Nstates,ncol(tip_priors)))
		}
		if(!is.null(colnames(tip_priors))) state_names = colnames(tip_priors);
		if(check_input){
			if(any(tip_priors>1.0)) stop(sprintf("ERROR: Some tip_priors are larger than 1.0 (max was %g)",max(tip_priors)))
			if((!is.null(rownames(tip_priors))) && (!is.null(tree$tip.label)) && (rownames(tip_priors)!=tree$tip.label)) stop("ERROR: Row names in tip_priors and tip labels in tree don't match")
		}
	}


    # estimate transition matrix if needed
    modelAIC = NULL
    if(is.null(transition_matrix)){
    	fit = fit_mk(	trees					= tree,
						Nstates					= Nstates,
						tip_states				= (if(got_tip_states) tip_states else NULL),
						tip_priors 				= (if(got_tip_states) NULL else tip_priors),
						rate_model 				= rate_model,
						root_prior 				= root_prior,
						Ntrials 				= Ntrials,
						max_model_runtime		= -1,
						optim_algorithm		 	= optim_algorithm,
						optim_max_iterations	= optim_max_iterations,
						optim_rel_tol			= optim_rel_tol,
						check_input 			= FALSE,
						Nthreads 				= Nthreads)
		if(!fit$success){
			return(list(success=FALSE, error=sprintf("Could not fit transition rate matrix: %s",fit$error)))
		}
		transition_matrix 	= fit$transition_matrix
		loglikelihood 		= fit$loglikelihood
		modelAIC			= fit$AIC
    }else{
		if(check_input){
			# make sure this is a valid transition matrix
			row_sums = rowSums(transition_matrix)
			if(any(abs(row_sums)>1e-6*max(abs(transition_matrix)))) stop(sprintf("Entries in transition_matrix do not sum up to 0.0 for each row; found row-sums between %g and %g\n",min(row_sums),max(row_sums)));
			# check row & column names
			CT = colnames(transition_matrix)
			RT = rownames(transition_matrix)
			CP = colnames(tip_priors)
			if((!is.null(CT)) && (!is.null(RT)) && (!all(CT==RT))) stop(sprintf("ERROR: Row names and column names of transition_matrix are not the same"))
			if((!is.null(CT)) && (!is.null(CP)) && (!all(CT==CP))) stop(sprintf("ERROR: Column names of transition_matrix and column names of tip_priors are not the same"))
			if((!is.null(RT)) && (!is.null(CP)) && (!all(RT==CP))) stop(sprintf("ERROR: Row names of transition_matrix and column names of tip_priors are not the same"))
		}
    }
    
    
    # figure out prior distribution for root if needed
    if(root_prior[1]=="auto"){
    	root_prior_type = "max_likelihood"
		root_prior_probabilities = numeric(0)
	}else if(root_prior[1]=="flat"){
    	root_prior_type = "custom"
    	root_prior_probabilities = rep(1.0/Nstates, times=Nstates);
    }else if(root_prior[1]=="empirical"){
    	root_prior_type = "custom"
		root_prior_probabilities = colSums(tip_priors)/nrow(tip_priors);
		root_prior_probabilities = root_prior_probabilities/sum(root_prior_probabilities);
    }else if((root_prior[1]=="stationary")){
    	root_prior_type = "custom"
    	root_prior_probabilities = get_stationary_distribution(transition_matrix);
    }else if(root_prior[1]=="max_likelihood"){
    	root_prior_type = "max_likelihood"
		root_prior_probabilities = numeric(0)
	}else{
		# the user provided their own root_prior probability vector
		if(length(root_prior)!=Nstates) stop(sprintf("ERROR: root_prior has length %d, expected %d",length(root_prior),Nstates))
		if(check_input){
			if((!is.null(names(root_prior))) && (!is.null(state_names)) && (!all(names(root_prior)==state_names))) stop(sprintf("ERROR: Names in root_prior and don't match up with state names\n"));
			if(any(root_prior<0)) stop(sprintf("ERROR: root_prior contains negative values (down to %g)",min(root_prior)))
			if(abs(1.0-sum(root_prior))>1e-6) stop(sprintf("ERROR: Entries in root prior do not sum up to 1 (sum=%.10g)",sum(root_prior)))
		}		
    	root_prior_type = "custom"
		root_prior_probabilities = root_prior		
	}
	
    
    # calculate loglikelihood and ancestral states if needed
    if(is.null(loglikelihood) || include_ancestral_likelihoods){
		# eigendecomposition = get_eigendecomposition_if_available(transition_matrix);
		reconstruction = ASR_with_fixed_rates_Markov_model_CPP(	Ntips 							= Ntips,
																Nnodes							= Nnodes,
																Nedges							= Nedges,
																Nstates							= Nstates,
																tree_edge 						= as.vector(t(tree$edge))-1,		# flatten in row-major format and make indices 0-based,
																edge_length		 				= (if(is.null(tree$edge.length)) rep(1.0, times=Nedges) else tree$edge.length),
																transition_matrix				= as.vector(t(transition_matrix)),	# flatten in row-major format
																eigenvalues						= numeric(), # disable eigendecomposition method for exponentiation
																EVmatrix						= numeric(), # disable eigendecomposition method for exponentiation
																inverse_EVmatrix				= numeric(), # disable eigendecomposition method for exponentiation
																prior_probabilities_per_tip 	= as.vector(t(tip_priors)),			# flatten in row-major format
																root_prior_type					= root_prior_type,
																root_prior						= root_prior_probabilities,
																include_ancestral_likelihoods	= include_ancestral_likelihoods,
																reroot 							= reroot,
																runtime_out_seconds				= 0,
																exponentiation_accuracy			= 1e-3,
																max_polynomials					= 1000,
																store_exponentials				= store_exponentials);
		loglikelihood = reconstruction$loglikelihood;
		if(include_ancestral_likelihoods){
			ancestral_likelihoods = matrix(reconstruction$ancestral_likelihoods, ncol=Nstates, byrow=TRUE, dimnames=list(tree$node.label,state_names)) # unflatten
		}
	}
	
	# return results
	results = list(	success					= TRUE, 
					Nstates					= Nstates,
					transition_matrix		= transition_matrix,
					loglikelihood			= loglikelihood,
					AIC						= modelAIC)
	if(include_ancestral_likelihoods){
		results$ancestral_likelihoods 	= ancestral_likelihoods
		results$ancestral_states		= sapply(seq_len(tree$Nnode), FUN=function(node) which.max(ancestral_likelihoods[node,])) # maximum-likelihood ancestral states
	}
	return(results)
}







# calculate the eigendecomposition of a matrix Q
# if eigendecomposition is not available (i.e. eigenvector-matrix does not have full rank), returns empty vectors
#get_eigendecomposition_if_available = function(Q){
#	eigenvalues 		= numeric();
#	EVmatrix 			= numeric();
#	inverse_EVmatrix 	= numeric();
#	decomposition 		= eigen(Q, only.values=FALSE);
#	if((length(decomposition$values)==ncol(Q)) && (Matrix::rankMatrix(decomposition$vectors, method="tolNorm2", tol=1e-6)==ncol(Q))){
#		eigenvalues 	 = decomposition$values
#		EVmatrix		 = decomposition$vectors
#		inverse_EVmatrix = solve(EVmatrix)
#	}
#	return(list(eigenvalues=eigenvalues, EVmatrix=EVmatrix, inverse_EVmatrix=inverse_EVmatrix))
#}

Try the castor package in your browser

Any scripts or data that you put into this service are public.

castor documentation built on Aug. 18, 2023, 1:07 a.m.