Nothing
## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)
library(cramR)
library(DT)
## -----------------------------------------------------------------------------
library(data.table)
# Function to generate sample data with heterogeneous treatment effects:
# - Positive effect group
# - Neutral effect group
# - Adverse effect group
generate_data <- function(n) {
X <- data.table(
binary = rbinom(n, 1, 0.5), # Binary variable
discrete = sample(1:5, n, replace = TRUE), # Discrete variable
continuous = rnorm(n) # Continuous variable
)
# Binary treatment assignment (50% treated)
D <- rbinom(n, 1, 0.5)
# Define heterogeneous treatment effects based on X
treatment_effect <- ifelse(
X[, binary] == 1 & X[, discrete] <= 2, # Group 1: Positive effect
1,
ifelse(X[, binary] == 0 & X[, discrete] >= 4, # Group 3: Adverse effect
-1,
0.1) # Group 2: Neutral effect
)
# Outcome depends on treatment effect + noise
Y <- D * (treatment_effect + rnorm(n, mean = 0, sd = 1)) +
(1 - D) * rnorm(n)
return(list(X = X, D = D, Y = Y))
}
# Generate a sample dataset
set.seed(123)
n <- 1000
data <- generate_data(n)
X <- data$X
D <- data$D
Y <- data$Y
## -----------------------------------------------------------------------------
# Options for batch:
# Either an integer specifying the number of batches or a vector/list of batch assignments for all individuals
batch <- 20
# Model type for estimating treatment effects
# Options for model_type: 'causal_forest', 's_learner', 'm_learner'
# Note: you can also set model_type to NULL and specify custom_fit and custom_predict to use your custom model
model_type <- "causal_forest"
# Options for learner_type:
# if model_type == 'causal_forest', choose NULL
# if model_type == 's_learner' or 'm_learner', choose between 'ridge', 'fnn' and 'caret'
learner_type <- NULL
# Baseline policy to compare against (list of 0/1 for each individual)
# Options for baseline_policy:
# A list representing the baseline policy assignment for each individual.
# If NULL, a default baseline policy of zeros is created.
# Examples of baseline policy:
# - All-control baseline: as.list(rep(0, nrow(X))) or NULL
# - Randomized baseline: as.list(sample(c(0, 1), nrow(X), replace = TRUE))
baseline_policy <- as.list(rep(0, nrow(X)))
# Whether to parallelize batch processing (i.e. the cram method learns T policies, with T the number of batches.
# They are learned in parallel when parallelize_batch is TRUE
# vs. learned sequentially using the efficient data.table structure when parallelize_batch is FALSE, recommended for light weight training).
# Defaults to FALSE.
parallelize_batch <- FALSE
# Model-specific parameters (more details in the article "Cram Policy part 2")
# Examples: NULL defaults to the following:
# - causal_forest: list(num.trees = 100)
# - ridge: list(alpha = 1)
# - caret: list(formula = Y ~ ., caret_params = list(method = "lm", trControl = trainControl(method = "none")))
# - fnn (Feedforward Neural Network): see below
# input_shape <- if (model_type == "s_learner") ncol(X) + 1 else ncol(X)
# default_model_params <- list(
# input_layer = list(units = 64, activation = 'relu', input_shape = input_shape),
# layers = list(
# list(units = 32, activation = 'relu')
# ),
# output_layer = list(units = 1, activation = 'linear'),
# compile_args = list(optimizer = 'adam', loss = 'mse'),
# fit_params = list(epochs = 5, batch_size = 32, verbose = 0)
# )
model_params <- NULL
# Significance level for confidence intervals (default = 95%)
alpha <- 0.05
# Run the Cram policy method
result <- cram_policy(
X, D, Y,
batch = batch,
model_type = model_type,
learner_type = learner_type,
baseline_policy = baseline_policy,
parallelize_batch = parallelize_batch,
model_params = model_params,
alpha = alpha
)
# Display the results
print(result)
## -----------------------------------------------------------------------------
result$raw_results
## -----------------------------------------------------------------------------
result$interactive_table
## -----------------------------------------------------------------------------
class(result$final_policy_model)
summary(result$final_policy_model)
## ----cleanup-autograph, include=FALSE-----------------------------------------
autograph_files <- list.files(tempdir(), pattern = "^__autograph_generated_file.*\\.py$", full.names = TRUE)
if (length(autograph_files) > 0) {
try(unlink(autograph_files, recursive = TRUE, force = TRUE), silent = TRUE)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.