SDTree | R Documentation |
Estimates a regression tree using spectral deconfounding.
A regression tree is part of the function class of step functions
f(X) = \sum_{m = 1}^M 1_{\{X \in R_m\}} c_m
, where (R_m
) with
m = 1, \ldots, M
are regions dividing the space of \mathbb{R}^p
into M
rectangular parts. Each region has response level c_m \in \mathbb{R}
.
For the training data, we can write the step function as f(\mathbf{X}) = \mathcal{P} c
where \mathcal{P} \in \{0, 1\}^{n \times M}
is an indicator matrix encoding
to which region an observation belongs and c \in \mathbb{R}^M
is a vector
containing the levels corresponding to the different regions. This function then minimizes
(\hat{\mathcal{P}}, \hat{c}) = \text{argmin}_{\mathcal{P}' \in \{0, 1\}^{n \times M}, c' \in \mathbb{R}^ {M}} \frac{||Q(\mathbf{Y} - \mathcal{P'} c')||_2^2}{n}
We find \hat{\mathcal{P}}
by using the tree structure and repeated splitting of the leaves,
similar to the original cart algorithm \insertCiteBreiman2017ClassificationTreesSDModels.
Since comparing all possibilities for \mathcal{P}
is impossible, we let a tree grow greedily.
Given the current tree, we iterate over all leaves and all possible splits.
We choose the one that reduces the spectral loss the most and estimate after each split
all the leave estimates
\hat{c} = \text{argmin}_{c' \in \mathbb{R}^M} \frac{||Q\mathbf{Y} - Q\mathcal{P} c'||_2^2}{n}
which is just a linear regression problem. This is repeated until the loss decreases
less than a minimum loss decrease after a split.
The minimum loss decrease equals a cost-complexity parameter cp
times
the initial loss when only an overall mean is estimated.
The cost-complexity parameter cp
controls the complexity of a regression tree
and acts as a regularization parameter.
SDTree(
formula = NULL,
data = NULL,
x = NULL,
y = NULL,
max_leaves = NULL,
cp = 0.01,
min_sample = 5,
mtry = NULL,
fast = TRUE,
Q_type = "trim",
trim_quantile = 0.5,
q_hat = 0,
Qf = NULL,
A = NULL,
gamma = 0.5,
gpu = FALSE,
mem_size = 1e+07,
max_candidates = 100,
Q_scale = TRUE
)
formula |
Object of class |
data |
Training data of class |
x |
Matrix of covariates, alternative to |
y |
Vector of responses, alternative to |
max_leaves |
Maximum number of leaves for the grown tree. |
cp |
Complexity parameter, minimum loss decrease to split a node.
A split is only performed if the loss decrease is larger than |
min_sample |
Minimum number of observations per leaf.
A split is only performed if both resulting leaves have at least
|
mtry |
Number of randomly selected covariates to consider for a split,
if |
fast |
If |
Q_type |
Type of deconfounding, one of 'trim', 'pca', 'no_deconfounding'.
'trim' corresponds to the Trim transform \insertCiteCevid2020SpectralModelsSDModels
as implemented in the Doubly debiased lasso \insertCiteGuo2022DoublyConfoundingSDModels,
'pca' to the PCA transformation\insertCitePaul2008PreconditioningProblemsSDModels.
See |
trim_quantile |
Quantile for Trim transform,
only needed for trim, see |
q_hat |
Assumed confounding dimension, only needed for pca,
see |
Qf |
Spectral transformation, if |
A |
Numerical Anchor of class |
gamma |
Strength of distributional robustness, |
gpu |
If |
mem_size |
Amount of split candidates that can be evaluated at once. This is a trade-off between memory and speed can be decreased if either the memory is not sufficient or the gpu is to small. |
max_candidates |
Maximum number of split points that are proposed at each node for each covariate. |
Q_scale |
Should data be scaled to estimate the spectral transformation?
Default is |
Object of class SDTree
containing
predictions |
Predictions for the training set. |
tree |
The estimated tree of class |
var_names |
Names of the covariates in the training data. |
var_importance |
Variable importance of the covariates. The variable importance is calculated as the sum of the decrease in the loss function resulting from all splits that use this covariate. |
Markus Ulmer
simulate_data_nonlinear
, regPath.SDTree
,
prune.SDTree
, partDependence
set.seed(1)
n <- 10
X <- matrix(rnorm(n * 5), nrow = n)
y <- sign(X[, 1]) * 3 + rnorm(n)
model <- SDTree(x = X, y = y, cp = 0.5)
set.seed(42)
# simulation of confounded data
sim_data <- simulate_data_step(q = 2, p = 15, n = 100, m = 2)
X <- sim_data$X
Y <- sim_data$Y
train_data <- data.frame(X, Y)
# causal parents of y
sim_data$j
tree_plain_cv <- cvSDTree(Y ~ ., train_data, Q_type = "no_deconfounding")
tree_plain <- SDTree(Y ~ ., train_data, Q_type = "no_deconfounding", cp = 0)
tree_causal_cv <- cvSDTree(Y ~ ., train_data)
tree_causal <- SDTree(y = Y, x = X, cp = 0)
# check regularization path of variable importance
path <- regPath(tree_causal)
plot(path)
tree_plain <- prune(tree_plain, cp = tree_plain_cv$cp_min)
tree_causal <- prune(tree_causal, cp = tree_causal_cv$cp_min)
plot(tree_causal)
plot(tree_plain)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.