crit_logEI: Logarithm of Expected Improvement criterion

View source: R/optim.R

crit_logEIR Documentation

Logarithm of Expected Improvement criterion

Description

Computes log of EI for minimization, with improved stability with respect to EI

Usage

crit_logEI(x, model, cst = NULL, preds = NULL)

Arguments

x

matrix of new designs, one point per row (size n x d)

model

homGP or hetGP model, or their TP equivalents, including inverse matrices. For TP models, the computation is using the one from regular EI.

cst

optional plugin value used in the EI, see details

preds

optional predictions at x to avoid recomputing if already done

Details

cst is classically the observed minimum in the deterministic case. In the noisy case, the min of the predictive mean works fine.

Note

This is a beta version at this point.

References

Ament, S., Daulton, S., Eriksson, D., Balandat, M., & Bakshy, E. (2024). Unexpected improvements to expected improvement for Bayesian optimization. Advances in Neural Information Processing Systems, 36.

See Also

crit_EI for the regular EI criterion and compare the outcomes

Examples

## Optimization example
set.seed(42)


## Noise field via standard deviation
noiseFun <- function(x, coef = 1.1, scale = 1){
if(is.null(nrow(x)))
 x <- matrix(x, nrow = 1)
   return(scale*(coef + cos(x * 2 * pi)))
}

## Test function defined in [0,1]
ftest <- function(x){
if(is.null(nrow(x)))
x <- matrix(x, ncol = 1)
return(f1d(x) + rnorm(nrow(x), mean = 0, sd = noiseFun(x)))
}

n_init <- 10 # number of unique designs
N_init <- 100 # total number of points
X <- seq(0, 1, length.out = n_init)
X <- matrix(X[sample(1:n_init, N_init, replace = TRUE)], ncol = 1)
Z <- ftest(X)

## Predictive grid
ngrid <- 51
xgrid <- seq(0,1, length.out = ngrid)
Xgrid <- matrix(xgrid, ncol = 1)

model <- mleHetGP(X = X, Z = Z, lower = 0.001, upper = 1)

logEIgrid <- crit_logEI(Xgrid, model)
preds <- predict(x = Xgrid, model)

par(mar = c(3,3,2,3)+0.1)
plot(xgrid, f1d(xgrid), type = 'l', lwd = 1, col = "blue", lty = 3,
xlab = '', ylab = '', ylim = c(-8,16))
points(X, Z)
lines(Xgrid, preds$mean, col = 'red', lwd = 2)
lines(Xgrid, qnorm(0.05, preds$mean, sqrt(preds$sd2)), col = 2, lty = 2)
lines(Xgrid, qnorm(0.95, preds$mean, sqrt(preds$sd2)), col = 2, lty = 2)
lines(Xgrid, qnorm(0.05, preds$mean, sqrt(preds$sd2 + preds$nugs)), col = 3, lty = 2)
lines(Xgrid, qnorm(0.95, preds$mean, sqrt(preds$sd2 + preds$nugs)), col = 3, lty = 2)
par(new = TRUE)
plot(NA, NA, xlim = c(0, 1), ylim = range(logEIgrid), axes = FALSE, ylab = "", xlab = "")
lines(xgrid, logEIgrid, lwd = 2, col = 'cyan')
axis(side = 4)
mtext(side = 4, line = 2, expression(logEI(x)), cex = 0.8)
mtext(side = 2, line = 2, expression(f(x)), cex = 0.8)

hetGP documentation built on Sept. 11, 2024, 6:56 p.m.