R/ROKET.R

Defines functions kernTEST run_myOTs run_myOT

Documented in kernTEST run_myOT run_myOTs

# Package Functions

# ----------
# Optimal Transport Functions
# ----------

#' @title run_myOT
#' @description Runs balanced or unbalanced optimal transport
#'	on two input vectors
#' @param XX A numeric vector of positive masses
#' @param YY A numeric vector of positive masses
#' @param COST A numeric matrix of non-negative values
#'	representing the costs to transport masses between
#'	features of \code{XX} and \code{YY}. The rows of \code{COST}
#'	and features of \code{XX} need to be aligned.
#'	The columns of \code{COST} and features of \code{YY}
#'	need to be aligned.
#' @param EPS A positive numeric value representing the
#'	tuning parameter for entropic regularization.
#' @param LAMBDA1 A non-negative numeric value representing
#'	the tuning parameter penalizing the distance between \code{XX}
#'	and the row sums of the optimal transport matrix.
#' @param LAMBDA2 A non-negative numeric value representing
#'	the tuning parameter penalizing the distance between \code{YY}
#'	and the column sums of the optimal transport matrix.
#' @param balance Boolean set to \code{TRUE} to run balanced
#'	optimal transport regardless of LAMDA1 and LAMBDA2. 
#'	Otherwise run unbalanced optimal transport.
#' @param conv A positive numeric value to determine 
#'	algorithmic convergence. The default value is \code{1e-5}.
#' @param max_iter A positive integer denoting the maximum
#'	iterations to run the algorithm.
#' @param verbose Boolean value to display verbose function output.
#' @param show_iter A positive integer to display iteration details
#'	at multiples of \code{show_iter} but only if \code{verbose = TRUE}.
#' @return A R list containing the optimal transport matrix and 
#'	associated distance metric.
#' 
#' @export
run_myOT = function(XX,YY,COST,EPS,LAMBDA1,LAMBDA2 = NULL,
	balance = FALSE,conv = 1e-5,max_iter = 3e3,
	verbose = TRUE,show_iter = 50){
	
	nXX = names(XX)
	nYY = names(YY)
	nrCOST = rownames(COST)
	ncCOST = colnames(COST)
	
	# Check inputs
	if( is.null(nXX) ) stop("Set names(XX)")
	if( is.null(nYY) ) stop("Set names(YY)")
	if( is.null(ncCOST) ) stop("Set colnames(COST)")
	if( is.null(nrCOST) ) stop("Set rownames(COST)")
	if( length(XX) != nrow(COST) ) stop("length(XX) != nrow(COST)")
	if( length(YY) != ncol(COST) ) stop("length(YY) != ncol(COST)")
	if( !all(nXX == nrCOST) ) stop("row COST name mismatch")
	if( !all(nYY == ncCOST) ) stop("column COST name mismatch")
	if( sum(XX) <= 0 ) stop("XX needs to have mass > 0")
	if( sum(YY) <= 0 ) stop("YY needs to have mass > 0")
	if( is.null(LAMBDA2) ) LAMBDA2 = LAMBDA1
	
	# If there are zeros in XX or YY, subset
	XX2 = XX[XX > 0]
	YY2 = YY[YY > 0]
	COST_XY = COST[names(XX2),names(YY2),drop = FALSE]
	
	if( sum(XX2) <= 0 ) stop("XX needs to have mass > 0")
	if( sum(YY2) <= 0 ) stop("YY needs to have mass > 0")
	
	# Run OT
	OT = Rcpp_run_OT(XX = XX2,YY = YY2,COST_XY = COST_XY,
		EPS = EPS,LAMBDA1 = LAMBDA1,LAMBDA2 = LAMBDA2,
		balance = balance,highLAM_lowMU = TRUE,
		conv = conv,max_iter = max_iter,show = verbose,
		show_iter = show_iter)
	
	# Add dimnames
	OT_fin = matrix(0,length(XX),length(YY))
	OT_fin = smart_names(MAT = OT_fin,
		ROW = nXX,COL = nYY)
	OT_fin[names(XX2),names(YY2)] = OT
	
	# Calculate OT metrics
	DIST = sum(COST * OT_fin)
	
	out = list(OT = OT_fin,DIST = DIST)
	return(out)
	
}

