policy_tree: Fit a policy with exact tree search

View source: R/policy_tree.R

policy_treeR Documentation

Fit a policy with exact tree search

Description

Finds the optimal (maximizing the sum of rewards) depth k tree by exhaustive search. If the optimal action is the same in both the left and right leaf of a node, the node is pruned.

Usage

policy_tree(
  X,
  Gamma,
  depth = 2,
  split.step = 1,
  min.node.size = 1,
  verbose = TRUE
)

Arguments

X

The covariates used. Dimension N*p where p is the number of features.

Gamma

The rewards for each action. Dimension N*d where d is the number of actions.

depth

The depth of the fitted tree. Default is 2.

split.step

An optional approximation parameter, the number of possible splits to consider when performing tree search. split.step = 1 (default) considers every possible split, split.step = 10 considers splitting at every 10'th sample and may yield a substantial speedup for dense features. Manually rounding or re-encoding continuous covariates with very high cardinality in a problem specific manner allows for finer-grained control of the accuracy/runtime tradeoff and may in some cases be the preferred approach.

min.node.size

An integer indicating the smallest terminal node size permitted. Default is 1.

verbose

Give verbose output. Default is TRUE.

Details

Exact tree search is intended as a way to find shallow (i.e. depth 2 or 3) globally optimal tree-based polices on datasets of "moderate" size. The amortized runtime of exact tree search is O(p^k n^k (log n + d) + pnlog n) where p is the number of features, n the number of distinct observations, d the number of treatments, and k >= 1 the tree depth. Due to the exponents in this expression, exact tree search will not scale to datasets of arbitrary size.

As an example, the runtime of a depth two tree scales quadratically with the number of observations, implying that doubling the number of samples will quadruple the runtime. n refers to the number of distinct observations, substantial speedups can be gained when the features are discrete (with all binary features, the runtime will be ~ linear in n), and it is therefore beneficial to round down/re-encode very dense data to a lower cardinality (the optional parameter split.step emulates this, though rounding/re-encoding allow for finer-grained control).

Value

A policy_tree object.

References

Athey, Susan, and Stefan Wager. "Policy Learning With Observational Data." Econometrica 89.1 (2021): 133-161.

Sverdrup, Erik, Ayush Kanodia, Zhengyuan Zhou, Susan Athey, and Stefan Wager. "policytree: Policy learning via doubly robust empirical welfare maximization over trees." Journal of Open Source Software 5, no. 50 (2020): 2232.

Zhou, Zhengyuan, Susan Athey, and Stefan Wager. "Offline multi-action policy learning: Generalization and optimization." Operations Research 71.1 (2023).

See Also

hybrid_policy_tree for building deeper trees.

Examples


# Construct doubly robust scores using a causal forest.
n <- 10000
p <- 10
# Discretizing continuous covariates decreases runtime for policy learning.
X <- round(matrix(rnorm(n * p), n, p), 2)
colnames(X) <- make.names(1:p)
W <- rbinom(n, 1, 1 / (1 + exp(X[, 3])))
tau <- 1 / (1 + exp((X[, 1] + X[, 2]) / 2)) - 0.5
Y <- X[, 3] + W * tau + rnorm(n)
c.forest <- grf::causal_forest(X, Y, W)

# Retrieve doubly robust scores.
dr.scores <- double_robust_scores(c.forest)

# Learn a depth-2 tree on a training set.
train <- sample(1:n, n / 2)
tree <- policy_tree(X[train, ], dr.scores[train, ], 2)
tree

# Evaluate the tree on a test set.
test <- -train

# One way to assess the policy is to see whether the leaf node (group) the test set samples
# are predicted to belong to have mean outcomes in accordance with the prescribed policy.

# Get the leaf node assigned to each test sample.
node.id <- predict(tree, X[test, ], type = "node.id")

# Doubly robust estimates of E[Y(control)] and E[Y(treated)] by leaf node.
values <- aggregate(dr.scores[test, ], by = list(leaf.node = node.id),
                    FUN = function(dr) c(mean = mean(dr), se = sd(dr) / sqrt(length(dr))))
print(values, digits = 1)

# Take cost of treatment into account by, for example, offsetting the objective
# with an estimate of the average treatment effect.
ate <- grf::average_treatment_effect(c.forest)
cost.offset <- ate[["estimate"]]
dr.scores[, "treated"] <- dr.scores[, "treated"] - cost.offset
tree.cost <- policy_tree(X, dr.scores, 2)

# Predict treatment assignment for each sample.
predicted <- predict(tree, X)

# If there are too many covariates to make tree search computationally feasible, then one
# approach is to consider for example only the top features according to GRF's variable importance.
var.imp <- grf::variable_importance(c.forest)
top.5 <- order(var.imp, decreasing = TRUE)[1:5]
tree.top5 <- policy_tree(X[, top.5], dr.scores, 2, split.step = 50)


policytree documentation built on June 22, 2024, 9:47 a.m.