R/main.R

Defines functions train_model gmix smart_reframer dts plot_graph nnf_crps_ensemble temper

Documented in temper

#' Temporal Encoder–Masked Probabilistic Ensemble Regressor
#'
#' @description
#' Temper trains and deploys a hybrid forecasting model that couples a temporal auto-encoder (shrinks a sliding window of length `past` into a latent representation of size `latent_dim`) and a masked neural decision forest (an ensemble of `n_trees` soft decision trees of depth `depth`; feature-level dropout is governed by `init_prob` and annealed by a Gumbel–Softmax with parameter `temperature`) and a CRPS loss (Continuous Ranked Probability Score) that blends the probabilistic forecasting error with a reconstruction term (`lambda_rec × MSE`), to yield multi-step probabilistic forecasts and their fan chart. Model weights are optimized with ADAM or other options, optional early stopping.
#'
#' @param ts Numeric vector of length at least past + future. Represents the input time series in levels (not log-returns). Missing values are automatically imputed using na_kalman.
#' @param future Integer \eqn{\geq 1}. Forecast horizon: the number of steps ahead to predict.
#' @param past Integer \eqn{\geq 1}. Length of the sliding window used to feed the encoder.
#' @param latent_dim Integer \eqn{\geq 1}. Dimensionality of the autoencoder's latent bottleneck.
#' @param n_trees Integer \eqn{\geq 1}. Number of trees in the neural decision forest ensemble. Usually in the range of 30 to 200. Default: 30.
#' @param depth Integer \eqn{\geq 1}. Depth of each decision tree (i.e., number of binary splits). Usually in the range of 4 to 12. Default: 6.
#' @param init_prob Numeric in \eqn{(0, 1)}. Initial probability that each input feature is kept by the feature mask (used for stochastic feature selection). A value of 0 means always dropped; 1 means always included. Default: 0.8.
#' @param temperature Positive numeric. Temperature parameter for the Gumbel–Softmax distribution used during feature masking. Lower values lead to harder (closer to binary) masks; higher values encourage smoother gradients. Default: 0.5.
#' @param n_bases Integer \eqn{\geq 1}. Max numbers of bases for the Gaussian mixture. Default: 10.
#' @param train_rate Numeric in \eqn{(0, 1)}. Proportion of samples allocated to the training set. The remaining samples form the validation set used for early stopping. Default: 0.7.
#' @param epochs Positive integer. Maximum number of training epochs. Have a look at the loss plot to decide the right number of epochs. Default: 30.
#' @param optimizer Character string. Optimizer to use for training (adam, adamw, sgd, rprop, rmsprop, adagrad, asgd, adadelta). Default: adam.
#' @param lr Positive numeric. Learning rate for the optimizer. Default: 0.005.
#' @param batch Positive integer. Mini-batch size used during training. Default: 32.
#' @param lambda_rec Non-negative numeric. Weight applied to the reconstruction loss relative to the probabilistic CRPS forecasting loss. Default: 0.3.
#' @param patience Positive integer. Number of consecutive epochs without improvement on the validation CRPS before early stopping is triggered. Default: 15.
#' @param verbose Logical. If \code{TRUE}, prints CRPS values for each epoch during training. Default: TRUE.
#' @param alpha Numeric in \eqn{(0, 1)}. Confidence level used to define the predictive interval band width in the output fan chart. Default: 0.1.
#' @param dates Optional \code{Date} vector of the same length as ts. If supplied, fan chart x-axes use calendar dates; otherwise, integer time indices are used. Default: NULL.
#' @param seed Optional integer. Used to seed both R and Torch random number generators for reproducibility. Default: 42.
#'
#' @return A named list with four components
#' \describe{
#'   \item{`loss`}{A ggplot in which training and validation CRPS are plotted against epoch number, useful for diagnosing over-/under-fitting.}
#'   \item{`pred_funs`}{A length-`future` list.  Each element contains four empirical distribution functions (pdf, cdf, icdf, sampler) created by empfun}
#'   \item{`plot`}{A ggplot object showing the historical series, median forecast and predictive interval. A print-ready fan chart.}
#'   \item{`time_log`}{An object measuring the wall-clock training time.}
#' }
#'
#'
#' @examples
#' \donttest{
#' set.seed(2025)
#' ts <- cumsum(rnorm(250))          # synthetic price series
#' fit <- temper(ts, future = 3, past = 20, latent_dim = 5, epochs = 2)
#'
#' # 80 % predictive interval for the 3-step-ahead forecast
#' pfun <- fit$pred_funs$t3$pfun
#' pred_interval_80 <- c(pfun(0.1), pfun(0.9))
#'
#' # Visual diagnostics
#' print(fit$plot)
#' print(fit$loss)
#' }
#'
#' @import torch ggplot2 purrr
#' @importFrom stats  quantile runif dunif punif qunif bw.nrd0 bw.nrd dnorm approxfun fft median sd rt coef pnorm rnorm uniroot
#' @importFrom imputeTS  na_kalman
#' @importFrom scales  number
#' @importFrom lubridate  seconds_to_period period
#' @importFrom utils  tail head
#'
#' @export
temper <- function(ts, future, past, latent_dim, n_trees = 30, depth = 6,
                   init_prob = 0.8, temperature = 0.5, n_bases = 10,
                   train_rate = 0.7, epochs = 30, optimizer = "adam", lr = 0.005, batch = 32,
                   lambda_rec = 0.3, patience = 15, verbose = TRUE,
                   alpha = 0.1, dates = NULL, seed = 42)
{
  start <- Sys.time()

  set.seed(seed)
  torch_manual_seed(seed)

  if(anyNA(ts)){ts <- na_kalman(ts)}
  scaled_ts <- dts(ts, 1)
  set <- smart_reframer(scaled_ts, past + future, past + future)
  x_set <- set[, 1:past, drop = FALSE]
  y_set <- set[, (past + 1):(past + future), drop = FALSE]

  x <- head(torch_tensor(x_set), -1)
  new_x <- tail(torch_tensor(x_set), 1)
  y <- tail(torch_tensor(y_set), -1)

  model <- forecasting_model(seq_len = past,
                             latent_dim = latent_dim,
                             n_trees    = n_trees,
                             depth      = depth,
                             out_dim    = future,
                             temperature,
                             init_prob)

  train_idx <- sample.int(nrow(x), train_rate * nrow(x))
  val_idx <- setdiff(1:nrow(x), train_idx)

  trained <- train_model(model, x[train_idx,,drop=FALSE], y[train_idx,,drop=FALSE],
                         epochs, lr, batch, lambda_rec, patience, x[val_idx,,drop=FALSE], y[val_idx,,drop=FALSE], optimizer, verbose)

  model <- trained$model
  loss_plot <- trained$loss_plot

  y_hat <- model(new_x)
  raw_preds <- as.matrix(torch_stack(y_hat$pred)$squeeze())
  proj_space <- t(apply(raw_preds, 2, function(x) tail(ts, 1) * cumprod(1 + x)))
  pred_funs <- apply(proj_space, 2, function(x) gmix(x, K.max = n_bases, seed = seed))
  names(pred_funs) <- paste0("t", 1:future)

  plot <- plot_graph(ts, pred_funs, alpha = alpha, dates = dates)
  end <- Sys.time()
  time_log <- seconds_to_period(round(difftime(end, start, units = "secs"), 0))

  out <- list(loss = loss_plot, pred_funs = pred_funs, plot = plot, time_log = time_log)

  return(out)
}