#' @title run_myOTs
#' @inheritParams run_myOT
#' @param ZZ A numeric matrix of non-negative mass to transport.
#'	Rows correspond to features (e.g. genes) and columns
#'	correspond to samples or individuals. Each column must have
#'	strictly positive mass
#' @param COST A numeric square matrix of non-negative values
#'	representing the non-negative costs to transport 
#'	masses between pairs of features
#' @param ncores A positive integer for the number of cores/threads
#'	to reduce computational runtime when running for loops
#' @param show_iter A positive integer to display iteration details
#'	at multiples of \code{show_iter} but only if \code{verbose = TRUE}.
#' @return A R numeric matrix of pairwise distances.
#' @export
run_myOTs = function(ZZ,COST,EPS,LAMBDA1,LAMBDA2 = NULL,
	balance,conv = 1e-5,max_iter = 3e3,ncores = 1,
	verbose = TRUE,show_iter = 50){
	
	# Check inputs
	mass = colSums(ZZ)
	if( any(mass <= 0) )
		stop("A subset of samples have non-positive mass")
	if( nrow(COST) != ncol(COST) ) stop("COST not square")
	if( !all(COST == t(COST)) ) stop("COST not symmetric")
	if( is.null(LAMBDA2) ) LAMBDA2 = LAMBDA1
	
	# Subset features of ZZ with any mass
	mass_rows = apply(ZZ,1,function(xx){
		sum(xx[xx > 0])
	})
	ZZ = ZZ[mass_rows > 0,,drop = FALSE]
	nZZ = rownames(ZZ)
	nCOST = rownames(COST)
	if( nrow(ZZ) > nrow(COST) )
		stop("More features in ZZ than features in COST")
	if( !all(nZZ %in% nCOST) )
		stop("Some features in ZZ not among features in COST")
	COST = COST[nZZ,nZZ,drop = FALSE]
	
	if( !all(colnames(COST) == rownames(ZZ)) )
		stop("rownames mismatch")
	
	# Run OT across all individuals
	out_OT = Rcpp_run_full_OT(COST = COST,ZZ = ZZ,
		EPS = EPS,LAMBDA1 = LAMBDA1,LAMBDA2 = LAMBDA2,
		balance = balance,highLAM_lowMU = TRUE,
		conv = conv,max_iter = max_iter,ncores = ncores,
		show = verbose,show_iter = show_iter)
	
	DIST = smart_names(out_OT$DIST,
		ROW = colnames(ZZ),COL = colnames(ZZ))
	
	return(DIST)
	
}


# ----------
# Kernel Regression Hypothesis Testing
# ----------

#' @title kernTEST
#' @inheritParams run_myOTs
#' @param RESI A numeric vector of null model residuals
#'	\code{names(RESI)} must be set to maintain sample ordering
#'	for survival regression, otherwise set \code{RESI} to \code{NULL}.
#' @param KK An array containing double-centered positive semi-definite
#'	kernel matrices. Refer to \code{MiRKAT::D2K()} for transforming 
#'	distance matrices to kernel matrices. The \code{dimnames(KK)[[1]]} and 
#'	\code{dimnames(KK)[[2]]} must match \code{names(RESI)}.
#'	Also set dimnames(KK)[[3]] to keep track of each kernel matrix.
#' @param YY A numeric vector of continuous outcomes to be fitted
#'	in a linear model. Defaults to NULL for survival model.
#' @param XX A numeric data matrix with first column for intercept,
#'	a column of ones.
#' @param OMNI A matrix of zeros and ones. Each column corresponds to a
#'	distance matrix while each row corresponds to an omnibus test. Set
#'	\code{rownames(OMNI)} for labeling outputted p-values and 
#'	\code{colnames(OMNI)} which should match \code{dimnames(KK)[[3]]}.
#' @param nPERMS A positive integer to specify the number of
#'	permutation-based p-value calculation
#' @return A R list of p-values and omnibus p-values.
#' @export
kernTEST = function(RESI = NULL,KK,YY = NULL,XX = NULL,
	OMNI,nPERMS = 1e5,ncores = 1){
	
	REG = NULL
	samp_names = NULL
	if( !is.null(RESI) ){
		REG = "SURV"
		samp_names = names(RESI)
	}
	if( !is.null(YY) && !is.null(XX) ){
		REG = "CONT"
		if( is.null(names(YY)) ) stop("Add names(YY)")
		if( is.null(rownames(XX)) ) stop("Add rownames(XX)")
		if( !all(names(YY) == rownames(XX)) )
			stop("name labeling mismatch")
		samp_names = names(YY)
	}
	
	if( is.null(REG) || is.null(samp_names) )
		stop("not an encoded model")
	
	if( is.null(dimnames(KK)[[3]]) || is.null(dimnames(KK)[[1]]) 
		|| is.null(dimnames(KK)[[2]]) )
		stop("Specify dimnames(KK) to track sample order and kernel matrices")
	
	if( is.null(samp_names) ){
		if( REG == "SURV" ) stop("Specify names(RESI)")
		if( REG == "CONT" ) stop("Specify names(YY) and rownames(XX)")
	}
	
	if( !all(samp_names == dimnames(KK[[1]])) )
		stop("names(RESI) != dimnames(KK[[1]]")
	if( !all(samp_names == dimnames(KK[[2]])) )
		stop("names(RESI) != dimnames(KK[[2]]")
	
	# Check OMNI object
	if( !all(colnames(OMNI) == dimnames(KK)[[3]]) )
		stop("colnames(OMNI) != dimnames(KK[[3]]")
	if( is.null(rownames(OMNI)) )
		stop("Specify rownames(OMNI)")
	if( !all(c(OMNI) %in% c(0,1)) )
		stop("OMNI should take values 0 or 1")
	if( any(rowSums(OMNI) == 0) )
		stop("Each row of OMNI should contain at least one non-zero element")
	
	if( REG == "SURV" ){
		out_test = Rcpp_KernTest(RESI = RESI,cKK = KK,
			OMNI = OMNI,nPERMS = nPERMS,ncores = ncores)
	} else if( REG == "CONT" ){
		out_test = Rcpp_KernTest_FL(YY = YY,XX = XX,
			cKK = KK,OMNI = OMNI,nPERMS = nPERMS)
	}
	
	names(out_test$PVALs) = dimnames(KK)[[3]]
	names(out_test$omni_PVALs) = rownames(OMNI)
	
	return(out_test)
	
}


###

Try the ROKET package in your browser

Any scripts or data that you put into this service are public.

ROKET documentation built on April 4, 2025, 2:28 a.m.