R/wasserstein.R

Defines functions wasserstein

Documented in wasserstein

wasserstein <- function(x1, x2, check = FALSE)
{
  # Wasserstein distance between two Gaussian distributions
  #
  # x1, x2: the data.

  p <- NCOL(x1)
  if (NCOL(x2) != p)
    stop("x1 and x2 must be two vectors, or have the same number of columns.")
  
  if (p == 1) {
    # Univariate Gaussian densities:
    if (is.data.frame(x1)) x1 <- x1[, 1]
    if (is.data.frame(x2)) x2 <- x2[, 1]
    if (is.matrix(x1)) x1 <- drop(x1)
    if (is.matrix(x2)) x2 <- drop(x2)
    x2 <- as.numeric(x2)
    return(wassersteinpar(mean(x1),var(x1),mean(x2),var(x2),check=check))
  } else {
    # Multivariate Gaussian densities:
    return(wassersteinpar(colMeans(x1),var(x1),colMeans(x2),var(x2),check=check))
  }
}

Try the dad package in your browser

Any scripts or data that you put into this service are public.

dad documentation built on Aug. 30, 2023, 5:06 p.m.