R/fit_symmetric_mk.R

Defines functions fit_symmetric_mk

# Fit a fixed-rates time-reversible continuous-time Markov model for discrete character evolution (Mk model) based on a set of independent contrasts
# Time reversibility means that the transition-rates matrix is symmetric.
# The input tree may include monofurcations and multifurcations, but note that multifurcations are internally first split into bifurcations
fit_symmetric_mk = function(trees, 									# either a single tree in phylo format, or a list of trees
							Nstates,								# integer, number of possible states
							tip_states,								# either a 1D vector of size Ntips (if trees[] is a single tree) or a list of 1D vectors (if trees[] is a list of trees), listing the state of each tip in each tree.
							phylodistance_matrixes	= NULL,			# optional list of length Ntrees, each element being a phylodistance matrix for the tips of a tree. Hence phylodistance_matrixes[[tr]][i,j] is the phylogenetic (patristic) distance between tips i & j in tree tr. Can be used to specify distances between tips regardless of the edge lengths in the trees (i.e., trees are only used for topology). Some phylodistances may be NA or NaN; the corresponding tip pairs will be omitted from the fitting.
							only_basal_tip_pairs	= FALSE,		# logical, specifying whether only immediate sister tips should be considered, i.e. tip pairs with at most 2 edges between the two tips
							rate_model 				= "ER",			# either "ER" or "SYM" or "ARD" or "SUEDE" or an integer vector mapping entries of the transition matrix to a set of independent rate parameters. The format and interpretation is the same as for index.matrix generated by the function get_transition_index_matrix(..).
							max_MRCA_ages			= NULL,			# Optional numeric or numeric vector of size Ntrees, specifying the maximum age of MRCAs of independent contrasts to consider. Typically this will be <=root_age. If NULL, then the filter is ignored. May contain NAs or NaNs; for these trees the filter is ignored.
							optim_algorithm		 	= "nlminb",		# either "optim" or "nlminb". What algorithm to use for fitting.
							optim_max_iterations	= 200,			# maximum number of iterations of the optimization algorithm (per trial)
							optim_rel_tol			= 1e-8,			# relative tolerance when optimizing the objective function
							check_input 			= TRUE,			# (bool) perform some basic sanity checks on the input data. Set this to FALSE if you're certain your input data is valid.
							Ntrials 				= 1,			# (int) number of trials (starting points) for fitting the transition matrix. Only relevant if transition_matrix=NULL.
							Nthreads 				= 1,			# (integer) number of threads for running multiple fitting trials in parallel
							max_model_runtime		= NULL,			# maximum time (in seconds) to allocate for each likelihood evaluation per tree. Use this to escape from badly parameterized models during fitting (this will likely cause the affected fitting trial to fail). If NULL or <=0, this option is ignored.
							verbose					= FALSE,		# boolean, specifying whether to print informative messages
							verbose_prefix			= ""){			# string, specifying the line prefix when printing messages. Only relevant if verbose==TRUE.

	# basic input checking
	if(verbose) cat(sprintf("%sChecking input variables..\n",verbose_prefix))
	if(!(rate_model %in% c("ER", "SYM"))) return(list(success=FALSE, error=sprintf("Unknown or non-symmetric rate model '%s'",rate_model)))
	if("phylo" %in% class(trees)){
		# trees[] is actually a single tree
		trees = list(trees)
		Ntrees = 1
	}else if("list" %in% class(trees)){
		# trees[] is a list of trees
		Ntrees = length(trees)
	}else{
		return(list(success=FALSE,error=sprintf("Unknown data format '%s' for input trees[]: Expected a list of phylo trees or a single phylo tree",class(trees)[1])))
	}
	if(!(("list" %in% class(tip_states)) && (length(tip_states)==Ntrees))){
		# something is wrong with tip_states, perhaps we can fix it
		if((Ntrees==1) && (length(tip_states)==length(trees[[1]]$tip.label))){
			tip_states = list(unlist(tip_states))
		}else{
			return(list(success=FALSE, error=sprintf("Invalid input format for tip_states: Expected a list of vectors, each listing the tip states of a specific tree")))
		}
	}
	if(is.null(max_model_runtime)) max_model_runtime = 0
	if(!is.null(max_MRCA_ages)){
		if(length(max_MRCA_ages)==1){
			max_MRCA_ages = rep(max_MRCA_ages,times=Ntrees)
		}else if(length(max_MRCA_ages)!=Ntrees){
			return(list(success=FALSE, error=sprintf("Invalid number of max_MRCA_ages[]; expected either 1 value or %d values (=Ntrees)",Ntrees)))
		}
	}else{
		max_MRCA_ages = rep(NA,times=Ntrees)
	}
# 	if(is.null(sampling_fractions)){
# 		sampling_fractions = rep(1,times=Nstates)
# 	}else if(length(sampling_fractions)==1){
# 		sampling_fractions = rep(1,times=Nstates) # all sampling fractions are the same, so wlog we can set them all to 1
# 	}else if(length(sampling_fractions)!=Nstates){
# 		return(list(success=FALSE, error=sprintf("Invalid number of sampling_fractions (%d); expected either 1 value or %d values (=Nstates)",length(sampling_fractions),Nstates)))
# 	}

	# loop through the trees and extract independet contrasts between sister clades
	phylodistances 		= numeric()
	transitions			= matrix(NA, nrow=0, ncol=2)
	for(i in 1:Ntrees){
		tree 					= trees[[i]]
		tip_states_this_tree  	= tip_states[[i]]
		max_MRCA_age_this_tree	= max_MRCA_ages[[i]]
		
		# make sure tree does not have multifurcations
		tree = multifurcations_to_bifurcations(tree)$tree
		
		# extract independent pairs of sister tips (independent contrasts)
		tip_pairs = get_independent_sister_tips(tree)
		if(only_basal_tip_pairs){
			# calculate number of nodes between tip pairs
			edge_counts = get_pairwise_distances(tree, A=tip_pairs[,1], B=tip_pairs[,2], as_edge_counts=TRUE, check_input=FALSE)
			# only keep tip pairs with at most 2 edges connecting them
			keep_pairs 	= which(edge_counts<=2)
			tip_pairs 	= tip_pairs[keep_pairs,,drop=FALSE]
		}
		if(is.finite(max_MRCA_age_this_tree)){
			MRCAs 			= get_pairwise_mrcas(tree, tip_pairs[,1], tip_pairs[,2], check_input=FALSE)
			clade_heights	= castor::get_all_distances_to_root(tree)
			tree_span		= max(clade_heights)
			keep_pairs		= which(tree_span-clade_heights[MRCAs]<=max_MRCA_age_this_tree)
			tip_pairs 		= tip_pairs[keep_pairs,,drop=FALSE]
		}

		# determine phylodistances between sister tips
		if(is.null(phylodistance_matrixes)){
			# calculate phylodistances from tree
			phylodistances_this_tree = get_pairwise_distances(tree, A=tip_pairs[,1], B=tip_pairs[,2], check_input=FALSE)
		}else{
			# get phylodistances from provided phylodistance matrixes
			phylodistance_matrix_this_tree = phylodistance_matrixes[[i]]
			if(nrow(phylodistance_matrix_this_tree)!=length(tree$tip.label)) return(list(success=FALSE,error=sprintf("ERROR: Input phylodistance_matrix #%d has %d rows, but expected %d rows (number of tips in the input tree)",i,nrow(phylodistance_matrix_this_tree),length(tree$tip.label))))
			if(ncol(phylodistance_matrix_this_tree)!=length(tree$tip.label)) return(list(success=FALSE,error=sprintf("ERROR: Input phylodistance_matrix #%d has %d columns, but expected %d columns (number of tips in the input tree)",i,ncol(phylodistance_matrix_this_tree),length(tree$tip.label))))
			phylodistances_this_tree = phylodistance_matrix_this_tree[tip_pairs]
		}

		# omit tip pairs with zero phylogenetic distance, because in that case the likelihood density is pathological
		keep_pairs 					= which(is.finite(phylodistances_this_tree) & (phylodistances_this_tree>0))
		tip_pairs					= tip_pairs[keep_pairs,,drop=FALSE]
		phylodistances_this_tree	= phylodistances_this_tree[keep_pairs]
		
		if(nrow(tip_pairs)==0) next; # no valid tip pairs found in this tree
		
		# determine state transitions for this tree's tip_pairs
		transitions_this_tree = cbind(tip_states_this_tree[tip_pairs[,1]], tip_states_this_tree[tip_pairs[,2]])

		# apend this tree's independent contrasts to master lists
		phylodistances 	= c(phylodistances, phylodistances_this_tree)
		transitions		= rbind(transitions, transitions_this_tree)
	}
	NC = length(phylodistances)
	if(NC==0) return(list(success=FALSE, error="No valid tip pairs left for extracting independent contrasts"))
	if(all(transitions[,1]==transitions[,2])) return(list(success=FALSE, error="None of the independent contrasts comprise state transitions"))

	# figure out reasonable first guess for the transition rate between states
	first_guess_rate = mean(abs(transitions[,1]-transitions[,2])/phylodistances)
	if(first_guess_rate==0){
		first_guess_rate = mean(sapply(1:Ntrees, FUN=function(tr) Nstates/((if(is.null(trees[[tr]]$edge.length)) 1 else mean(trees[[tr]]$edge.length))*log(length(trees[[tr]]$tip.label))/log(2.0))))
	}

	# get set of independent rates to be fitted (as a condensed vector), and the corresponding index_matrix (which links each entry in the transition matrix to an independent rate parameter)
	indexing		= get_transition_index_matrix(Nstates,rate_model)
	index_matrix	= indexing$index_matrix
	Nrates 			= indexing$Nrates
	
	
	# define objective function to be minimized (negated log-likelihood)
	objective_function = function(dense_rates){
		if(any(is.nan(dense_rates)) || any(is.infinite(dense_rates))) return(Inf);
		Q = get_transition_matrix_from_rate_vector(dense_rates, index_matrix, Nstates);
		results = TR_Mk_loglikelihood_ICs_CPP(	Nstates					= Nstates,
												phylodistances			= phylodistances,
												transitions				= as.vector(t(transitions-1)),	# flatten in row-major format, convert to 0-based
												transition_matrix		= as.vector(t(Q)),				# flatten in row-major format
												runtime_out_seconds		= max_model_runtime,
												exponentiation_accuracy	= 1e-3,
												max_polynomials			= 1000)
		if((!results$success) || is.na(results$loglikelihood) || is.nan(results$loglikelihood)) return(Inf)
		return(-results$loglikelihood);
	}
	
	# fit starting with various starting_rates, keep track of best fit
	fit_single_trial = function(trial){
		power_range = 8
		initial_dense_rates = if(trial==1) rep(first_guess_rate, Nrates) else first_guess_rate * 10**runif(n=Nrates, min=-power_range/2, max=power_range/2);
		rate_scale = mean(abs(initial_dense_rates))
		if(optim_algorithm == "optim"){
			fit = stats::optim(	initial_dense_rates/rate_scale, 
								function(x) objective_function(x*rate_scale), 
								method 	= "L-BFGS-B", 
								lower 	= rep(first_guess_rate/(10**power_range), Nrates)/rate_scale,
								upper	= rep((10**power_range)*first_guess_rate, Nrates)/rate_scale,
								control = list(maxit=optim_max_iterations, reltol=optim_rel_tol))
			LL 				= -fit$value;
			Nevaluations 	= fit$counts
			Niterations		= NA
			converged 		= (fit$convergence==0)
		}else{
			fit = stats::nlminb(initial_dense_rates/rate_scale, 
								function(x) objective_function(x*rate_scale), 
								lower=rep(0, Nrates)/rate_scale, 
								upper=rep((10**power_range)*first_guess_rate, Nrates)/rate_scale,
								control = list(iter.max=optim_max_iterations, eval.max=optim_max_iterations*Nrates*10, rel.tol=optim_rel_tol, step.min=1e-4))
			LL 				= -fit$objective;
			Nevaluations 	= fit$evaluations[1]
			Niterations		= fit$iterations
			converged		= (fit$convergence==0)
		}
		fit$par = fit$par * rate_scale
		return(list(LL=LL, Nevaluations=Nevaluations, Niterations=Niterations, converged=converged, fit=fit))
	}
	
	# run one or more independent fitting trials
	if((Ntrials>1) && (Nthreads>1) && (.Platform$OS.type!="windows")){
		# run trials in parallel using multiple forks
		# Note: Forks (and hence shared memory) are not available on Windows
		if(verbose) cat(sprintf("%sFitting %d free parameters (%d trials, parallelized)..\n",verbose_prefix,Nrates,Ntrials))
		fits = parallel::mclapply(	1:Ntrials, 
									FUN = function(trial) fit_single_trial(trial), 
									mc.cores = min(Nthreads, Ntrials), 
									mc.preschedule = FALSE, 
									mc.cleanup = TRUE);
	}else{
		# run in serial mode
		if(verbose) cat(sprintf("%sFitting %d free parameters (%s)..\n",verbose_prefix,Nrates,(if(Ntrials==1) "1 trial" else sprintf("%d trials",Ntrials))))
		fits = sapply(1:Ntrials,function(x) NULL)
		for(trial in 1:Ntrials){
			fits[[trial]] = fit_single_trial(trial)
		}
	}
	
	
	# extract information from best fit (note that some fits may have LL=NaN or NA)
	LLs		= unlist_with_nulls(sapply(1:Ntrials, function(trial) fits[[trial]]$LL))
	valids 	= which((!is.na(LLs)) & (!is.nan(LLs)) & (!is.null(LLs)) & (!is.infinite(LLs)) & sapply(1:Ntrials, function(trial) (!any(is.null(fits[[trial]]$fit$par))) && all(is.finite(fits[[trial]]$fit$par))))
	if(length(valids)==0) return(list(success=FALSE, error="Fitting failed for all trials"));
	best 				= valids[which.max(LLs[valids])]
	loglikelihood		= fits[[best]]$LL
	fitted_rates 		= fits[[best]]$fit$par
	transition_matrix 	= get_transition_matrix_from_rate_vector(fitted_rates, index_matrix, Nstates)

	
	#####################################


	# return results
	return(list(success				= TRUE, 
				Nstates				= Nstates, 
				transition_matrix	= transition_matrix, 
				loglikelihood		= loglikelihood,
				Ncontrasts			= NC,
				phylodistances 		= phylodistances,
				Niterations			= fits[[best]]$Niterations,  # may be NA, depending on the optimization algorithm
				Nevaluations		= fits[[best]]$Nevaluations, # may be NA, depending on the optimization algorithm
				converged			= fits[[best]]$converged,
				guess_rate			= first_guess_rate))
}

Try the castor package in your browser

Any scripts or data that you put into this service are public.

castor documentation built on Aug. 18, 2023, 1:07 a.m.