Nothing
#' @importFrom R6 "R6Class"
#' @import torch
#' @title R6 class for kernel's hyperparameter sampling
#'
#' @description The KernelParamSampler encapsulate all the behavior related to
#' the sampling of the kernel hyperparameters
#'
#' @noRd
KernelParamSampler <- R6::R6Class(
'KernelParamSampler',
public = list(
kernel = NULL,
marginal_ll_eval_fn = NULL,
initialize = function(
kernel,
marginal_ll_eval_fn
) {
self$kernel <- kernel
self$marginal_ll_eval_fn <- marginal_ll_eval_fn
},
initialize_theta_bounds = function(param) {
theta_range <- param$slice_sampling_scale * as.numeric(TSR$rand(1)$cpu())
theta_min <- max(log(param$value) - theta_range, log(param$lower_bound))
theta_max <- min(theta_min + param$slice_sampling_scale, log(param$upper_bound))
return(list(min = theta_min, max = theta_max))
},
prior_fn = function(param) {
return(-0.5 * param$hparam_precision * log(param$value) ** 2)
},
sample_rand_theta_value = function(theta_min, theta_max) {
return(theta_min + (theta_max - theta_min) * as.numeric(TSR$rand(1)$cpu()))
},
sample_param = function(param) {
theta_bounds <- self$initialize_theta_bounds(param)
theta_min <- theta_bounds$min
theta_max <- theta_bounds$max
initial_theta <- log(param$value)
self$kernel$kernel_gen()
initial_marginal_likelihood <- self$marginal_ll_eval_fn() + self$prior_fn(param)
density_threshold <- as.numeric(TSR$rand(1)$cpu())
while (TRUE) {
new_theta <- self$sample_rand_theta_value(theta_min, theta_max)
param$value <- exp(new_theta)
self$kernel$kernel_gen()
new_marginal_likelihood <- self$marginal_ll_eval_fn() + self$prior_fn(param)
marg_ll_diff <- new_marginal_likelihood - initial_marginal_likelihood
if (exp(marg_ll_diff) > density_threshold) {
return(param$value)
}
if (new_theta < initial_theta) {
theta_min <- new_theta
} else {
theta_max <- new_theta
}
}
},
sample = function() {
for (param in self$kernel$parameters) {
if (!param$is_fixed) {
self$sample_param(param)
}
}
}
)
)
#' @title Sample a tensor of random values from a normal multivariate distribution
#'
#' @description The sampling use a tensor of mean and the upper triangular portion of the precision matrix
#'
#' @noRd
sample_norm_multivariate <- function(mean_vec, precision_upper_tri) {
# TODO Open PR & Issue for https://github.com/mlverse/torch/blob/main/R/distributions-multivariate_normal.R L:86
# Not Able to use the precision matrix because of priority of ops (!is.null(NULL) + !is.null(1) + !is.null(1)) == F
# ERROR comes from torch::distr_multivariate_normal(torch::torch_zeros(2), precision_matrix = torch::torch_eye(2))
return(
torch::linalg_solve_triangular(
precision_upper_tri,
TSR$randn_like(mean_vec)$unsqueeze(2),
upper = TRUE
)$squeeze() + mean_vec
)
}
#' @noRd
get_cov_decomp_chol <- function(
spatial_decomp, time_decomp, covs, rank_cp, omega, tau, y, wish_precision_tensor
) {
y_masked <- omega * y
# TODO Merge some parts with marginal ll of spatial and temporal
# get corresponding norm multivariate mean
b <- TSR$khatri_rao_prod(spatial_decomp, time_decomp)$reshape(
c(spatial_decomp$shape[1], time_decomp$shape[1], rank_cp)
)
psi_c <- torch::torch_einsum('ijk,ijl->ijlk', c(covs, b))
psi_c_mask <- psi_c * omega$unsqueeze(3)$unsqueeze(4)$expand_as(psi_c)
psi_c_mask <- psi_c_mask$permute(c(2, 1, 3, 4))$reshape(
c(psi_c$shape[1] * psi_c$shape[2], psi_c$shape[3] * psi_c$shape[4])
)
inv_s <- TSR$kronecker_prod(TSR$eye(rank_cp), wish_precision_tensor)
lambda_c <- tau * psi_c_mask$t()$matmul(psi_c_mask) + inv_s
chol_lc <- torch::linalg_cholesky(lambda_c)
cc <- torch::linalg_solve(chol_lc, psi_c_mask$t()$matmul(y_masked$t()$flatten()))
return(list(chol_lc = chol_lc, cc = cc))
}
#' @title R6 class for the Tau precision hyperparameter sampling
#'
#' @description Encapsulate all the behavior that allows to generate new tau values
#'
#' @noRd
TauSampler <- R6::R6Class(
'TauSampler',
public = list(
b_0 = NULL,
a_tau = NULL,
initialize = function(a_0, b_0, nb_observations) {
self$b_0 <- b_0
self$a_tau <- TSR$tensor(a_0 + 0.5 * nb_observations)
},
sample = function(total_sq_error) {
b_tau <- self$b_0 + 0.5 * total_sq_error
return(TSR$tensor(
torch::distr_gamma(self$a_tau$cpu(), b_tau)$sample()
))
}
)
)
#' @title R6 class to sample new precision matrices
#'
#' @description Encapsulate all the behavior that allows to sample new precision matrices from
#' a Wishart distribution
#'
#' @noRd
# TODO create a PR to add rand wishart in R Torch
PrecisionMatrixSampler <- R6::R6Class(
'PrecisionMatrixSampler',
public = list(
nb_covariates = NULL,
wish_df = NULL,
wish_precision_tensor = NULL,
initialize = function(nb_covariates, rank_cp) {
self$nb_covariates <- nb_covariates
self$wish_df <- nb_covariates + rank_cp
},
sample = function(covs_decomp) {
w <- covs_decomp$matmul(covs_decomp$t()) + TSR$eye(self$nb_covariates)
wish_sigma <- as.matrix(((w + w$t()) * 0.5)$inverse()$cpu())
wish_precision_matrix <- rWishart(1, self$wish_df, wish_sigma)[, , 1]
self$wish_precision_tensor <- TSR$tensor(wish_precision_matrix)
return(self$wish_precision_tensor)
}
)
)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.