distCompare: Compares Optimal Transport Distances Between WpProj and...

View source: R/distanceCompare.R

distCompareR Documentation

Compares Optimal Transport Distances Between WpProj and Original Models

Description

[Experimental] Will compare the Wasserstein distance between the original model and the WpProj model.

Usage

distCompare(
  models,
  target = list(parameters = NULL, predictions = NULL),
  power = 2,
  method = "exact",
  quantity = c("parameters", "predictions"),
  parallel = NULL,
  transform = function(x) {
     return(x)
 },
  ...
)

Arguments

models

A list of models from WpProj methods

target

The target to compare the methods to. Should be a list with slots "parameters" to compare the parameters and "predictions" to compare predictions

power

The power parameter of the Wasserstein distance.

method

Which approximation to the Wasserstein distance to use. Should be one of the outputs of transport_options().

quantity

Should the function target the "parameters" or the "predictions". Can choose both.

parallel

Parallel backend to use for the foreach package. See ⁠foreach::registerDoParallel(⁠ for more details.

transform

Transformation function for the predictions.

...

other options passed to the wasserstein() distance function

Details

For the data frames, dist is the Wasserstein distance, nactive is the number of active variables in the model, groups is the name distinguishing the model, and method is the method used to calculate the distance (i.e., exact, sinkhorn, etc.). If the list in models is named, these will be used as the group names otherwise the group names will be created based on the call from the WpProj method.

Value

an object of class distcompare with slots parameters, predictions, and p. The slots parameters and predictions are data frames. See the details for more info. The slot p is the power parameter of the Wasserstein distance used in the distance calculation.

Examples

if(rlang::is_installed("stats")) {
n <- 32
p <- 10
s <- 21
x <- matrix( stats::rnorm( p * n ), nrow = n, ncol = p )
beta <- (1:10)/10
y <- x %*% beta + stats::rnorm(n)
post_beta <- matrix(beta, nrow=p, ncol=s) + stats::rnorm(p*s, 0, 0.1)
post_mu <- x %*% post_beta

fit1 <-  WpProj(X=x, eta=post_mu, power = 2.0,
               options = list(penalty = "lasso")
)
fit2 <-  WpProj(X=x, eta=post_mu, theta = post_beta, power = 2.0,
               method = "binary program", solver = "lasso",
               options = list(solver.options = list(penalty = "mcp"))
)
dc <- distCompare(models = list("L1" = fit1, "BP" = fit2),
                 target = list(parameters = post_beta, predictions = post_mu))
plot(dc)
}

WpProj documentation built on May 29, 2024, 7:55 a.m.