# ---------------------------------------------------------------------------
# Everything below is INTERNAL
# ---------------------------------------------------------------------------

#' @keywords internal
mtry_mask <- nn_module(
  "mtry_mask",
  initialize = function(input_dim, init_prob = 0.8, temperature = 0.5) {
    self$log_alpha <- nn_parameter(
      torch_full(input_dim, log(init_prob / (1 - init_prob))))
    self$tau <- temperature
  },
  forward = function(x, hard = FALSE) {
    if (self$training) {
      u      <- torch_rand_like(x)
      gumbel <- -torch_log(-torch_log(u))
      m_soft <- torch_sigmoid((self$log_alpha + gumbel) / self$tau)
      return(x * m_soft)
    } else {
      m_det <- (self$log_alpha > 0)$to(dtype = x$dtype)
      if (hard) m_det <- m_det$detach()
      return(x * m_det)
    }
  }
)

#' @keywords internal
sdtree_module <- nn_module(
  "sdtree_module",

  initialize = function(input_dim, depth = 3, out_dim,
                        temperature = 0.5, init_prob = 0.8) {

    self$mask_layer <- mtry_mask(input_dim, init_prob, temperature)

    self$depth      <- depth
    self$n_internal <- 2^depth - 1
    self$n_leaf     <- 2^depth

    self$gate_layer <- nn_linear(input_dim, self$n_internal)
    self$leaf       <- nn_parameter(torch_randn(self$n_leaf, out_dim))
  },

  .path_probs = function(p) {
    b <- p$size(1); prob <- torch_ones(b, 1, device = p$device); idx <- 1
    for (d in seq_len(self$depth)) {
      n_lvl <- 2^(d - 1); p_lvl <- p[ , idx:(idx + n_lvl - 1) ]
      prob  <- torch_cat(list(prob * p_lvl, prob * (1 - p_lvl)), dim = 2)
      idx   <- idx + n_lvl
    }
    prob
  },

  forward = function(x) {
    x_masked  <- self$mask_layer(x)                  # ← feature selection
    gate_p    <- self$gate_layer(x_masked)$sigmoid()
    leaf_prob <- self$.path_probs(gate_p)
    leaf_prob$matmul(self$leaf)                      # (batch × out_dim)
  }
)


