predict: Predictive posterior mean and variance for DNAmf object with...

predict.DNAmfR Documentation

Predictive posterior mean and variance for DNAmf object with nonseparable kernel.

Description

The function computes the predictive posterior mean and variance for the DNAmf model using closed-form expressions based on the chosen nonseparable kernel at given new input locations.

Usage

## S3 method for class 'DNAmf'
predict(object, x, targett = 0, nimpute = 50, ...)

Arguments

object

A fitted DNAmf object.

x

A vector or matrix of new input locations to predict.

targett

A numeric value of target tuning parameter to predict.

nimpute

Number of imputations for non-nested designs. Default is 50.

...

Additional arguments for compatibility with generic method predict.

Details

The predict.DNAmf function internally calls closed_form, which further calls h1_sqex, h2_sqex, h2_sqex_single for kernel="sqex", or h1_matern, h2_matern, h2_matern_single for kernel="matern1.5" orkernel="matern2.5", to recursively compute the closed-form posterior mean and variance at each level.

From the fitted model from DNAmf, the posterior mean and variance are calculated based on the closed-form expression derived by a recursive fashion. The formulas depend on its kernel choices.

If the fitted model was constructed with non-nested designs (nested=FALSE), the function generates nimpute sets of imputations for pseudo outputs via imputer.

For further details, see Heo, Boutelet, and Sung (2025+, <arXiv:2506.08328>).

Value

A list of predictive posterior mean and variance for each level and computation time containing:

  • mu_1, sig2_1, ..., mu_L, sig2_L: A vector of predictive posterior mean and variance at each level.

  • mu: A vector of predictive posterior mean at target tuning parameter.

  • sig2: A vector of predictive posterior variance at target tuning parameter targett.

  • time: Total computation time in seconds.

See Also

DNAmf for the user-level function.

Examples

### Non-Additive example ###
library(RNAmf)

### Non-Additive Function ###
fl <- function(x, t){
  term1 <- sin(10 * pi * x / (5+t))
  term2 <- 0.2 * sin(8 * pi * x)
  term1 + term2
}

### training data ###
n1 <- 13; n2 <- 10; n3 <- 7; n4 <- 4; n5 <- 1;
m1 <- 2.5; m2 <- 2.0; m3 <- 1.5; m4 <- 1.0; m5 <- 0.5;
d <- 1
eps <- sqrt(.Machine$double.eps)
x <- seq(0,1,0.01)

### fix seed to reproduce the result ###
set.seed(1)

### generate initial nested design ###
NestDesign <- NestedX(c(n1,n2,n3,n4,n5),d)

X1 <- NestDesign[[1]]
X2 <- NestDesign[[2]]
X3 <- NestDesign[[3]]
X4 <- NestDesign[[4]]
X5 <- NestDesign[[5]]

y1 <- fl(X1, t=m1)
y2 <- fl(X2, t=m2)
y3 <- fl(X3, t=m3)
y4 <- fl(X4, t=m4)
y5 <- fl(X5, t=m5)

### fit a DNAmf ###
fit.DNAmf <- DNAmf(X=list(X1, X2, X3, X4, X5), y=list(y1, y2, y3, y4, y5), kernel="sqex",
                   t=c(m1,m2,m3,m4,m5), multi.start=10, constant=TRUE)

### predict ###
pred.DNAmf <- predict(fit.DNAmf, x, targett=0)
predydiffu <- pred.DNAmf$mu
predsig2diffu <- pred.DNAmf$sig2

### RMSE ###
print(sqrt(mean((predydiffu-fl(x, t=0))^2))) # 0.1162579

### visualize the emulation performance ###
oldpar <- par(mfrow = c(2,3))
create_plot_base <- function(i, mesh_size, x, pred_mu, pred_sig2,
                             X_points = NULL, y_points = NULL, add_points = TRUE, yylim) {
  lower <- pred_mu - qnorm(0.995) * sqrt(pred_sig2)
  upper <- pred_mu + qnorm(0.995) * sqrt(pred_sig2)

  plot(x, pred_mu, type = "n", ylim = c(-yylim, yylim), xlab = "", ylab = "",
       main = paste0("Mesh size = ", mesh_size), axes = FALSE)
  box()

  polygon(c(x, rev(x)), c(upper, rev(lower)),
          col = adjustcolor("blue", alpha.f = 0.2), border = NA)
  lines(x, pred_mu, col = "blue", lwd = 2)
  lines(x, fl(x, mesh_size), lty = 2, col = "black", lwd = 2)

  if (add_points && !is.null(X_points) && !is.null(y_points)) {
    points(X_points, y_points, col = "red", pch = 16, cex = 1.3)
  }
}

mesh_sizes <- c(m1, m2, m3, m4, m5, 0)
mu_list <- list(pred.DNAmf$mu_1, pred.DNAmf$mu_2, pred.DNAmf$mu_3,
                pred.DNAmf$mu_4, pred.DNAmf$mu_5, pred.DNAmf$mu)
sig2_list <- list(pred.DNAmf$sig2_1, pred.DNAmf$sig2_2, pred.DNAmf$sig2_3,
                  pred.DNAmf$sig2_4, pred.DNAmf$sig2_5, pred.DNAmf$sig2)
X_list <- list(X1, X2, X3, X4, X5, NULL)
y_list <- list(y1, y2, y3, y4, y5, NULL)

plots <- mapply(function(i, m, mu, sig2, X, y) {
  create_plot_base(i, m, x, mu, sig2, X, y, add_points = !is.null(X), yylim=1.5)
}, i = 1:6, m = mesh_sizes, mu = mu_list, sig2 = sig2_list,
X = X_list, y = y_list, SIMPLIFY = FALSE)
par(oldpar)


DNAmf documentation built on June 23, 2025, 5:08 p.m.

Related to predict in DNAmf...