R/gauss_medianpd.R

Defines functions gaussmedpd

Documented in gaussmedpd

#' Wasserstein Median of Gaussian Distributions in \eqn{\mathbf{R}^p}
#' 
#' Given a collection of \eqn{p}-dimensional Gaussian distributions \eqn{\mathcal{N}(\mu_i, \sigma_i^2)} for \eqn{i=1,\ldots,n}, 
#' compute the Wasserstein median.
#' 
#' @param means an \eqn{(n\times p)} matrix whose rows are mean vectors.
#' @param vars a \eqn{(p\times p\times n)} array where each slice is covariance matrix.
#' @param weights a weight of each image; if \code{NULL} (default), uniform weight is set. Otherwise, it should be a length-\eqn{n} vector of nonnegative weights.
#' @param ... extra parameters including \describe{
#' \item{abstol}{stopping criterion for iterations (default: 1e-8).}
#' \item{maxiter}{maximum number of iterations (default: 496).}
#' }
#' 
#' @examples 
#' \donttest{
#' #----------------------------------------------------------------------
#' #                         Three Gaussians in R^2
#' #----------------------------------------------------------------------
#' # GENERATE PARAMETERS
#' # means
#' par_mean = rbind(c(-4,0), c(0,0), c(5,-1))
#' 
#' # covariances
#' par_vars = array(0,c(2,2,3))
#' par_vars[,,1] = cbind(c(2,-1),c(-1,2))
#' par_vars[,,2] = cbind(c(4,+1),c(+1,4))
#' par_vars[,,3] = diag(c(4,1))
#' 
#' # COMPUTE THE MEDIAN
#' gmeds = gaussmedpd(par_mean, par_vars)
#' 
#' # COMPUTE THE BARYCENTER 
#' gmean = gaussbarypd(par_mean, par_vars)
#' 
#' # GET COORDINATES FOR DRAWING
#' pt_type1 = gaussvis2d(par_mean[1,], par_vars[,,1])
#' pt_type2 = gaussvis2d(par_mean[2,], par_vars[,,2])
#' pt_type3 = gaussvis2d(par_mean[3,], par_vars[,,3])
#' pt_gmean = gaussvis2d(gmean$mean, gmean$var)
#' pt_gmeds = gaussvis2d(gmeds$mean, gmeds$var)
#' 
#' # VISUALIZE
#' opar <- par(no.readonly=TRUE)
#' plot(pt_gmean, lwd=2, col="red", type="l",
#'      main="Three Gaussians", xlab="", ylab="", 
#'      xlim=c(-6,8), ylim=c(-2.5,2.5))
#' lines(pt_gmeds, lwd=2, col="blue")
#' lines(pt_type1, lty=2, lwd=5)
#' lines(pt_type2, lty=2, lwd=5)
#' lines(pt_type3, lty=2, lwd=5)
#' abline(h=0, col="grey80", lty=3)
#' abline(v=0, col="grey80", lty=3)
#' legend("topright", legend=c("Median","Barycenter"),
#'        lwd=2, lty=1, col=c("blue","red"))
#' par(opar)
#' }
#' 
#' @return a named list containing \describe{
#' \item{mean}{a length-\eqn{p} vector for mean of the estimated median distribution.}
#' \item{var}{a \eqn{(p\times p)} matrix for variance of the estimated median distribution.}
#' }
#' 
#' @seealso [T4transport::gaussmed1d()] for univariate case.
#' @concept gaussian
#' @export
gaussmedpd <- function(means, vars, weights=NULL, ...){
  # --------------------------------------------------------------------------
  # INPUT : EXPLICIT
  # data
  if (!gauss_checknd(means, vars)){
    stop("* gaussmedpd : input 'means' and 'vars' are not valid.")
  }
  N = base::nrow(means)
  P = base::ncol(means)
  
  # weight
  name.f = "gaussmedpd"
  weight = valid_weight(weights, N, "weights", name.f)
  
  # --------------------------------------------------------------------------
  # INPUT : IMPLICIT
  params = list(...)
  pnames = names(params)
  
  if ("maxiter"%in%pnames){
    par_iter = max(5, round(params$maxiter))
  } else {
    par_iter = 496
  }
  if ("abstol"%in%pnames){
    par_tol = max(100*.Machine$double.eps, as.double(params$abstol))
  } else {
    par_tol = 1e-8
  }
  
  # --------------------------------------------------------------------------
  # COMPUTE
  # run the product-manifold algorithm
  out_run = gauss_median_general(means, vars, weight, par_tol, par_iter)
  
  # separate out
  out_mean = as.vector(out_run$mean)
  out_var  = as.matrix(out_run$var)
  
  # --------------------------------------------------------------------------
  # RETURN
  return(list(mean=out_mean, var=out_var))
}

Try the T4transport package in your browser

Any scripts or data that you put into this service are public.

T4transport documentation built on April 12, 2023, 12:37 p.m.