#' @keywords internal
neural_decision_forest <- nn_module(
  "neural_decision_forest",

  initialize = function(n_trees, input_dim, depth = 3, out_dim,
                        temperature = 0.5, init_prob = 0.8) {
    self$trees <- nn_module_list(lapply(seq_len(n_trees), function(i)
      sdtree_module(input_dim, depth, out_dim,
                    temperature, init_prob)))
  },

  forward = function(x) {
    torch_stack(lapply(self$trees, function(t) t(x)), dim = 3)
  }
)

#' @keywords internal
autoencoder <- nn_module(
  "autoencoder",

  initialize = function(seq_len, latent_dim, hidden_dim = NULL)
  {

    if (is.null(hidden_dim))
      hidden_dim <- max(64, latent_dim * 2)

    self$seq_len  <- seq_len

    self$encoder <- nn_sequential(
      nn_linear(seq_len, hidden_dim),
      nn_relu(),
      nn_linear(hidden_dim, latent_dim)
    )

    self$decoder <- nn_sequential(
      nn_linear(latent_dim, hidden_dim),
      nn_relu(),
      nn_linear(hidden_dim, seq_len)
    )
  },

  forward = function(x) {
    z       <- self$encoder(x)
    rec     <- self$decoder(z)
    list(latent = z, recon = rec)
  }
)


#' @keywords internal
forecasting_model <- nn_module(
  "forecasting_model",

  initialize = function(seq_len,
                        latent_dim,
                        n_trees,
                        depth,
                        out_dim,
                        temperature,
                        init_prob) {

    self$ae      <- autoencoder(seq_len, latent_dim)
    self$forest  <- neural_decision_forest(n_trees     = n_trees,
                                           input_dim   = latent_dim,
                                           depth       = depth,
                                           out_dim     = out_dim,
                                           temperature = temperature,
                                           init_prob   = init_prob)
  },

  forward = function(x) {                               # x: B × L × F
    ae_out <- self$ae(x)
    pred   <- self$forest$forward(ae_out$latent)        # B × out_dim
    list(pred = pred, recon = ae_out$recon)
  }
)

