DoubleMLIRM: Double machine learning for interactive regression models

DoubleMLIRMR Documentation

Double machine learning for interactive regression models

Description

Double machine learning for interactive regression models.

Format

R6::R6Class object inheriting from DoubleML.

Details

Interactive regression (IRM) models take the form

Y = g_0(D,X) + U,

D = m_0(X) + V,

with E[U|X,D]=0 and E[V|X] = 0. Y is the outcome variable and D \in \{0,1\} is the binary treatment variable. We consider estimation of the average treamtent effects when treatment effects are fully heterogeneous. Target parameters of interest in this model are the average treatment effect (ATE),

\theta_0 = E[g_0(1,X) - g_0(0,X)]

and the average treament effect on the treated (ATTE),

\theta_0 = E[g_0(1,X) - g_0(0,X)|D=1].

Super class

DoubleML::DoubleML -> DoubleMLIRM

Active bindings

trimming_rule

(character(1))
A character(1) specifying the trimming approach.

trimming_threshold

(numeric(1))
The threshold used for timming.

Methods

Public methods

Inherited methods

Method new()

Creates a new instance of this R6 class.

Usage
DoubleMLIRM$new(
  data,
  ml_g,
  ml_m,
  n_folds = 5,
  n_rep = 1,
  score = "ATE",
  trimming_rule = "truncate",
  trimming_threshold = 1e-12,
  dml_procedure = "dml2",
  draw_sample_splitting = TRUE,
  apply_cross_fitting = TRUE
)
Arguments
data

(DoubleMLData)
The DoubleMLData object providing the data and specifying the variables of the causal model.

ml_g

(LearnerRegr, LearnerClassif, Learner, character(1))
A learner of the class LearnerRegr, which is available from mlr3 or its extension packages mlr3learners or mlr3extralearners. For binary treatment outcomes, an object of the class LearnerClassif can be passed, for example lrn("classif.cv_glmnet", s = "lambda.min"). Alternatively, a Learner object with public field task_type = "regr" or task_type = "classif" can be passed, respectively, for example of class GraphLearner.
ml_g refers to the nuisance function g_0(X) = E[Y|X,D].

ml_m

(LearnerClassif, Learner, character(1))
A learner of the class LearnerClassif, which is available from mlr3 or its extension packages mlr3learners or mlr3extralearners. Alternatively, a Learner object with public field task_type = "classif" can be passed, for example of class GraphLearner. The learner can possibly be passed with specified parameters, for example lrn("classif.cv_glmnet", s = "lambda.min").
ml_m refers to the nuisance function m_0(X) = E[D|X].

n_folds

(integer(1))
Number of folds. Default is 5.

n_rep

(integer(1))
Number of repetitions for the sample splitting. Default is 1.

score

(character(1), ⁠function()⁠)
A character(1) ("ATE" or ATTE) or a ⁠function()⁠ specifying the score function. If a ⁠function()⁠ is provided, it must be of the form ⁠function(y, d, g0_hat, g1_hat, m_hat, smpls)⁠ and the returned output must be a named list() with elements psi_a and psi_b. Default is "ATE".

trimming_rule

(character(1))
A character(1) ("truncate" is the only choice) specifying the trimming approach. Default is "truncate".

trimming_threshold

(numeric(1))
The threshold used for timming. Default is 1e-12.

dml_procedure

(character(1))
A character(1) ("dml1" or "dml2") specifying the double machine learning algorithm. Default is "dml2".

draw_sample_splitting

(logical(1))
Indicates whether the sample splitting should be drawn during initialization of the object. Default is TRUE.

apply_cross_fitting

(logical(1))
Indicates whether cross-fitting should be applied. Default is TRUE.


Method clone()

The objects of this class are cloneable with this method.

Usage
DoubleMLIRM$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

See Also

Other DoubleML: DoubleML, DoubleMLIIVM, DoubleMLPLIV, DoubleMLPLR

Examples


library(DoubleML)
library(mlr3)
library(mlr3learners)
library(data.table)
set.seed(2)
ml_g = lrn("regr.ranger",
  num.trees = 100, mtry = 20,
  min.node.size = 2, max.depth = 5)
ml_m = lrn("classif.ranger",
  num.trees = 100, mtry = 20,
  min.node.size = 2, max.depth = 5)
obj_dml_data = make_irm_data(theta = 0.5)
dml_irm_obj = DoubleMLIRM$new(obj_dml_data, ml_g, ml_m)
dml_irm_obj$fit()
dml_irm_obj$summary()

## Not run: 
library(DoubleML)
library(mlr3)
library(mlr3learners)
library(mlr3uning)
library(data.table)
set.seed(2)
ml_g = lrn("regr.rpart")
ml_m = lrn("classif.rpart")
obj_dml_data = make_irm_data(theta = 0.5)
dml_irm_obj = DoubleMLIRM$new(obj_dml_data, ml_g, ml_m)

param_grid = list(
  "ml_g" = paradox::ps(
    cp = paradox::p_dbl(lower = 0.01, upper = 0.02),
    minsplit = paradox::p_int(lower = 1, upper = 2)),
  "ml_m" = paradox::ps(
    cp = paradox::p_dbl(lower = 0.01, upper = 0.02),
    minsplit = paradox::p_int(lower = 1, upper = 2)))

# minimum requirements for tune_settings
tune_settings = list(
  terminator = mlr3tuning::trm("evals", n_evals = 5),
  algorithm = mlr3tuning::tnr("grid_search", resolution = 5))
dml_irm_obj$tune(param_set = param_grid, tune_settings = tune_settings)
dml_irm_obj$fit()
dml_irm_obj$summary()

## End(Not run)


DoubleML documentation built on June 22, 2024, 10:50 a.m.