shrinkTPR: Student-t Process Regression with Shrinkage and Normalizing...

View source: R/shrinkTPR.R

shrinkTPRR Documentation

Student-t Process Regression with Shrinkage and Normalizing Flows

Description

Fits a Student-t process regression (TPR) model (Shah et al. 2014) with triple-gamma shrinkage priors on the inverse length-scale parameters \theta_j. Compared to shrinkGPR, the Student-t process has heavier tails governed by the degrees of freedom \nu, providing greater robustness to outliers. An optional linear mean can be added via formula_mean. The joint posterior is approximated by normalizing flows trained to maximize the ELBO.

Usage

shrinkTPR(
  formula,
  data,
  a = 0.5,
  c = 0.5,
  formula_mean,
  a_mean = 0.5,
  c_mean = 0.5,
  sigma2_rate = 10,
  nu_alpha = 0.5,
  nu_beta = 2,
  kernel_func = kernel_se,
  n_layers = 10,
  n_latent = 10,
  flow_func = sylvester,
  flow_args,
  n_epochs = 1000,
  auto_stop = TRUE,
  cont_model,
  device,
  display_progress = TRUE,
  optim_control
)

Arguments

formula

object of class "formula": a symbolic representation of the model for the covariance equation, as in lm. The response variable and covariates are specified here.

data

optional data frame containing the response variable and the covariates. If not found in data, the variables are taken from environment(formula). No NAs are allowed in the response variable or covariates.

a

positive real number controlling the behavior at the origin of the shrinkage prior for the covariance structure. The default is 0.5.

c

positive real number controlling the tail behavior of the shrinkage prior for the covariance structure. The default is 0.5.

formula_mean

optional formula for the linear mean equation. If provided, the covariates for the mean structure are specified separately from the covariance structure. A response variable is not required in this formula.

a_mean

positive real number controlling the behavior at the origin of the shrinkage for the mean structure. The default is 0.5.

c_mean

positive real number controlling the tail behavior of the shrinkage prior for the mean structure. The default is 0.5.

sigma2_rate

positive real number controlling the prior rate parameter for the residual variance. The default is 10.

nu_alpha

positive real number controlling the shape parameter of the shifted gamma prior for the degrees of freedom of the Student-t process. The default is 0.5.

nu_beta

positive real number controlling the rate parameter of the shifted gamma prior for the degrees of freedom of the Student-t process. The default is 2.

kernel_func

function specifying the covariance kernel. The default is kernel_se, a squared exponential kernel. For guidance on how to provide a custom kernel function, see Details.

n_layers

positive integer specifying the number of flow layers in the normalizing flow. The default is 10.

n_latent

positive integer specifying the dimensionality of the latent space for the normalizing flow. The default is 10.

flow_func

function specifying the normalizing flow transformation. The default is sylvester. For guidance on how to provide a custom flow function, see Details.

flow_args

optional named list containing arguments for the flow function. If not provided, default arguments are used. For guidance on how to provide a custom flow function, see Details.

n_epochs

positive integer specifying the number of training epochs. The default is 1000.

auto_stop

logical value indicating whether to enable early stopping based on convergence. The default is TRUE.

cont_model

optional object returned from a previous shrinkTPR call, enabling continuation of training from the saved state.

device

optional device to run the model on, e.g., torch_device("cuda") for GPU or torch_device("cpu") for CPU. Defaults to GPU if available; otherwise, CPU.

display_progress

logical value indicating whether to display progress bars and messages during training. The default is TRUE.

optim_control

optional named list containing optimizer parameters. If not provided, default settings are used.

Details

Model Specification

f is a Student-t process if any finite collection of function values has a joint multivariate Student-t distribution. Given N observations with d-dimensional covariates x_i, the joint density is thus

(f(x_1), \ldots, f(x_N)) \sim t_N\!\left(\nu,\, \mu(x_1, \ldots, x_N),\, K(\theta, \tau)\right),

. which means that f follows \mathcal{TP}(\nu, \mu, k(\cdot, \cdot\,;\, \theta, \tau)) Student-t process with \nu degrees of freedom, mean function \mu, and covariance kernel k. As opposed to a Gaussian process regression model, the noise is added directly to the kernel, so the likelihood for the observations is Y \sim t_N\!\left(\nu,\, \mu(x_1, \ldots, x_N),\, K(\theta, \tau) + \sigma^2 I\right). The default squared exponential kernel is