#' @keywords internal
nnf_crps_ensemble <- function(pred, target) {

  if (!inherits(target, "torch_tensor"))
    target <- torch_tensor(target)

  target <- target$unsqueeze(3)$to(dtype = pred$dtype, device = pred$device)

  m <- pred$size(3)                                 # n_trees (ensemble size)

  # 1) mean |y − x_i|
  term1 <- torch_abs(pred - target)$mean(dim = 3)    # (B × D)

  # 2) mean |x_i − x_j|
  #    compute pairwise differences with broadcasting
  diff  <- torch_abs(pred$unsqueeze(4) - pred$unsqueeze(3))  # B×D×m×m
  term2 <- diff$mean(dim = c(3,4))                 # (B × D)

  crps  <- term1 - 0.5 * term2                     # (B × D)
  crps$mean()                                      # scalar
}

#' @keywords internal
plot_graph <- function(ts, pred_funs, alpha = 0.05, dates = NULL, line_size = 1.3, label_size = 11,
                       forcat_band = "seagreen2", forcat_line = "seagreen4", hist_line = "gray43",
                       label_x = "Horizon", label_y= "Forecasted Var", date_format = "%b-%Y")
{
  preds <- Reduce(rbind, map(pred_funs, ~ quantile(.x$rfun(1000), probs = c(alpha, 0.5, (1-alpha)))))
  colnames(preds) <- c("lower", "median", "upper")
  future <- nrow(preds)

  if(is.null(dates)){x_hist <- 1:length(ts)} else {x_hist <- as.Date(as.character(dates))}
  if(is.null(dates)){x_forcat <- length(ts) + 1:nrow(preds)} else {x_forcat <- as.Date(as.character(tail(dates, 1)))+ 1:future}

  forecast_data <- data.frame(x_forcat = x_forcat, preds)
  historical_data <- data.frame(x_all = as.Date(c(x_hist, x_forcat)), y_all = c(ts = ts, pred = preds[, "median"]))

  plot <- ggplot()+ geom_line(data = historical_data, aes(x = .data$x_all, y = .data$y_all), color = hist_line, linewidth = line_size)
  plot <- plot + geom_ribbon(data = forecast_data, aes(x = x_forcat, ymin = .data$lower, ymax = .data$upper), alpha = 0.3, fill = forcat_band)
  plot <- plot + geom_line(data = forecast_data, aes(x = x_forcat, y = median), color = forcat_line, linewidth = line_size)
  if(!is.null(dates)){plot <- plot + scale_x_date(name = paste0("\n", label_x), date_labels = date_format)}
  if(is.null(dates)){plot <- plot + scale_x_continuous(name = paste0("\n", label_x))}
  plot <- plot + scale_y_continuous(name = paste0(label_y, "\n"), labels = number)
  plot <- plot + ylab(label_y) + theme_bw()
  plot <- plot + theme(axis.text=element_text(size=label_size), axis.title=element_text(size=label_size + 2))

  return(plot)
}

#' @keywords internal
dts <- function(ts, lag = 1)
{
  scaled_ts <- tail(ts, -lag)/head(ts, -lag)-1
  scaled_ts[!is.finite(scaled_ts)] <- NA
  if(anyNA(ts)){scaled_ts <- na_kalman(scaled_ts)}
  return(scaled_ts)
}


#' @keywords internal
smart_reframer <- function(ts, seq_len, stride)
{
  n_length <- length(ts)
  if(seq_len > n_length | stride > n_length){stop("vector too short for sequence length or stride")}
  if(n_length%%seq_len > 0){ts <- tail(ts, - (n_length%%seq_len))}
  n_length <- length(ts)
  idx <- seq(from = 1, to = (n_length - seq_len + 1), by = 1)
  reframed <- t(sapply(idx, function(x) ts[x:(x+seq_len-1)]))
  if(seq_len == 1){reframed <- t(reframed)}
  idx <- rev(seq(nrow(reframed), 1, - stride))
  reframed <- reframed[idx,,drop = FALSE]
  colnames(reframed) <- paste0("t", 1:seq_len)
  return(reframed)
}

