Nothing
#' Spectrally Deconfounded Tree
#'
#' Estimates a regression tree using spectral deconfounding.
#' A regression tree is part of the function class of step functions
#' \eqn{f(X) = \sum_{m = 1}^M 1_{\{X \in R_m\}} c_m}, where (\eqn{R_m}) with
#' \eqn{m = 1, \ldots, M} are regions dividing the space of \eqn{\mathbb{R}^p}
#' into \eqn{M} rectangular parts. Each region has response level \eqn{c_m \in \mathbb{R}}.
#' For the training data, we can write the step function as \eqn{f(\mathbf{X}) = \mathcal{P} c}
#' where \eqn{\mathcal{P} \in \{0, 1\}^{n \times M}} is an indicator matrix encoding
#' to which region an observation belongs and \eqn{c \in \mathbb{R}^M} is a vector
#' containing the levels corresponding to the different regions. This function then minimizes
#' \deqn{(\hat{\mathcal{P}}, \hat{c}) = \text{argmin}_{\mathcal{P}' \in \{0, 1\}^{n \times M}, c' \in \mathbb{R}^ {M}} \frac{||Q(\mathbf{Y} - \mathcal{P'} c')||_2^2}{n}}
#' We find \eqn{\hat{\mathcal{P}}} by using the tree structure and repeated splitting of the leaves,
#' similar to the original cart algorithm \insertCite{Breiman2017ClassificationTrees}{SDModels}.
#' Since comparing all possibilities for \eqn{\mathcal{P}} is impossible, we let a tree grow greedily.
#' Given the current tree, we iterate over all leaves and all possible splits.
#' We choose the one that reduces the spectral loss the most and estimate after each split
#' all the leave estimates
#' \eqn{\hat{c} = \text{argmin}_{c' \in \mathbb{R}^M} \frac{||Q\mathbf{Y} - Q\mathcal{P} c'||_2^2}{n}}
#' which is just a linear regression problem. This is repeated until the loss decreases
#' less than a minimum loss decrease after a split.
#' The minimum loss decrease equals a cost-complexity parameter \eqn{cp} times
#' the initial loss when only an overall mean is estimated.
#' The cost-complexity parameter \eqn{cp} controls the complexity of a regression tree
#' and acts as a regularization parameter.
#' @references
#' \insertAllCited{}
#' @author Markus Ulmer
#' @param formula Object of class \code{formula} or describing the model to fit
#' of the form \code{y ~ x1 + x2 + ...} where \code{y} is a numeric response and
#' \code{x1, x2, ...} are vectors of covariates. Interactions are not supported.
#' @param data Training data of class \code{data.frame} containing the variables in the model.
#' @param x Matrix of covariates, alternative to \code{formula} and \code{data}.
#' @param y Vector of responses, alternative to \code{formula} and \code{data}.
#' @param max_leaves Maximum number of leaves for the grown tree.
#' @param cp Complexity parameter, minimum loss decrease to split a node.
#' A split is only performed if the loss decrease is larger than \code{cp * initial_loss},
#' where \code{initial_loss} is the loss of the initial estimate using only a stump.
#' @param min_sample Minimum number of observations per leaf.
#' A split is only performed if both resulting leaves have at least
#' \code{min_sample} observations.
#' @param mtry Number of randomly selected covariates to consider for a split,
#' if \code{NULL} all covariates are available for each split.
#' @param fast If \code{TRUE}, only the optimal splits in the new leaves are
#' evaluated and the previously optimal splits and their potential loss-decrease are reused.
#' If \code{FALSE} all possible splits in all the leaves are reevaluated after every split.
#' @param Q_type Type of deconfounding, one of 'trim', 'pca', 'no_deconfounding'.
#' 'trim' corresponds to the Trim transform \insertCite{Cevid2020SpectralModels}{SDModels}
#' as implemented in the Doubly debiased lasso \insertCite{Guo2022DoublyConfounding}{SDModels},
#' 'pca' to the PCA transformation\insertCite{Paul2008PreconditioningProblems}{SDModels}.
#' See \code{\link{get_Q}}.
#' @param trim_quantile Quantile for Trim transform,
#' only needed for trim, see \code{\link{get_Q}}.
#' @param q_hat Assumed confounding dimension, only needed for pca,
#' see \code{\link{get_Q}}.
#' @param Qf Spectral transformation, if \code{NULL}
#' it is internally estimated using \code{\link{get_Q}}.
#' @param A Numerical Anchor of class \code{matrix}. See \code{\link{get_W}}.
#' @param gamma Strength of distributional robustness, \eqn{\gamma \in [0, \infty]}.
#' See \code{\link{get_W}}.
#' @param gpu If \code{TRUE}, the calculations are performed on the GPU.
#' If it is properly set up.
#' @param mem_size Amount of split candidates that can be evaluated at once.
#' This is a trade-off between memory and speed can be decreased if either
#' the memory is not sufficient or the gpu is to small.
#' @param max_candidates Maximum number of split points that are
#' proposed at each node for each covariate.
#' @param Q_scale Should data be scaled to estimate the spectral transformation?
#' Default is \code{TRUE} to not reduce the signal of high variance covariates,
#' and we do not know of a scenario where this hurts.
#' @return Object of class \code{SDTree} containing
#' \item{predictions}{Predictions for the training set.}
#' \item{tree}{The estimated tree of class \code{Node} from \insertCite{Glur2023Data.tree:Structure}{SDModels}.
#' The tree contains the information about all the splits and the resulting estimates.}
#' \item{var_names}{Names of the covariates in the training data.}
#' \item{var_importance}{Variable importance of the covariates.
#' The variable importance is calculated as the sum of the decrease in the loss
#' function resulting from all splits that use this covariate.}
#' @seealso \code{\link{simulate_data_nonlinear}}, \code{\link{regPath.SDTree}},
#' \code{\link{prune.SDTree}}, \code{\link{partDependence}}
#' @examples
#' set.seed(1)
#' n <- 10
#' X <- matrix(rnorm(n * 5), nrow = n)
#' y <- sign(X[, 1]) * 3 + rnorm(n)
#' model <- SDTree(x = X, y = y, cp = 0.5)
#'
#' \donttest{
#' set.seed(42)
#' # simulation of confounded data
#' sim_data <- simulate_data_step(q = 2, p = 15, n = 100, m = 2)
#' X <- sim_data$X
#' Y <- sim_data$Y
#' train_data <- data.frame(X, Y)
#' # causal parents of y
#' sim_data$j
#'
#' tree_plain_cv <- cvSDTree(Y ~ ., train_data, Q_type = "no_deconfounding")
#' tree_plain <- SDTree(Y ~ ., train_data, Q_type = "no_deconfounding", cp = 0)
#'
#' tree_causal_cv <- cvSDTree(Y ~ ., train_data)
#' tree_causal <- SDTree(y = Y, x = X, cp = 0)
#'
#' # check regularization path of variable importance
#' path <- regPath(tree_causal)
#' plot(path)
#'
#' tree_plain <- prune(tree_plain, cp = tree_plain_cv$cp_min)
#' tree_causal <- prune(tree_causal, cp = tree_causal_cv$cp_min)
#' plot(tree_causal)
#' plot(tree_plain)
#' }
#' @export
SDTree <- function(formula = NULL, data = NULL, x = NULL, y = NULL, max_leaves = NULL,
cp = 0.01, min_sample = 5, mtry = NULL, fast = TRUE,
Q_type = 'trim', trim_quantile = 0.5, q_hat = 0, Qf = NULL,
A = NULL, gamma = 0.5, gpu = FALSE, mem_size = 1e+7, max_candidates = 100,
Q_scale = TRUE){
if(gpu) ifelse(GPUmatrix::installTorch(),
gpu_type <- 'torch',
gpu_type <- 'tensorflow')
input_data <- data.handler(formula = formula, data = data, x = x, y = y)
X <- input_data$X
Y <- input_data$Y
# number of observations
n <- nrow(X)
# number of covariates
p <- ncol(X)
if(is.null(max_leaves)) max_leaves <- n
max_leaves <- max_leaves - 1
mem_size <- mem_size / n
# check validity of input
if(n != length(Y)) stop('X and Y must have the same number of observations')
if(max_leaves < 0) stop('max_leaves must be larger than 1')
if(min_sample < 1) stop('min_sample must be larger than 0')
if(cp < 0) stop('cp must be at least 0')
if(!is.null(mtry) && mtry < 1) stop('mtry must be larger than 0')
if(!is.null(mtry) && mtry > p) stop('mtry must be at most p')
if(n < 2 * min_sample) stop('n must be at least 2 * min_sample')
if(max_candidates < 1) stop('max_candidates must be at least 1')
# estimate spectral transformation
if(!is.null(A)){
if(is.null(gamma)) stop('gamma must be provided if A is provided')
if(is.vector(A)) A <- matrix(A)
if(!is.matrix(A)) stop('A must be a matrix')
if(nrow(A) != n) stop('A must have n rows')
Wf <- get_Wf(A, gamma, gpu)
}else {
Wf <- function(v) v
}
if(is.null(Qf)){
if(!is.null(A)){
Qf <- function(v) get_Qf(Wf(X), Q_type, trim_quantile, q_hat, gpu, Q_scale)(Wf(v))
}else{
Qf <- get_Qf(X, Q_type, trim_quantile, q_hat, gpu, Q_scale)
}
}else{
if(!is.function(Qf)) stop('Q must be a function')
if(length(Qf(rnorm(n))) == n) stop('Q must map from n to n')
}
# calculate first estimate
E <- matrix(1, n, 1)
E_tilde <- Qf(E)
if(gpu){
E_tilde <- gpu.matrix(E_tilde, type = gpu_type)
}
Ue <- E_tilde / sqrt(sum(E_tilde ** 2))
Y_tilde <- Qf(Y)
# solve linear model
if(gpu && gpu_type == 'tensorflow'){
c_hat <- lm.fit(as.matrix(E_tilde), as.matrix(Y_tilde))$coefficients
}else{
c_hat <- qr.coef(qr(E_tilde), Y_tilde)
c_hat <- as.numeric(c_hat)
}
loss_start <- as.numeric(sum((Y_tilde - c_hat) ** 2) / n)
loss_temp <- loss_start
# initialize tree
tree <- data.tree::Node$new(name = '1', value = as.numeric(c_hat),
dloss = as.numeric(loss_start),
cp = 10, n_samples = n)
# memory for optimal splits
memory <- list()
potential_splits <- 1
# variable importance
var_imp <- rep(0, p)
names(var_imp) <- colnames(X)
after_mtry <- 0
for(i in 1:max_leaves){
# iterate over all possible splits every time
# for slow but slightly better solution
if(!fast){
potential_splits <- 1:i
to_small <- sapply(potential_splits,
function(x){sum(E[, x]) < min_sample*2})
potential_splits <- potential_splits[!to_small]
}
#iterate over new to estimate splits
for(branch in potential_splits){
E_branch <- E[, branch]
index <- which(E_branch == 1)
X_branch <- matrix(X[index, ], nrow = length(index))
s <- find_s(X_branch, max_candidates = max_candidates)
n_splits <- nrow(s)
if(min_sample > 1) {
s <- s[-c(0:(min_sample - 1), (n_splits - min_sample + 2):(n_splits+1)), ]
}
s <- matrix(s, ncol = p)
all_n_splits <- apply(s, 2, function(x) length(unique(x)))
all_idx <- cumsum(all_n_splits)
eval <- matrix(-Inf, nrow(s), p)
done_splits <- 0
p_top <- 0
while(p_top < p){
c_all_idx <- all_idx - done_splits
p_low <- p_top + 1
possible <- which(c_all_idx < mem_size)
p_top <- possible[length(possible)]
c_n_splits <- sum(all_idx[p_top], -all_idx[p_low-1])
E_next <- matrix(0, n, c_n_splits)
for(j in p_low:p_top){
s_j <- s[, j]
s_j <- unique(s_j)
for(i_s in 1:all_n_splits[j]){
E_next[index[X_branch[, j] > s_j[i_s]], sum(c_all_idx[j-1], i_s)] <- 1
}
}
if(gpu) E_next <- gpu.matrix(E_next, type = gpu_type)
U_next_prime <- Qf_temp(E_next, Ue, Qf)
U_next_size <- colSums(U_next_prime ** 2)
dloss <- as.numeric(crossprod(U_next_prime, Y_tilde))**2 / U_next_size
dloss[is.na(dloss)] <- 0
for(m in p_low:p_top){
eval[1:all_n_splits[m], m] <- dloss[sum(c_all_idx[m-1], 1):c_all_idx[m]]
}
done_splits <- done_splits + c_n_splits
}
is_opt <- apply(eval, 2, which.max)
memory[[branch]] <- t(sapply(1:p, function(j)
c(eval[is_opt[j], j], j, unique(s[, j])[is_opt[j]], branch)))
}
if(i > after_mtry && !is.null(mtry)){
Losses_dec <- lapply(memory, function(branch){
branch[sample(1:p, mtry), ]})
Losses_dec <- do.call(rbind, Losses_dec)
}else {
Losses_dec <- do.call(rbind, memory)
}
loc <- which.max(Losses_dec[, 1])
best_branch <- Losses_dec[loc, 4]
j <- Losses_dec[loc, 2]
s <- Losses_dec[loc, 3]
if(Losses_dec[loc, 1] <= 0){
break
}
# divide observations in leave
index <- which(E[, best_branch] == 1)
index_n_branches <- index[X[index, j] > s]
# new indicator matrix
E <- cbind(E, matrix(0, n, 1))
E[index_n_branches, best_branch] <- 0
E[index_n_branches, i+1] <- 1
E_tilde_branch <- E_tilde[, best_branch]
suppressWarnings({
E_tilde[, best_branch] <- Qf(E[, best_branch])
})
E_tilde <- cbind(E_tilde, matrix(E_tilde_branch - E_tilde[, best_branch]))
if(gpu && gpu_type == 'tensorflow'){
c_hat <- lm.fit(as.matrix(E_tilde), as.matrix(Y_tilde))$coefficients
}else{
c_hat <- qr.coef(qr(E_tilde), Y_tilde)
}
u_next_prime <- Qf_temp(E[, i + 1], Ue, Qf)
Ue <- cbind(Ue, u_next_prime / sqrt(sum(u_next_prime ** 2)))
# check if loss decrease is larger than minimum loss decrease
# and if linear model could be estimated
if(sum(is.na(as.matrix(c_hat))) > 0){
warning('singulaer matrix QE, tree might be to large, consider increasing cp')
break
}
loss_dec <- as.numeric(loss_temp - loss(Y_tilde, E_tilde %*% c_hat))
loss_temp <- loss_temp - loss_dec
if(loss_dec <= cp * loss_start){
break
}
# add loss decrease to variable importance
var_imp[j] <- var_imp[j] + loss_dec
# select leave to split
if(tree$height == 1){
leave <- tree
}else{
leaves <- tree$leaves
leave <- leaves[[which(tree$Get('name', filterFun = data.tree::isLeaf) == best_branch)]]
}
# save split rule
leave$j <- j
leave$s <- s
leave$res_dloss <- loss_dec
# add new leaves
leave$AddChild(best_branch, value = 0, dloss = loss_dec,
cp = loss_dec / loss_start, decision = 'yes',
n_samples = sum(E[, best_branch] == 1))
leave$AddChild(i + 1, value = 0, dloss = loss_dec,
cp = loss_dec / loss_start, decision = 'no',
n_samples = sum(E[, i + 1] == 1))
# add estimates to tree leaves
c_hat <- as.numeric(c_hat)
for(l in tree$leaves){
l$value <- c_hat[as.numeric(l$name)]
}
# the two new partitions need to be checked for optimal splits in next iteration
potential_splits <- c(best_branch, i + 1)
# a partition with less than min_sample observations or unique samples
# are not available for further splits
to_small <- sapply(potential_splits, function(x){
new_samples <- nrow(unique(matrix(X[as.logical(E[, x]),], nrow = sum(E[, x]))))
if(is.null(new_samples)) new_samples <- 0
(new_samples < min_sample * 2)
})
if(sum(to_small) > 0){
for(el in potential_splits[to_small]){
# to small partitions cannot decrease the loss
memory[[el]] <- matrix(0, p, 4)
}
potential_splits <- potential_splits[!to_small]
}
}
if(i == max_leaves){
warning('maximum number of iterations was reached, consider increasing m!')
}
# predict the test set
f_X_hat <- predict_outsample(tree, X)
var_names <- colnames(data.frame(X))
names(var_imp) <- var_names
# labels for the nodes
tree$Do(split_names, filterFun = data.tree::isNotLeaf, var_names = var_names)
tree$Do(leave_names, filterFun = data.tree::isLeaf)
# cp max of all splits after
tree$Do(function(node) node$cp_max <- max(node$Get('cp')))
tree$Do(function(node) {
cp_max <- data.tree::Aggregate(node, 'cp_max', max)
node$children[[1]]$cp_max <- cp_max
node$children[[2]]$cp_max <- cp_max
}, filterFun = data.tree::isNotLeaf
)
res <- list(predictions = f_X_hat, tree = tree,
var_names = var_names, var_importance = var_imp)
class(res) <- 'SDTree'
res
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.