k(x, x';\, \theta, \tau) = \frac{1}{\tau} \exp\!\left(-\frac{1}{2} \sum_{j=1}^d \theta_j (x_j - x'_j)^2\right),

where \theta_j \ge 0 are inverse squared length-scales and \tau > 0 is the output scale. Users can specify custom kernels by following the guidelines below, or use one of the other provided kernel functions in kernel_functions.

If formula_mean is provided, the process mean becomes x_{\mu,i}^\top \beta.

Priors

\theta_j \mid \tau \sim \mathrm{TG}(a, c, \tau), \quad j = 1, \ldots, d,

\tau \sim F(2c, 2a),

\sigma^2 \sim \mathrm{Exp}(\sigma^2_\mathrm{rate}),

\nu - 2 \sim \mathrm{Gamma}(\nu_\alpha, \nu_\beta).

The shift by 2 ensures \nu > 2 so that the process variance is finite. With a mean structure, \beta_k \mid \tau_\mu \sim \mathrm{NGG}(a_\mu, c_\mu, \tau_\mu) and \tau_\mu \sim F(2c_\mu, 2a_\mu).

Inference

The posterior is approximated by a normalizing flow q_\phi trained to maximize the ELBO. auto_stop triggers early stopping when the ELBO shows no significant improvement over the last 100 iterations.

Custom Kernel Functions

Users can define custom kernel functions by passing them to the kernel_func argument. A valid kernel function must follow the same structure as kernel_se. The function must:

  1. Accept arguments thetas (n_latent x d), tau (length n_latent), x (N x d), and optionally x_star (N_new x d).

  2. Return a torch_tensor of dimensions n_latent x N x N (if x_star = NULL) or n_latent x N_new x N (if x_star is provided).

  3. Produce a valid positive semi-definite covariance matrix using torch tensor operations.

See kernel_functions for documented examples.

Custom Flow Functions

Users can define custom flow functions by implementing an nn_module in torch. The module must have a forward method that accepts a tensor z of shape n_latent x D and returns a list with:

  • zk: the transformed samples, shape n_latent x D.

  • log_diag_j: log-absolute-determinant of the Jacobian, shape n_latent.

See sylvester for a documented example.

Value

A list object of class shrinkTPR, containing:

model

The best-performing trained model.

loss

The best loss value (ELBO) achieved during training.

loss_stor

A numeric vector storing the ELBO values at each training iteration.

last_model

The model state at the final iteration.

optimizer

The optimizer object used during training.

model_internals

Internal objects required for predictions and further training, such as model matrices and formulas.

Author(s)

Peter Knaus peter.knaus@wu.ac.at

References

Shah, A., Wilson, A., & Ghahramani, Z. (2014, April). Student-t processes as alternatives to Gaussian processes. In Artificial intelligence and statistics (pp. 877-885). PMLR.

Examples


if (torch::torch_is_installed()) {
  # Simulate data
  set.seed(123)
  torch::torch_manual_seed(123)
  n <- 100
  x <- matrix(runif(n * 2), n, 2)
  y <- sin(2 * pi * x[, 1]) + rnorm(n, sd = 0.1)
  data <- data.frame(y = y, x1 = x[, 1], x2 = x[, 2])

  # Fit TPR model
  res <- shrinkTPR(y ~ x1 + x2, data = data)

  # Check convergence
  plot(res$loss_stor, type = "l", main = "Loss Over Iterations")

  # Check posterior
  samps <- gen_posterior_samples(res, nsamp = 1000)
  boxplot(samps$thetas) # Second theta is pulled towards zero

  # Predict
  x1_new <- seq(from = 0, to = 1, length.out = 100)
  x2_new <- runif(100)
  y_new <- predict(res, newdata = data.frame(x1 = x1_new, x2 = x2_new), nsamp = 2000)

  # Plot
  quants <- apply(y_new, 2, quantile, c(0.025, 0.5, 0.975))
  plot(x1_new, quants[2, ], type = "l", ylim = c(-1.5, 1.5),
        xlab = "x1", ylab = "y", lwd = 2)
  polygon(c(x1_new, rev(x1_new)), c(quants[1, ], rev(quants[3, ])),
        col = adjustcolor("skyblue", alpha.f = 0.5), border = NA)
  points(x[,1], y)
  curve(sin(2 * pi * x), add = TRUE, col = "forestgreen", lwd = 2, lty = 2)



  # Add mean equation
  res2 <- shrinkTPR(y ~ x1 + x2, formula_mean = ~ x1, data = data)
  }


shrinkGPR documentation built on March 30, 2026, 5:06 p.m.