Nothing
#' Temporal Encoder–Masked Probabilistic Ensemble Regressor
#'
#' @description
#' Temper trains and deploys a hybrid forecasting model that couples a temporal auto-encoder (shrinks a sliding window of length `past` into a latent representation of size `latent_dim`) and a masked neural decision forest (an ensemble of `n_trees` soft decision trees of depth `depth`; feature-level dropout is governed by `init_prob` and annealed by a Gumbel–Softmax with parameter `temperature`) and a CRPS loss (Continuous Ranked Probability Score) that blends the probabilistic forecasting error with a reconstruction term (`lambda_rec × MSE`), to yield multi-step probabilistic forecasts and their fan chart. Model weights are optimized with ADAM or other options, optional early stopping.
#'
#' @param ts Numeric vector of length at least past + future. Represents the input time series in levels (not log-returns). Missing values are automatically imputed using na_kalman.
#' @param future Integer \eqn{\geq 1}. Forecast horizon: the number of steps ahead to predict.
#' @param past Integer \eqn{\geq 1}. Length of the sliding window used to feed the encoder.
#' @param latent_dim Integer \eqn{\geq 1}. Dimensionality of the autoencoder's latent bottleneck.
#' @param n_trees Integer \eqn{\geq 1}. Number of trees in the neural decision forest ensemble. Usually in the range of 30 to 200. Default: 30.
#' @param depth Integer \eqn{\geq 1}. Depth of each decision tree (i.e., number of binary splits). Usually in the range of 4 to 12. Default: 6.
#' @param init_prob Numeric in \eqn{(0, 1)}. Initial probability that each input feature is kept by the feature mask (used for stochastic feature selection). A value of 0 means always dropped; 1 means always included. Default: 0.8.
#' @param temperature Positive numeric. Temperature parameter for the Gumbel–Softmax distribution used during feature masking. Lower values lead to harder (closer to binary) masks; higher values encourage smoother gradients. Default: 0.5.
#' @param n_bases Integer \eqn{\geq 1}. Max numbers of bases for the Gaussian mixture. Default: 10.
#' @param train_rate Numeric in \eqn{(0, 1)}. Proportion of samples allocated to the training set. The remaining samples form the validation set used for early stopping. Default: 0.7.
#' @param epochs Positive integer. Maximum number of training epochs. Have a look at the loss plot to decide the right number of epochs. Default: 30.
#' @param optimizer Character string. Optimizer to use for training (adam, adamw, sgd, rprop, rmsprop, adagrad, asgd, adadelta). Default: adam.
#' @param lr Positive numeric. Learning rate for the optimizer. Default: 0.005.
#' @param batch Positive integer. Mini-batch size used during training. Default: 32.
#' @param lambda_rec Non-negative numeric. Weight applied to the reconstruction loss relative to the probabilistic CRPS forecasting loss. Default: 0.3.
#' @param patience Positive integer. Number of consecutive epochs without improvement on the validation CRPS before early stopping is triggered. Default: 15.
#' @param verbose Logical. If \code{TRUE}, prints CRPS values for each epoch during training. Default: TRUE.
#' @param alpha Numeric in \eqn{(0, 1)}. Confidence level used to define the predictive interval band width in the output fan chart. Default: 0.1.
#' @param dates Optional \code{Date} vector of the same length as ts. If supplied, fan chart x-axes use calendar dates; otherwise, integer time indices are used. Default: NULL.
#' @param seed Optional integer. Used to seed both R and Torch random number generators for reproducibility. Default: 42.
#'
#' @return A named list with four components
#' \describe{
#' \item{`loss`}{A ggplot in which training and validation CRPS are plotted against epoch number, useful for diagnosing over-/under-fitting.}
#' \item{`pred_funs`}{A length-`future` list. Each element contains four empirical distribution functions (pdf, cdf, icdf, sampler) created by empfun}
#' \item{`plot`}{A ggplot object showing the historical series, median forecast and predictive interval. A print-ready fan chart.}
#' \item{`time_log`}{An object measuring the wall-clock training time.}
#' }
#'
#'
#' @examples
#' \donttest{
#' set.seed(2025)
#' ts <- cumsum(rnorm(250)) # synthetic price series
#' fit <- temper(ts, future = 3, past = 20, latent_dim = 5, epochs = 2)
#'
#' # 80 % predictive interval for the 3-step-ahead forecast
#' pfun <- fit$pred_funs$t3$pfun
#' pred_interval_80 <- c(pfun(0.1), pfun(0.9))
#'
#' # Visual diagnostics
#' print(fit$plot)
#' print(fit$loss)
#' }
#'
#' @import torch ggplot2 purrr
#' @importFrom stats quantile runif dunif punif qunif bw.nrd0 bw.nrd dnorm approxfun fft median sd rt coef pnorm rnorm uniroot
#' @importFrom imputeTS na_kalman
#' @importFrom scales number
#' @importFrom lubridate seconds_to_period period
#' @importFrom utils tail head
#'
#' @export
temper <- function(ts, future, past, latent_dim, n_trees = 30, depth = 6,
init_prob = 0.8, temperature = 0.5, n_bases = 10,
train_rate = 0.7, epochs = 30, optimizer = "adam", lr = 0.005, batch = 32,
lambda_rec = 0.3, patience = 15, verbose = TRUE,
alpha = 0.1, dates = NULL, seed = 42)
{
start <- Sys.time()
set.seed(seed)
torch_manual_seed(seed)
if(anyNA(ts)){ts <- na_kalman(ts)}
scaled_ts <- dts(ts, 1)
set <- smart_reframer(scaled_ts, past + future, past + future)
x_set <- set[, 1:past, drop = FALSE]
y_set <- set[, (past + 1):(past + future), drop = FALSE]
x <- head(torch_tensor(x_set), -1)
new_x <- tail(torch_tensor(x_set), 1)
y <- tail(torch_tensor(y_set), -1)
model <- forecasting_model(seq_len = past,
latent_dim = latent_dim,
n_trees = n_trees,
depth = depth,
out_dim = future,
temperature,
init_prob)
train_idx <- sample.int(nrow(x), train_rate * nrow(x))
val_idx <- setdiff(1:nrow(x), train_idx)
trained <- train_model(model, x[train_idx,,drop=FALSE], y[train_idx,,drop=FALSE],
epochs, lr, batch, lambda_rec, patience, x[val_idx,,drop=FALSE], y[val_idx,,drop=FALSE], optimizer, verbose)
model <- trained$model
loss_plot <- trained$loss_plot
y_hat <- model(new_x)
raw_preds <- as.matrix(torch_stack(y_hat$pred)$squeeze())
proj_space <- t(apply(raw_preds, 2, function(x) tail(ts, 1) * cumprod(1 + x)))
pred_funs <- apply(proj_space, 2, function(x) gmix(x, K.max = n_bases, seed = seed))
names(pred_funs) <- paste0("t", 1:future)
plot <- plot_graph(ts, pred_funs, alpha = alpha, dates = dates)
end <- Sys.time()
time_log <- seconds_to_period(round(difftime(end, start, units = "secs"), 0))
out <- list(loss = loss_plot, pred_funs = pred_funs, plot = plot, time_log = time_log)
return(out)
}
# ---------------------------------------------------------------------------
# Everything below is INTERNAL
# ---------------------------------------------------------------------------
#' @keywords internal
mtry_mask <- nn_module(
"mtry_mask",
initialize = function(input_dim, init_prob = 0.8, temperature = 0.5) {
self$log_alpha <- nn_parameter(
torch_full(input_dim, log(init_prob / (1 - init_prob))))
self$tau <- temperature
},
forward = function(x, hard = FALSE) {
if (self$training) {
u <- torch_rand_like(x)
gumbel <- -torch_log(-torch_log(u))
m_soft <- torch_sigmoid((self$log_alpha + gumbel) / self$tau)
return(x * m_soft)
} else {
m_det <- (self$log_alpha > 0)$to(dtype = x$dtype)
if (hard) m_det <- m_det$detach()
return(x * m_det)
}
}
)
#' @keywords internal
sdtree_module <- nn_module(
"sdtree_module",
initialize = function(input_dim, depth = 3, out_dim,
temperature = 0.5, init_prob = 0.8) {
self$mask_layer <- mtry_mask(input_dim, init_prob, temperature)
self$depth <- depth
self$n_internal <- 2^depth - 1
self$n_leaf <- 2^depth
self$gate_layer <- nn_linear(input_dim, self$n_internal)
self$leaf <- nn_parameter(torch_randn(self$n_leaf, out_dim))
},
.path_probs = function(p) {
b <- p$size(1); prob <- torch_ones(b, 1, device = p$device); idx <- 1
for (d in seq_len(self$depth)) {
n_lvl <- 2^(d - 1); p_lvl <- p[ , idx:(idx + n_lvl - 1) ]
prob <- torch_cat(list(prob * p_lvl, prob * (1 - p_lvl)), dim = 2)
idx <- idx + n_lvl
}
prob
},
forward = function(x) {
x_masked <- self$mask_layer(x) # ← feature selection
gate_p <- self$gate_layer(x_masked)$sigmoid()
leaf_prob <- self$.path_probs(gate_p)
leaf_prob$matmul(self$leaf) # (batch × out_dim)
}
)
#' @keywords internal
neural_decision_forest <- nn_module(
"neural_decision_forest",
initialize = function(n_trees, input_dim, depth = 3, out_dim,
temperature = 0.5, init_prob = 0.8) {
self$trees <- nn_module_list(lapply(seq_len(n_trees), function(i)
sdtree_module(input_dim, depth, out_dim,
temperature, init_prob)))
},
forward = function(x) {
torch_stack(lapply(self$trees, function(t) t(x)), dim = 3)
}
)
#' @keywords internal
autoencoder <- nn_module(
"autoencoder",
initialize = function(seq_len, latent_dim, hidden_dim = NULL)
{
if (is.null(hidden_dim))
hidden_dim <- max(64, latent_dim * 2)
self$seq_len <- seq_len
self$encoder <- nn_sequential(
nn_linear(seq_len, hidden_dim),
nn_relu(),
nn_linear(hidden_dim, latent_dim)
)
self$decoder <- nn_sequential(
nn_linear(latent_dim, hidden_dim),
nn_relu(),
nn_linear(hidden_dim, seq_len)
)
},
forward = function(x) {
z <- self$encoder(x)
rec <- self$decoder(z)
list(latent = z, recon = rec)
}
)
#' @keywords internal
forecasting_model <- nn_module(
"forecasting_model",
initialize = function(seq_len,
latent_dim,
n_trees,
depth,
out_dim,
temperature,
init_prob) {
self$ae <- autoencoder(seq_len, latent_dim)
self$forest <- neural_decision_forest(n_trees = n_trees,
input_dim = latent_dim,
depth = depth,
out_dim = out_dim,
temperature = temperature,
init_prob = init_prob)
},
forward = function(x) { # x: B × L × F
ae_out <- self$ae(x)
pred <- self$forest$forward(ae_out$latent) # B × out_dim
list(pred = pred, recon = ae_out$recon)
}
)
#' @keywords internal
nnf_crps_ensemble <- function(pred, target) {
if (!inherits(target, "torch_tensor"))
target <- torch_tensor(target)
target <- target$unsqueeze(3)$to(dtype = pred$dtype, device = pred$device)
m <- pred$size(3) # n_trees (ensemble size)
# 1) mean |y − x_i|
term1 <- torch_abs(pred - target)$mean(dim = 3) # (B × D)
# 2) mean |x_i − x_j|
# compute pairwise differences with broadcasting
diff <- torch_abs(pred$unsqueeze(4) - pred$unsqueeze(3)) # B×D×m×m
term2 <- diff$mean(dim = c(3,4)) # (B × D)
crps <- term1 - 0.5 * term2 # (B × D)
crps$mean() # scalar
}
#' @keywords internal
plot_graph <- function(ts, pred_funs, alpha = 0.05, dates = NULL, line_size = 1.3, label_size = 11,
forcat_band = "seagreen2", forcat_line = "seagreen4", hist_line = "gray43",
label_x = "Horizon", label_y= "Forecasted Var", date_format = "%b-%Y")
{
preds <- Reduce(rbind, map(pred_funs, ~ quantile(.x$rfun(1000), probs = c(alpha, 0.5, (1-alpha)))))
colnames(preds) <- c("lower", "median", "upper")
future <- nrow(preds)
if(is.null(dates)){x_hist <- 1:length(ts)} else {x_hist <- as.Date(as.character(dates))}
if(is.null(dates)){x_forcat <- length(ts) + 1:nrow(preds)} else {x_forcat <- as.Date(as.character(tail(dates, 1)))+ 1:future}
forecast_data <- data.frame(x_forcat = x_forcat, preds)
historical_data <- data.frame(x_all = as.Date(c(x_hist, x_forcat)), y_all = c(ts = ts, pred = preds[, "median"]))
plot <- ggplot()+ geom_line(data = historical_data, aes(x = .data$x_all, y = .data$y_all), color = hist_line, linewidth = line_size)
plot <- plot + geom_ribbon(data = forecast_data, aes(x = x_forcat, ymin = .data$lower, ymax = .data$upper), alpha = 0.3, fill = forcat_band)
plot <- plot + geom_line(data = forecast_data, aes(x = x_forcat, y = median), color = forcat_line, linewidth = line_size)
if(!is.null(dates)){plot <- plot + scale_x_date(name = paste0("\n", label_x), date_labels = date_format)}
if(is.null(dates)){plot <- plot + scale_x_continuous(name = paste0("\n", label_x))}
plot <- plot + scale_y_continuous(name = paste0(label_y, "\n"), labels = number)
plot <- plot + ylab(label_y) + theme_bw()
plot <- plot + theme(axis.text=element_text(size=label_size), axis.title=element_text(size=label_size + 2))
return(plot)
}
#' @keywords internal
dts <- function(ts, lag = 1)
{
scaled_ts <- tail(ts, -lag)/head(ts, -lag)-1
scaled_ts[!is.finite(scaled_ts)] <- NA
if(anyNA(ts)){scaled_ts <- na_kalman(scaled_ts)}
return(scaled_ts)
}
#' @keywords internal
smart_reframer <- function(ts, seq_len, stride)
{
n_length <- length(ts)
if(seq_len > n_length | stride > n_length){stop("vector too short for sequence length or stride")}
if(n_length%%seq_len > 0){ts <- tail(ts, - (n_length%%seq_len))}
n_length <- length(ts)
idx <- seq(from = 1, to = (n_length - seq_len + 1), by = 1)
reframed <- t(sapply(idx, function(x) ts[x:(x+seq_len-1)]))
if(seq_len == 1){reframed <- t(reframed)}
idx <- rev(seq(nrow(reframed), 1, - stride))
reframed <- reframed[idx,,drop = FALSE]
colnames(reframed) <- paste0("t", 1:seq_len)
return(reframed)
}
#' @keywords internal
gmix <- function(x,
K.max = 10, # upper bound on clusters
seed = 1, # k-means reproducibility
...) { # extra args to kmeans()
stopifnot(is.numeric(x), is.vector(x))
n <- length(x)
## -------------------------------------------------------------- ##
## 1. Compute WSS for k = 1 … K.max ----------------------------- ##
## -------------------------------------------------------------- ##
wss <- numeric(K.max)
set.seed(seed)
for (k in 1:K.max) {
wss[k] <- stats::kmeans(x, centers = k, ...)$tot.withinss
}
## -------------------------------------------------------------- ##
## 2. Detect the elbow (max curvature) ------------------------- ##
## -------------------------------------------------------------- ##
# First & second finite differences: ΔWSS(k) = WSS(k-1) – WSS(k)
d1 <- -diff(wss) # length K.max-1
d2 <- diff(d1) # length K.max-2
k_hat <- which.max(d2) + 1L # add 1 because d2 starts at k = 3
## Fallback if curvature is flat (rare): default to 2 clusters
if (length(k_hat) == 0 || is.na(k_hat) || k_hat < 2) k_hat <- 2
## -------------------------------------------------------------- ##
## 3. Final k-means with k = k̂ -------------------------------- ##
## -------------------------------------------------------------- ##
set.seed(seed)
km <- stats::kmeans(x, centers = k_hat, ...)
clusters <- km$cluster
centers <- as.numeric(km$centers)
sizes <- km$size
weights <- sizes / n # π_g
# Component SDs (protect 1-point clusters with a small positive value)
sigmas <- vapply(split(x, clusters), function(v)
if (length(v) > 1) sd(v) else 1e-8, numeric(1))
## -------------------------------------------------------------- ##
## 4. Mixture helpers ------------------------------------------ ##
## -------------------------------------------------------------- ##
pdf_fun <- function(z)
vapply(z, function(zz) sum(weights * dnorm(zz, centers, sigmas)), numeric(1))
cdf_fun <- function(z)
vapply(z, function(zz) sum(weights * pnorm(zz, centers, sigmas)), numeric(1))
lower <- min(x) - 6 * max(sigmas)
upper <- max(x) + 6 * max(sigmas)
icdf_fun <- function(p) {
stopifnot(all(p >= 0 & p <= 1))
vapply(p, function(pp) uniroot(function(q) cdf_fun(q) - pp, interval = c(lower, upper))$root, numeric(1))
}
sampler_fun <- function(n) {
g <- sample(seq_along(weights), n, TRUE, prob = weights)
rnorm(n, centers[g], sigmas[g])
}
pred_funs <- list(rfun = sampler_fun, dfun = pdf_fun, pfun = cdf_fun, qfun = icdf_fun)
attr(pred_funs, "kmeans") <- km
attr(pred_funs, "weights") <- weights
attr(pred_funs, "centers") <- centers
attr(pred_funs, "sigmas") <- sigmas
attr(pred_funs, "wss") <- wss
attr(pred_funs, "k_hat") <- k_hat
return(pred_funs)
}
#' @keywords internal
train_model <- function(model,
x, y, # training tensors
epochs = 150,
lr = 1e-3,
batch = 32,
lambda_rec = 0.3,
patience = 15,
val_x = NULL, # optional validation
val_y = NULL,
optimizer = c("adam", "adamw", "sgd", "rprop", "rmsprop", "adagrad", "asgd", "adadelta"),
verbose = TRUE) {
optimizer <- match.arg(optimizer)
opt <- switch(
optimizer,
adam = optim_adam(model$parameters, lr = lr),
adamw = optim_adamw(model$parameters, lr = lr),
adadelta = optim_adadelta(model$parameters, lr = lr),
adagrad = optim_adagrad(model$parameters, lr = lr),
asgd = optim_asgd(model$parameters, lr = lr),
sgd = optim_sgd(model$parameters, lr = lr), # plain SGD
rprop = optim_rprop(model$parameters, lr = lr),
rmsprop = optim_rmsprop(model$parameters, lr = lr)
)
n <- x$size(1)
best_state <- NULL
best_val <- Inf
no_imp <- 0
train_hist <- numeric(0)
val_hist <- numeric(0)
## flag: start early-stopping countdown only *after*
## train_CRPS < val_CRPS at least once
allow_early <- is.null(val_x) || is.null(val_y)
for (e in 1:epochs) {
# ---- TRAIN --------------------------------------------------------
model$train()
batch_crps <- c()
for (idx in split(sample.int(n), ceiling(seq_len(n)/batch))) {
xb <- x[idx,,drop=FALSE]; yb <- y[idx,,drop=FALSE]
opt$zero_grad()
out <- model(xb)
loss_f <- nnf_crps_ensemble(out$pred, yb$to(dtype = out$pred$dtype))
loss_rec <- nnf_mse_loss(out$recon, xb$to(dtype = out$recon$dtype))
loss <- loss_f + lambda_rec * loss_rec
loss$backward(); opt$step()
batch_crps <- c(batch_crps, loss_f$item())
}
train_epoch <- mean(batch_crps)
train_hist <- c(train_hist, train_epoch)
# ---- VALIDATION ---------------------------------------------------
if (!is.null(val_x) && !is.null(val_y)) {
model$eval()
with_no_grad({
val_pred <- model$forward(val_x)
val_epoch <- nnf_crps_ensemble(val_pred$pred, val_y$to(dtype = val_pred$pred$dtype))$item()
})
val_hist <- c(val_hist, val_epoch)
# allow early-stop only after train < val at least once
if (!allow_early && train_epoch < val_epoch)
allow_early <- TRUE
metric <- val_epoch
} else {
metric <- train_epoch # no validation set
}
if (verbose) {
if(e%%10==0){
msg <- sprintf("epoch %3d train %.6f", e, train_epoch)
if (length(val_hist) > 0)
msg <- sprintf("%s val %.6f", msg, val_epoch)
cat(msg, "\n")}
}
# ---- Best-model bookkeeping --------------------------------------
if (metric < best_val) {
best_val <- metric
best_state<- model$state_dict()
no_imp <- 0
} else if (allow_early) {
no_imp <- no_imp + 1
}
# ---- Early-stop decision -----------------------------------------
if (allow_early && no_imp >= patience) {
if (verbose) cat("Early stop.\n")
break
}
}
if (!is.null(best_state)) model$load_state_dict(best_state)
## ── build history data frame ---------------------------------------
epoch_vec <- seq_along(train_hist)
history <- rbind(
data.frame(epoch = epoch_vec, set = "train", crps = train_hist),
data.frame(epoch = epoch_vec, set = "validation", crps = val_hist)
)
## ── ggplot ----------------------------------------------------------
loss_plot <- ggplot(history,
ggplot2::aes(x = .data$epoch, y = .data$crps, colour = .data$set)) +
geom_line(linewidth = 1) +
geom_point(size = 1.2) +
scale_colour_manual(values = c("#1B9E77", "#D95F02"),
breaks = c("train", "validation"),
na.translate = FALSE) +
labs(x = "Epoch",
y = "CRPS (ensemble)",
colour = NULL) +
theme_light()
## ── return ----------------------------------------------------------
out <- list(model = model,
history = history,
loss_plot = loss_plot)
out
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.