R/ws_maaipm.R

Defines functions ws_bary_maaipm

Documented in ws_bary_maaipm

#' Solves the 2-Wasserstein Barycenter problem between N probability measures on R^d using an interior point method.
#' @description This is a wrapper function for multiple methods to solve the 2-Wasserstein barycenter problem. It 
#' contains three methods:
#' "fixed": This function finds the best approximation of the 2-Wasserstein barycenter problem of N finitely supported input
#' measures on a given support set using a modified MAAIPM algorithm to solve the corresponding linear program. 
#' "free": This functions finds an approximation of the 2-Wasserstein barycenter using the MAAIPM method by
#' alternating between updating weights and positions of a candidate barycenter.
#' "multiscale": This finds the best approximation of the 2-Wasserstein barycenter problem of N finitely supported input
#' measures by using a multi-scale version of a modified MAAIPM method. Given a starting grid it solves the fixed support 
#' 2-Wasserstein barycenter problem on this grid. Then, the grid is refined, by splitting each grid point into 4 new ones.
#' Afterwards, all grids points below a certain threshold of mass are removed from the support. This procedure is repeated
#' until a prespecified resolution is reached. Note, the generated grids are assumed to be in [0,1]^2.
#' @param data.list A list of objects from which the barycenter should be computed. Each element should be one of the following:
#' A matrix, representing an image; A path to a file containing an image; 
#' A \link[transport]{wpp-object}; 
#' A list containing an entry named `positions` with the support of the measure and an entry named `weights` containing the weights of the support points;
#' A list containing en entry named `positions`` specifying the support of a measure with uniform weights.  
#' @param method A string determining which method is used. The available options "fixed", "free" and "multiscale"
#' are described above.
#' @param support The role of this parameter changes depending of the method used.
#' "fixed": This is a d x M matrix containing the positions of the fixed-support of the barycenter in R^d.
#' "free": This is a d x M matrix containing the initial positions of the barycenter approximation.
#' "multiscale": A vector with four entries. The first two are integers giving the resolution of the initial grid.
#' The third entry is another integer specifying how many times the grid should be refined. The fourth entry is
#' a real number specifying the threshold under which mass is considered to be zero.
#' @param wmaxIter An integer specifying the maximum number of weight iterations to be performed.
#' @param pmaxIter An integer specifying the maximum number of weight iterations to be performed for the "free"-
#' method.
#' @param return_type A string specifying the format in which the barycenter should be returned. For all methods
#' the options "default" (giving a list with an entry `positions` containing the support of the barycenter
#' and an entry `weights` containing the weights of the barycenter) and "wpp" (giving a \link[transport]{wpp-object})
#' are available. Additionally, for the "fixed" method there is the type "vec" which returns a vector of length
#' M, containing the weights of the barycenter on the given support and for the "multiscale" method there is the 
#' option "mat" which returns the barycenter in matrix form on a grid of the final resolution.
#' @param thresh A real number specifying a stopping criterion based on the magnitude of change between consecutive 
#' iterations. If one encounters numerical instabilities in the computations in the form of either returned NaNs
#' or warnings notifying the user about near singular matrices, this parameter can be increased to avoid this. 
#' @param threads An integer specifying the number of threads used for computations.
#' @return For details on the returned value refer to the parameter return_type.
#' @examples
#' #Generate a dataset consisting of measures supported on discretized nested ellipses.
#' N<-5 #The number of measures generated
#' M<-20 #The number of points each ellipse is discretized into
#' C<-2 #The parameter of the Kantorovich-Rubinstein distance
#' data.list<-vector("list",N)
#' set.seed(42)
#' ell.num<-3 #The  number of ellipses in each measure.
#' #This loop actually generates the measures for the example.
#' for (i in 1:N){
#'   pos.full<-matrix(0,0,2)
#'   nesting.depth<-ell.num
#'   for (k in 1:nesting.depth){
#'     t.vec<-seq(0,2*pi,length.out=M)
#'     pos<-cbind(cos(t.vec)*runif(1,0.2,1),sin(t.vec)*runif(1,0.2,1))/(3^(k-1)) 
#'     theta<-runif(1,0,2*pi)
#'     rotation<-matrix(c(cos(theta),sin(theta),-1*sin(theta),cos(theta)),2,2)
#'     pos.full<-rbind(pos.full,pos%*%rotation)
#'   }
#'   W<-rep(1,M*nesting.depth)
#'   W<-W/sum(W)
#'   data.list[[i]]<-transport::wpp((pos.full+1)/2,W)
#' }
#' #Using the multiscale method
#' system.time(bary.ms<-WSGeometry::ws_bary_maaipm(data.list,method="multiscale",
#' support=c(8,8,3,10^-4),wmaxIter=100,return_type="mat",thresh=6*10^-4,threads=1))
#' \donttest{
#' #Using the fixed support method
#' support<-t(WSGeometry::grid_positions(20,20))
#' system.time(bary.fixed<-WSGeometry::ws_bary_maaipm(data.list,method="fixed",
#' support=support,wmaxIter=100,return_type="wpp",thresh=6*10^-4,threads=1))
#' #Using the free support method
#' support<-t(WSGeometry::grid_positions(8,8))
#' system.time(bary.free<-WSGeometry::ws_bary_maaipm(data.list,method="free",
#' support=support,wmaxIter=100,pmaxIter=25,return_type="wpp",thresh=6*10^-4,threads=1))
#' 
#' #The outputs can be conveniently visualised using the image function for the "mat" output
#' #and the plot-method for the wpp-objects provided by the transport package.
#' image(bary.ms)
#' plot(bary.fixed)
#' plot(bary.free)
#' }
#' @references Ge, DongDong, et al. "Interior-Point Methods Strike Back: Solving the Wasserstein Barycenter Problem." 
#' Advances in Neural Information Processing Systems 32 (2019): 6894-6905.
#' Kantorovich-Rubinstein distance and barycenter for finitely supported measures: Foundations and Algorithms; Heinemann, Klatt and Munk; https://arxiv.org/pdf/2112.03581.pdf. 
#' @export
ws_bary_maaipm<-function(data.list,method="fixed",support,wmaxIter,pmaxIter,return_type="default",thresh=10^-3,threads=1){
  N<-length(data.list)
  types<-lapply(data.list,type_check)
  data.list<-mapply(process_data,data.list,types,SIMPLIFY = FALSE)
  run<-FALSE
  if (method=="fixed"){
    run<-TRUE
    res<-maaipm_fixed_wrap(lapply(lapply(data.list,"[[",1),t),lapply(data.list,"[[",2),support,wmaxIter,NULL,thresh,threads)
    if (return_type=="default"){
      return(list(positions=support,weights=res))
    }
    if (return_type=="wpp"){
      return(wpp(t(support),res))
    }
    if (return_type=="vec"){
      return(res)
    }
    warning("The chosen return type is not supported. Using the default instead.")
    return(list(positions=support,weights=res))
  }
  if (method=="free"){
    run<-TRUE
    res<-maaipm_free_wrap(lapply(lapply(data.list,"[[",1),t),lapply(data.list,"[[",2),support,wmaxIter,pmaxIter,NULL,thresh,threads)
    if (return_type=="default"){
      return(list(positions=res[[1]],weights=res[[2]]))
    }
    if (return_type=="wpp"){
      return(wpp(t(res[[1]]),res[[2]]))
    }
    warning("The chosen return type is not supported. Using the default instead.")
    return(list(positions=res[[1]],weights=res[[2]]))
  }
  if (method=="multiscale"){
    run<-TRUE
    res<-ws_multi_scale_bary(lapply(lapply(data.list,"[[",1),t),lapply(data.list,"[[",2),support[1:2],support[3],support[4],wmaxIter,NULL,thresh,threads)
    if (return_type=="default"){
      return(list(positions=res[[1]],weights=res[[2]]))
    }
    if (return_type=="wpp"){
      return(wpp(t(res[[1]]),res[[2]]))
    }
    if (return_type=="mat"){
      return(res[[3]])
    }
    warning("The chosen return type is not supported. Using the default instead.")
    return(list(positions=res[[1]],weights=res[[2]]))
  }
  if (!run){
    stop("Invalid method chosen. The available methods are called fixed, free and multiscale. Please, refer 
       to the documentation of this function for more details.")
  }
}

Try the WSGeometry package in your browser

Any scripts or data that you put into this service are public.

WSGeometry documentation built on Dec. 15, 2021, 1:08 a.m.