calc_pred_moments: Calculate Predictive Moments

View source: R/pred_functions.R

calc_pred_momentsR Documentation

Calculate Predictive Moments

Description

calc_pred_moments calculates the predictive means and variances for a fitted shrinkGPR, shrinkTPR, shrinkMVGPR, or shrinkMVTPR model at new data points.

Usage

calc_pred_moments(object, newdata, nsamp = 100)

Arguments

object

A shrinkGPR, shrinkTPR, shrinkMVGPR, or shrinkMVTPR object representing the fitted univariate or multivariate Gaussian or t process regression model.

newdata

Optional data frame containing the covariates for the new data points. If missing, the training data is used.

nsamp

Positive integer specifying the number of posterior samples to use for the calculation. Default is 100.

Details

This function computes predictive moments by marginalizing over posterior samples from the fitted model. If a mean equation was included in the model, the corresponding covariates are used to calculate the predictive mean.

Value

For univariate models (shrinkGPR, shrinkTPR), a list with:

  • means: An array of predictive means, with the first dimension corresponding to samples, the second to data points.

  • vars: An array of predictive variances, with the first dimension corresponding to samples and second and third to data points.

Additionally, for a shrinkTPR model, the list also includes:

  • nu: A vector of posterior degrees of freedom of length nsamp.

For multivariate models (shrinkMVGPR, shrinkMVTPR), a list with:

  • means: An array of predictive means of shape nsamp x N_new x M.

  • K: An array of posterior row covariance matrices of shape nsamp x N_new x N_new.

  • Omega: An array of posterior column covariance matrices of shape nsamp x M x M.

  • nu: (shrinkMVTPR only) A vector of posterior degrees of freedom of length nsamp.

Examples


if (torch::torch_is_installed()) {
  # Simulate data
  set.seed(123)
  torch::torch_manual_seed(123)
  n <- 100
  x <- matrix(runif(n * 2), n, 2)
  y <- sin(2 * pi * x[, 1]) + rnorm(n, sd = 0.1)
  data <- data.frame(y = y, x1 = x[, 1], x2 = x[, 2])

  # Fit GPR model
  res <- shrinkGPR(y ~ x1 + x2, data = data)

  # Calculate predictive moments
  momes <- calc_pred_moments(res, nsamp = 100)
  }


shrinkGPR documentation built on March 30, 2026, 5:06 p.m.