#' @keywords internal
gmix <- function(x,
                 K.max = 10,     # upper bound on clusters
                 seed   = 1,     # k-means reproducibility
                 ...) {          # extra args to kmeans()
  stopifnot(is.numeric(x), is.vector(x))
  n <- length(x)

  ## -------------------------------------------------------------- ##
  ## 1.  Compute WSS for k = 1 … K.max ----------------------------- ##
  ## -------------------------------------------------------------- ##
  wss <- numeric(K.max)
  set.seed(seed)
  for (k in 1:K.max) {
    wss[k] <- stats::kmeans(x, centers = k, ...)$tot.withinss
  }

  ## -------------------------------------------------------------- ##
  ## 2.  Detect the elbow (max curvature) ------------------------- ##
  ## -------------------------------------------------------------- ##
  # First & second finite differences: ΔWSS(k) = WSS(k-1) – WSS(k)
  d1  <- -diff(wss)                  # length K.max-1
  d2  <- diff(d1)                    # length K.max-2
  k_hat <- which.max(d2) + 1L        # add 1 because d2 starts at k = 3

  ## Fallback if curvature is flat (rare): default to 2 clusters
  if (length(k_hat) == 0 || is.na(k_hat) || k_hat < 2) k_hat <- 2

  ## -------------------------------------------------------------- ##
  ## 3.  Final k-means with k = k̂ -------------------------------- ##
  ## -------------------------------------------------------------- ##
  set.seed(seed)
  km <- stats::kmeans(x, centers = k_hat, ...)

  clusters <- km$cluster
  centers  <- as.numeric(km$centers)
  sizes    <- km$size
  weights  <- sizes / n                           # π_g

  # Component SDs (protect 1-point clusters with a small positive value)
  sigmas <- vapply(split(x, clusters), function(v)
    if (length(v) > 1) sd(v) else 1e-8, numeric(1))

  ## -------------------------------------------------------------- ##
  ## 4.  Mixture helpers ------------------------------------------ ##
  ## -------------------------------------------------------------- ##
  pdf_fun <- function(z)
    vapply(z, function(zz) sum(weights * dnorm(zz, centers, sigmas)), numeric(1))

  cdf_fun <- function(z)
    vapply(z, function(zz) sum(weights * pnorm(zz, centers, sigmas)), numeric(1))

  lower <- min(x) - 6 * max(sigmas)
  upper <- max(x) + 6 * max(sigmas)
  icdf_fun <- function(p) {
    stopifnot(all(p >= 0 & p <= 1))
    vapply(p, function(pp) uniroot(function(q) cdf_fun(q) - pp, interval = c(lower, upper))$root, numeric(1))
  }

  sampler_fun <- function(n) {
    g <- sample(seq_along(weights), n, TRUE, prob = weights)
    rnorm(n, centers[g], sigmas[g])
  }

  pred_funs <- list(rfun = sampler_fun, dfun = pdf_fun, pfun = cdf_fun, qfun = icdf_fun)

  attr(pred_funs, "kmeans") <- km
  attr(pred_funs, "weights") <- weights
  attr(pred_funs, "centers") <- centers
  attr(pred_funs, "sigmas") <- sigmas
  attr(pred_funs, "wss") <- wss
  attr(pred_funs, "k_hat") <- k_hat

  return(pred_funs)
}


