eval_pred_dens: Evaluate Predictive Densities

View source: R/pred_functions.R

eval_pred_densR Documentation

Evaluate Predictive Densities

Description

eval_pred_dens evaluates the predictive density for a set of points based on a fitted shrinkGPR, shrinkTPR, shrinkMVGPR, or shrinkMVTPR model.

Usage

eval_pred_dens(x, mod, data_test, nsamp = 100, log = FALSE)

Arguments

x

For univariate models (shrinkGPR, shrinkTPR): a numeric vector of response values at which to evaluate the density. For multivariate models (shrinkMVGPR, shrinkMVTPR): a numeric matrix with M columns, where each row is a candidate response vector.

mod

A shrinkGPR, shrinkTPR, shrinkMVGPR, or shrinkMVTPR object representing the fitted model.

data_test

Data frame with one row containing the covariates for the test set. Variables in data_test must match those used in model fitting.

nsamp

Positive integer specifying the number of posterior samples to use for the evaluation. Default is 100.

log

Logical; if TRUE, returns the log predictive density. Default is FALSE.

Details

This function computes predictive densities by marginalizing over posterior samples drawn from the fitted model. If a mean equation was included in the model, the corresponding covariates are used to calculate the predictive mean.

Value

A numeric vector containing the predictive densities (or log predictive densities) for the points in x.

Examples


if (torch::torch_is_installed()) {
  # Simulate data
  set.seed(123)
  torch::torch_manual_seed(123)
  n <- 100
  x <- matrix(runif(n * 2), n, 2)
  y <- sin(2 * pi * x[, 1]) + rnorm(n, sd = 0.1)
  data <- data.frame(y = y, x1 = x[, 1], x2 = x[, 2])

  # Fit GPR model
  res <- shrinkGPR(y ~ x1 + x2, data = data)

  # Create point at which to evaluate predictive density
  data_test <- data.frame(x1 = 0.8, x2 = 0.5)
  eval_points <- c(-1.2, -1, 0)

  eval_pred_dens(eval_points, res, data_test)

  # Is vectorized, can also be used in functions like curve
  curve(eval_pred_dens(x, res, data_test), from = -1.5, to = -0.5)
  abline(v = sin(2 * pi * 0.8), col = "red")
  }


shrinkGPR documentation built on March 30, 2026, 5:06 p.m.