Nothing
#' Simulate truncated multivariate normal (TMVT) using the Vecchia approximation
#'
#' @import GpGp
#' @import truncnorm
#'
#' @param lower lower bound vector for TMVT
#' @param upper upper bound vector for TMVT
#' @param delta MVT shifting parameter
#' @param df degrees of freedom
#' @param locs location (feature) matrix n X d
#' @param covName covariance function name from the `GpGp` package
#' @param covParms parameters for `covName`
#' @param m Vecchia conditioning set size
#' @param sigma dense covariance matrix, not needed when `locs` is not null
#' @param N number of samples required
#' @param verbose verbose level
#' @return n X N matrix of generated samples
#'
#' @export
mvrandt <- function(lower, upper, delta, df, locs = NULL,
covName = "matern15_isotropic", covParms = c(1.0, 0.1, 0.0),
m = 30, sigma = NULL, N = 1e3, verbose = FALSE) {
## standardize the input MVT prob -----------------------------
lower <- lower - delta
upper <- upper - delta
if (is.null(sigma)) {
n <- nrow(locs)
use_sigma <- FALSE
margin_sd <- sqrt(covParms[1])
upper <- upper / margin_sd
lower <- lower / margin_sd
covParms[1] <- 1
} else {
n <- nrow(sigma)
use_sigma <- TRUE
margin_sd <- sqrt(diag(sigma))
upper <- upper / margin_sd
lower <- lower / margin_sd
sigma <- t(t(sigma / margin_sd) / margin_sd)
}
if (n > 2000) {
stop("The package currently only support sampling TMVN/TMVT distributions
no greater than 2000 dimensions due to fast decaying acceptance rate")
}
if (any(lower < -10)) {
lower[lower < -10] <- -10
}
if (any(upper < -10)) {
upper[upper < -10] <- -10
}
if (any(lower > 10)) {
lower[lower > 10] <- 10
}
if (any(upper > 10)) {
upper[upper > 10] <- 10
}
if (any(upper < lower)) {
stop("Invalid MVT probability. Truncated marginal
probabilities have negative value(s)\n")
}
lower_upper <- matrix(0, n, 2)
lower_upper[, 1] <- lower
lower_upper[, 2] <- upper
lower <- lower_upper[, 1]
upper <- lower_upper[, 2]
## reorder --------------------------------
if (use_sigma) {
ord <- Vecc_reorder(lower, upper, m, covMat = sigma)$order
} else {
ord <- Vecc_reorder(
lower, upper, m, locs, covName, covParms
)$order
}
lower <- lower[ord]
upper <- upper[ord]
if (use_sigma) {
sigma <- sigma[ord, ord, drop = FALSE]
} else {
locs <- locs[ord, , drop = FALSE]
}
## find nearest neighbors for Vecchia --------------------------------
if (use_sigma) {
NN <- find_nn_corr(sigma, m)
} else {
NN <- GpGp::find_ordered_nn(locs, m)
}
## find Vecchia approx object -----------------------------------
if (use_sigma) {
vecc_obj <- vecc_cond_mean_var_sp(NN, covMat = sigma)
idx <- which(vecc_obj$cond_var < 0.01)
if (length(idx) > 0) {
diag(sigma)[idx] <- diag(sigma)[idx] + 0.01
vecc_obj <- vecc_cond_mean_var_sp(NN, covMat = sigma)
}
} else {
vecc_obj <- vecc_cond_mean_var_sp(NN,
locs = locs, covName = covName,
covParms = covParms
)
if (any(vecc_obj$cond_var < 0.01)) {
covParms[length(covParms)] <- 0.01 # nugget
vecc_obj <- vecc_cond_mean_var_sp(NN,
locs = locs, covName = covName,
covParms = covParms
)
}
}
## generate MVT samples ---------------------------------
samp_Vecc_ord <- mvtrnd_wrap(
lower, upper, df,
NN = NN, veccObj = vecc_obj, N = N, verbose = verbose
)
ord_rev <- integer(n)
ord_rev[ord] <- 1:n
samp_Vecc <- margin_sd * samp_Vecc_ord[ord_rev, ] + delta
return(samp_Vecc)
}
# TEST -------------------------------------------------------
# library(GpGp)
# library(mvtnorm)
# library(TruncatedNormal)
# library(VeccTMVN)
# set.seed(123)
# n1 <- 10
# n2 <- 10
# n <- n1 * n2
# locs <- as.matrix(expand.grid((1:n1) / n1, (1:n2) / n2))
# covparms <- c(2, 0.1, 0)
# mu <- rep(1, n)
# nu <- 10
# N <- 1000
# cov_mat <- matern15_isotropic(covparms, locs)
# a <- rep(-1, n)
# b <- rep(-0, n)
# samp_TN <- TruncatedNormal::mvrandt(
# a, b, cov_mat, df = nu,
# n = N, mu = mu
# )
# samp_Vecc <- VeccTMVN::mvrandt(
# a, b, mu, nu, locs, "matern15_isotropic", covparms,
# m = 30, N = N, verbose = TRUE
# )
# ## histogram for verification -------------------
# cat("Mean averaged absolute difference of prediction",
# mean(abs(rowMeans(samp_TN) - rowMeans(samp_Vecc))), "\n")
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.