#' @keywords internal
train_model <- function(model,
                        x, y,                     # training tensors
                        epochs      = 150,
                        lr          = 1e-3,
                        batch       = 32,
                        lambda_rec  = 0.3,
                        patience    = 15,
                        val_x       = NULL,       # optional validation
                        val_y       = NULL,
                        optimizer = c("adam", "adamw", "sgd", "rprop", "rmsprop", "adagrad", "asgd", "adadelta"),
                        verbose     = TRUE) {

  optimizer <- match.arg(optimizer)

  opt <- switch(
    optimizer,
    adam    = optim_adam(model$parameters, lr = lr),
    adamw   = optim_adamw(model$parameters, lr = lr),
    adadelta  = optim_adadelta(model$parameters, lr = lr),
    adagrad = optim_adagrad(model$parameters, lr = lr),
    asgd = optim_asgd(model$parameters, lr = lr),
    sgd     = optim_sgd(model$parameters, lr = lr),  # plain SGD
    rprop   = optim_rprop(model$parameters, lr = lr),
    rmsprop = optim_rmsprop(model$parameters, lr = lr)
  )

  n   <- x$size(1)

  best_state <- NULL
  best_val   <- Inf
  no_imp     <- 0

  train_hist <- numeric(0)
  val_hist   <- numeric(0)

  ## flag: start early-stopping countdown only *after*
  ##       train_CRPS < val_CRPS at least once
  allow_early <- is.null(val_x) || is.null(val_y)

  for (e in 1:epochs) {

    # ---- TRAIN --------------------------------------------------------
    model$train()
    batch_crps <- c()

    for (idx in split(sample.int(n), ceiling(seq_len(n)/batch))) {
      xb <- x[idx,,drop=FALSE]; yb <- y[idx,,drop=FALSE]

      opt$zero_grad()
      out <- model(xb)

      loss_f    <- nnf_crps_ensemble(out$pred, yb$to(dtype = out$pred$dtype))
      loss_rec <- nnf_mse_loss(out$recon, xb$to(dtype = out$recon$dtype))
      loss     <- loss_f + lambda_rec * loss_rec

      loss$backward(); opt$step()
      batch_crps <- c(batch_crps, loss_f$item())
    }

    train_epoch <- mean(batch_crps)
    train_hist  <- c(train_hist, train_epoch)

    # ---- VALIDATION ---------------------------------------------------
    if (!is.null(val_x) && !is.null(val_y)) {
      model$eval()
      with_no_grad({
        val_pred <- model$forward(val_x)
        val_epoch <- nnf_crps_ensemble(val_pred$pred, val_y$to(dtype = val_pred$pred$dtype))$item()
      })
      val_hist <- c(val_hist, val_epoch)

      # allow early-stop only after train < val at least once
      if (!allow_early && train_epoch < val_epoch)
        allow_early <- TRUE

      metric <- val_epoch
    } else {
      metric <- train_epoch            # no validation set
    }

    if (verbose) {
      if(e%%10==0){
        msg <- sprintf("epoch %3d  train %.6f", e, train_epoch)
        if (length(val_hist) > 0)
          msg <- sprintf("%s  val %.6f", msg, val_epoch)
        cat(msg, "\n")}
    }

    # ---- Best-model bookkeeping --------------------------------------
    if (metric < best_val) {
      best_val  <- metric
      best_state<- model$state_dict()
      no_imp    <- 0
    } else if (allow_early) {
      no_imp <- no_imp + 1
    }

    # ---- Early-stop decision -----------------------------------------
    if (allow_early && no_imp >= patience) {
      if (verbose) cat("Early stop.\n")
      break
    }
  }

  if (!is.null(best_state)) model$load_state_dict(best_state)

  ## ── build history data frame ---------------------------------------
  epoch_vec <- seq_along(train_hist)
  history <- rbind(
    data.frame(epoch = epoch_vec, set = "train", crps = train_hist),
    data.frame(epoch = epoch_vec, set = "validation", crps = val_hist)
  )

  ## ── ggplot ----------------------------------------------------------

  loss_plot <- ggplot(history,
                      ggplot2::aes(x = .data$epoch, y = .data$crps, colour = .data$set)) +
    geom_line(linewidth = 1) +
    geom_point(size = 1.2) +
    scale_colour_manual(values = c("#1B9E77", "#D95F02"),
                        breaks = c("train", "validation"),
                        na.translate = FALSE) +
    labs(x = "Epoch",
         y = "CRPS (ensemble)",
         colour = NULL) +
    theme_light()


  ## ── return ----------------------------------------------------------
  out <- list(model   = model,
              history = history,
              loss_plot = loss_plot)
  out
}

Try the temper package in your browser

Any scripts or data that you put into this service are public.

temper documentation built on Aug. 8, 2025, 7:43 p.m.