inst/doc/cram_policy_part_1.R

## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
library(cramR)
library(DT)

## -----------------------------------------------------------------------------
library(data.table)
# Function to generate sample data with heterogeneous treatment effects:
# - Positive effect group
# - Neutral effect group
# - Adverse effect group
generate_data <- function(n) {
  X <- data.table(
    binary = rbinom(n, 1, 0.5),                 # Binary variable
    discrete = sample(1:5, n, replace = TRUE),  # Discrete variable
    continuous = rnorm(n)                       # Continuous variable
  )

  # Binary treatment assignment (50% treated)
  D <- rbinom(n, 1, 0.5)

  # Define heterogeneous treatment effects based on X
  treatment_effect <- ifelse(
    X[, binary] == 1 & X[, discrete] <= 2,        # Group 1: Positive effect
    1,
    ifelse(X[, binary] == 0 & X[, discrete] >= 4, # Group 3: Adverse effect
           -1,
           0.1)                                   # Group 2: Neutral effect
  )

  # Outcome depends on treatment effect + noise
  Y <- D * (treatment_effect + rnorm(n, mean = 0, sd = 1)) +
    (1 - D) * rnorm(n)

  return(list(X = X, D = D, Y = Y))
}

# Generate a sample dataset
set.seed(123)
n <- 1000
data <- generate_data(n)
X <- data$X
D <- data$D
Y <- data$Y


## -----------------------------------------------------------------------------
# Options for batch:
# Either an integer specifying the number of batches or a vector/list of batch assignments for all individuals
batch <- 20

# Model type for estimating treatment effects
# Options for model_type: 'causal_forest', 's_learner', 'm_learner'
# Note: you can also set model_type to NULL and specify custom_fit and custom_predict to use your custom model
model_type <- "causal_forest"  

# Options for learner_type:
# if model_type == 'causal_forest', choose NULL
# if model_type == 's_learner' or 'm_learner', choose between 'ridge', 'fnn' and 'caret'
learner_type <- NULL  

# Baseline policy to compare against (list of 0/1 for each individual)
# Options for baseline_policy:
# A list representing the baseline policy assignment for each individual.
# If NULL, a default baseline policy of zeros is created.
# Examples of baseline policy: 
# - All-control baseline: as.list(rep(0, nrow(X))) or NULL
# - Randomized baseline: as.list(sample(c(0, 1), nrow(X), replace = TRUE))
baseline_policy <- as.list(rep(0, nrow(X)))  

# Whether to parallelize batch processing (i.e. the cram method learns T policies, with T the number of batches.
# They are learned in parallel when parallelize_batch is TRUE
# vs. learned sequentially using the efficient data.table structure when parallelize_batch is FALSE, recommended for light weight training).
# Defaults to FALSE.
parallelize_batch <- FALSE  
 

# Model-specific parameters (more details in the article "Cram Policy part 2")
# Examples: NULL defaults to the following:
# - causal_forest: list(num.trees = 100)
# - ridge: list(alpha = 1)
# - caret: list(formula = Y ~ ., caret_params = list(method = "lm", trControl = trainControl(method = "none")))
# - fnn (Feedforward Neural Network): see below
# input_shape <- if (model_type == "s_learner") ncol(X) + 1 else ncol(X)
# default_model_params <- list(
#       input_layer = list(units = 64, activation = 'relu', input_shape = input_shape),
#       layers = list(
#         list(units = 32, activation = 'relu')
#       ),
#       output_layer = list(units = 1, activation = 'linear'),
#       compile_args = list(optimizer = 'adam', loss = 'mse'),
#       fit_params = list(epochs = 5, batch_size = 32, verbose = 0)
#     )
model_params <- NULL  


# Significance level for confidence intervals (default = 95%)
alpha <- 0.05  

# Run the Cram policy method
result <- cram_policy(
  X, D, Y,
  batch = batch,
  model_type = model_type,
  learner_type = learner_type,
  baseline_policy = baseline_policy,
  parallelize_batch = parallelize_batch,
  model_params = model_params,
  alpha = alpha
)

# Display the results
print(result)


## -----------------------------------------------------------------------------
result$raw_results

## -----------------------------------------------------------------------------
result$interactive_table

## -----------------------------------------------------------------------------
class(result$final_policy_model)
summary(result$final_policy_model)


## ----cleanup-autograph, include=FALSE-----------------------------------------
autograph_files <- list.files(tempdir(), pattern = "^__autograph_generated_file.*\\.py$", full.names = TRUE)
if (length(autograph_files) > 0) {
  try(unlink(autograph_files, recursive = TRUE, force = TRUE), silent = TRUE)
}

Try the cramR package in your browser

Any scripts or data that you put into this service are public.

cramR documentation built on Aug. 25, 2025, 1:12 a.m.