Nothing
#' @title Double machine learning for interactive IV regression models
#'
#' @description
#' Double machine learning for interactive IV regression models.
#'
#' @format [R6::R6Class] object inheriting from [DoubleML].
#'
#' @family DoubleML
#'
#' @details
#' Interactive IV regression (IIVM) models take the form
#'
#' \eqn{Y = \ell_0(D,X) + \zeta},
#'
#' \eqn{Z = m_0(X) + V},
#'
#' with \eqn{E[\zeta|X,Z]=0} and \eqn{E[V|X] = 0}. \eqn{Y} is the outcome
#' variable, \eqn{D \in \{0,1\}} is the binary treatment variable and
#' \eqn{Z \in \{0,1\}} is a binary instrumental variable. Consider the functions
#' \eqn{g_0}, \eqn{r_0} and \eqn{m_0}, where \eqn{g_0} maps the support of
#' \eqn{(Z,X)} to \eqn{R} and \eqn{r_0} and \eqn{m_0}, respectively, map the
#' support of \eqn{(Z,X)} and \eqn{X} to \eqn{(\epsilon, 1-\epsilon)} for some
#' \eqn{\epsilon \in (1, 1/2)}, such that
#'
#' \eqn{Y = g_0(Z,X) + \nu,}
#'
#' \eqn{D = r_0(Z,X) + U,}
#'
#' \eqn{Z = m_0(X) + V,}
#'
#' with \eqn{E[\nu|Z,X]=0}, \eqn{E[U|Z,X]=0} and \eqn{E[V|X]=0}. The target
#' parameter of interest in this model is the local average treatment effect
#' (LATE),
#'
#' \eqn{\theta_0 = \frac{E[g_0(1,X)] - E[g_0(0,X)]}{E[r_0(1,X)] - E[r_0(0,X)]}.}
#'
#'
#' @usage NULL
#'
#' @examples
#' \donttest{
#' library(DoubleML)
#' library(mlr3)
#' library(mlr3learners)
#' library(data.table)
#' set.seed(2)
#' ml_g = lrn("regr.ranger",
#' num.trees = 100, mtry = 20,
#' min.node.size = 2, max.depth = 5)
#' ml_m = lrn("classif.ranger",
#' num.trees = 100, mtry = 20,
#' min.node.size = 2, max.depth = 5)
#' ml_r = ml_m$clone()
#' obj_dml_data = make_iivm_data(
#' theta = 0.5, n_obs = 1000,
#' alpha_x = 1, dim_x = 20)
#' dml_iivm_obj = DoubleMLIIVM$new(obj_dml_data, ml_g, ml_m, ml_r)
#' dml_iivm_obj$fit()
#' dml_iivm_obj$summary()
#' }
#'
#' \dontrun{
#' library(DoubleML)
#' library(mlr3)
#' library(mlr3learners)
#' library(mlr3tuning)
#' library(data.table)
#' set.seed(2)
#' ml_g = lrn("regr.rpart")
#' ml_m = lrn("classif.rpart")
#' ml_r = ml_m$clone()
#' obj_dml_data = make_iivm_data(
#' theta = 0.5, n_obs = 1000,
#' alpha_x = 1, dim_x = 20)
#' dml_iivm_obj = DoubleMLIIVM$new(obj_dml_data, ml_g, ml_m, ml_r)
#' param_grid = list(
#' "ml_g" = paradox::ParamSet$new(list(
#' paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
#' paradox::ParamInt$new("minsplit", lower = 1, upper = 2))),
#' "ml_m" = paradox::ParamSet$new(list(
#' paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
#' paradox::ParamInt$new("minsplit", lower = 1, upper = 2))),
#' "ml_r" = paradox::ParamSet$new(list(
#' paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
#' paradox::ParamInt$new("minsplit", lower = 1, upper = 2))))
#' # minimum requirements for tune_settings
#' tune_settings = list(
#' terminator = mlr3tuning::trm("evals", n_evals = 5),
#' algorithm = mlr3tuning::tnr("grid_search", resolution = 5))
#' dml_iivm_obj$tune(param_set = param_grid, tune_settings = tune_settings)
#' dml_iivm_obj$fit()
#' dml_iivm_obj$summary()
#' }
#'
#' @export
DoubleMLIIVM = R6Class("DoubleMLIIVM",
inherit = DoubleML,
active = list(
#' @field subgroups (named `list(2)`) \cr
#' Named `list(2)` with options to adapt to cases with and without the
#' subgroups of always-takers and never-takes.
#' The entry `always_takers`(`logical(1)`) speficies whether there are
#' always takers in the sample. The entry `never_takers` (`logical(1)`)
#' speficies whether there are never takers in the sample.
subgroups = function(value) {
if (missing(value)) {
return(private$subgroups_)
} else {
stop("can't set field subgroups")
}
},
#' @field trimming_rule (`character(1)`) \cr
#' A `character(1)` specifying the trimming approach.
trimming_rule = function(value) {
if (missing(value)) {
return(private$trimming_rule_)
} else {
stop("can't set field trimming_rule")
}
},
#' @field trimming_threshold (`numeric(1)`) \cr
#' The threshold used for timming.
trimming_threshold = function(value) {
if (missing(value)) {
return(private$trimming_threshold_)
} else {
stop("can't set field trimming_threshold")
}
}),
public = list(
#' @description
#' Creates a new instance of this R6 class.
#'
#' @param data (`DoubleMLData`) \cr
#' The `DoubleMLData` object providing the data and specifying the variables
#' of the causal model.
#'
#' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr],
#' [`LearnerClassif`][mlr3::LearnerClassif], [`Learner`][mlr3::Learner],
#' `character(1)`) \cr
#' A learner of the class [`LearnerRegr`][mlr3::LearnerRegr], which is
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
#' For binary treatment outcomes, an object of the class
#' [`LearnerClassif`][mlr3::LearnerClassif] can be passed, for example
#' `lrn("classif.cv_glmnet", s = "lambda.min")`.
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
#' `task_type = "regr"` or `task_type = "classif"` can be passed,
#' respectively, for example of class
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. \cr
#' `ml_g` refers to the nuisance function \eqn{g_0(Z,X) = E[Y|X,Z]}.
#'
#' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif],
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
#' A learner of the class [`LearnerClassif`][mlr3::LearnerClassif], which is
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
#' `task_type = "classif"` can be passed, for example of class
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
#' be passed with specified parameters, for example
#' `lrn("classif.cv_glmnet", s = "lambda.min")`. \cr
#' `ml_m` refers to the nuisance function \eqn{m_0(X) = E[Z|X]}.
#'
#' @param ml_r ([`LearnerClassif`][mlr3::LearnerClassif],
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
#' A learner of the class [`LearnerClassif`][mlr3::LearnerClassif], which is
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
#' `task_type = "classif"` can be passed, for example of class
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
#' be passed with specified parameters, for example
#' `lrn("classif.cv_glmnet", s = "lambda.min")`. \cr
#' `ml_r` refers to the nuisance function \eqn{r_0(Z,X) = E[D|X,Z]}.
#'
#' @param n_folds (`integer(1)`)\cr
#' Number of folds. Default is `5`.
#'
#' @param n_rep (`integer(1)`) \cr
#' Number of repetitions for the sample splitting. Default is `1`.
#'
#' @param score (`character(1)`, `function()`) \cr
#' A `character(1)` (`"LATE"` is the only choice) specifying the score
#' function.
#' If a `function()` is provided, it must be of the form
#' `function(y, z, d, g0_hat, g1_hat, m_hat, r0_hat, r1_hat, smpls)` and
#' the returned output must be a named `list()` with elements `psi_a` and
#' `psi_b`. Default is `"LATE"`.
#'
#' @param subgroups (named `list(2)`) \cr
#' Named `list(2)` with options to adapt to cases with and without the
#' subgroups of always-takers and never-takes. The entry
#' `always_takers`(`logical(1)`) speficies whether there are always takers
#' in the sample. The entry `never_takers` (`logical(1)`) speficies whether
#' there are never takers in the sample. Default is
#' `list(always_takers = TRUE, never_takers = TRUE)`.
#'
#' @param trimming_rule (`character(1)`) \cr
#' A `character(1)` (`"truncate"` is the only choice) specifying the
#' trimming approach. Default is `"truncate"`.
#' @param trimming_threshold (`numeric(1)`) \cr
#' The threshold used for timming. Default is `1e-12`.
#'
#' @param dml_procedure (`character(1)`) \cr
#' A `character(1)` (`"dml1"` or `"dml2"`) specifying the double machine
#' learning algorithm. Default is `"dml2"`.
#'
#' @param draw_sample_splitting (`logical(1)`) \cr
#' Indicates whether the sample splitting should be drawn during
#' initialization of the object. Default is `TRUE`.
#'
#' @param apply_cross_fitting (`logical(1)`) \cr
#' Indicates whether cross-fitting should be applied. Default is `TRUE`.
initialize = function(data,
ml_g,
ml_m,
ml_r,
n_folds = 5,
n_rep = 1,
score = "LATE",
subgroups = list(
always_takers = TRUE,
never_takers = TRUE),
dml_procedure = "dml2",
trimming_rule = "truncate",
trimming_threshold = 1e-12,
draw_sample_splitting = TRUE,
apply_cross_fitting = TRUE) {
super$initialize_double_ml(
data,
n_folds,
n_rep,
score,
dml_procedure,
draw_sample_splitting,
apply_cross_fitting)
private$check_data(self$data)
private$check_score(self$score)
ml_g = private$assert_learner(ml_g, "ml_g", Regr = TRUE, Classif = TRUE)
ml_m = private$assert_learner(ml_m, "ml_m", Regr = FALSE, Classif = TRUE)
ml_r = private$assert_learner(ml_r, "ml_r", Regr = FALSE, Classif = TRUE)
private$learner_ = list(
"ml_g" = ml_g,
"ml_m" = ml_m,
"ml_r" = ml_r)
private$initialize_ml_nuisance_params()
private$subgroups_ = subgroups
private$trimming_rule_ = trimming_rule
private$trimming_threshold_ = trimming_threshold
}
),
private = list(
subgroups_ = NULL,
trimming_rule_ = NULL,
trimming_threshold_ = NULL,
n_nuisance = 3,
initialize_ml_nuisance_params = function() {
nuisance = vector("list", self$data$n_treat)
names(nuisance) = self$data$d_cols
private$params_ = list(
"ml_g0" = nuisance,
"ml_g1" = nuisance,
"ml_m" = nuisance,
"ml_r0" = nuisance,
"ml_r1" = nuisance)
invisible(self)
},
nuisance_est = function(smpls, ...) {
if (self$subgroups$always_takers == FALSE &
self$subgroups$never_takers == FALSE) {
message("If there are no always-takers and no never-takers,
ATE is estimated")
}
cond_smpls = get_cond_samples(
smpls,
self$data$data_model[[self$data$z_cols]])
m_hat = dml_cv_predict(self$learner$ml_m,
c(self$data$x_cols, self$data$other_treat_cols),
self$data$z_cols,
self$data$data_model,
nuisance_id = "nuis_m",
smpls = smpls,
est_params = self$get_params("ml_m"),
return_train_preds = FALSE,
task_type = private$task_type$ml_m,
fold_specific_params = private$fold_specific_params)
g0_hat = dml_cv_predict(self$learner$ml_g,
c(self$data$x_cols, self$data$other_treat_cols),
self$data$y_col,
self$data$data_model,
nuisance_id = "nuis_g0",
smpls = cond_smpls$smpls_0,
est_params = self$get_params("ml_g0"),
return_train_preds = FALSE,
task_type = private$task_type$ml_g,
fold_specific_params = private$fold_specific_params)
g1_hat = dml_cv_predict(self$learner$ml_g,
c(self$data$x_cols, self$data$other_treat_cols),
self$data$y_col,
self$data$data_model,
nuisance_id = "nuis_g1",
smpls = cond_smpls$smpls_1,
est_params = self$get_params("ml_g1"),
return_train_preds = FALSE,
task_type = private$task_type$ml_g,
fold_specific_params = private$fold_specific_params)
if (self$subgroups$always_takers == FALSE) {
r0_hat = list(preds = rep(0, self$data$n_obs), models = NULL)
} else {
r0_hat = dml_cv_predict(self$learner$ml_r,
c(self$data$x_cols, self$data$other_treat_cols),
self$data$treat_col,
self$data$data_model,
nuisance_id = "nuis_r0",
smpls = cond_smpls$smpls_0,
est_params = self$get_params("ml_r0"),
return_train_preds = FALSE,
task_type = private$task_type$ml_r,
fold_specific_params = private$fold_specific_params)
}
if (self$subgroups$never_takers == FALSE) {
r1_hat = list(preds = rep(1, self$data$n_obs), models = NULL)
} else {
r1_hat = dml_cv_predict(self$learner$ml_r,
c(self$data$x_cols, self$data$other_treat_cols),
self$data$treat_col,
self$data$data_model,
nuisance_id = "nuis_r1",
smpls = cond_smpls$smpls_1,
est_params = self$get_params("ml_r1"),
return_train_preds = FALSE,
task_type = private$task_type$ml_r,
fold_specific_params = private$fold_specific_params)
}
# compute residuals
z = self$data$data_model[[self$data$z_cols]]
d = self$data$data_model[[self$data$treat_col]]
y = self$data$data_model[[self$data$y_col]]
res = private$score_elements(
y, z, d,
g0_hat$preds, g1_hat$preds, m_hat$preds,
r0_hat$preds, r1_hat$preds,
smpls)
res$preds = list(
"ml_g0" = g0_hat$preds,
"ml_g1" = g1_hat$preds,
"ml_m" = m_hat$preds,
"ml_r0" = r0_hat$preds,
"ml_r1" = r1_hat$preds)
res$models = list(
"ml_g0" = g0_hat$models,
"ml_g1" = g1_hat$models,
"ml_m" = m_hat$models,
"ml_r0" = r0_hat$models,
"ml_r1" = r1_hat$models)
return(res)
},
score_elements = function(y = y, z = z, d = d,
g0_hat = g0_hat, g1_hat = g1_hat, m_hat = m_hat,
r0_hat = r0_hat, r1_hat = r1_hat,
smpls = smpls) {
u0_hat = y - g0_hat
u1_hat = y - g1_hat
w0_hat = d - r0_hat
w1_hat = d - r1_hat
if (self$trimming_rule == "truncate" & self$trimming_threshold > 0) {
m_hat[m_hat < self$trimming_threshold] = self$trimming_threshold
m_hat[m_hat > 1 - self$trimming_threshold] = 1 - self$trimming_threshold
}
if (is.character(self$score)) {
if (self$score == "LATE") {
psi_b = g1_hat - g0_hat + z * (u1_hat) / m_hat -
(1 - z) * u0_hat / (1 - m_hat)
psi_a = -1 * (r1_hat - r0_hat + z * (w1_hat) / m_hat -
(1 - z) * w0_hat / (1 - m_hat))
}
psis = list(psi_a = psi_a, psi_b = psi_b)
} else if (is.function(self$score)) {
psis = self$score(
y, z, d, g0_hat, g1_hat, m_hat, r0_hat,
r1_hat, smpls)
}
return(psis)
},
nuisance_tuning = function(smpls, param_set, tune_settings,
tune_on_folds, ...) {
if (!tune_on_folds) {
data_tune_list = list(self$data$data_model)
} else {
data_tune_list = lapply(
smpls$train_ids,
function(x) extract_training_data(self$data$data_model, x))
}
indx_g0 = lapply(data_tune_list, function(x) x[[self$data$z_cols]] == 0)
indx_g1 = lapply(data_tune_list, function(x) x[[self$data$z_cols]] == 1)
data_tune_list_z0 = lapply(
seq_len(length(data_tune_list)),
function(x) data_tune_list[[x]][indx_g0[[x]], ])
data_tune_list_z1 = lapply(
seq_len(length(data_tune_list)),
function(x) data_tune_list[[x]][indx_g1[[x]], ])
tuning_result_m = dml_tune(self$learner$ml_m,
c(self$data$x_cols, self$data$other_treat_cols),
self$data$z_cols,
data_tune_list,
nuisance_id = "nuis_m",
param_set$ml_m, tune_settings,
tune_settings$measure$ml_m,
private$task_type$ml_m)
tuning_result_g0 = dml_tune(self$learner$ml_g,
c(self$data$x_cols, self$data$other_treat_cols),
self$data$y_col,
data_tune_list_z0,
nuisance_id = "nuis_g0",
param_set$ml_g, tune_settings,
tune_settings$measure$ml_g,
private$task_type$ml_g)
tuning_result_g1 = dml_tune(self$learner$ml_g,
c(self$data$x_cols, self$data$other_treat_cols),
self$data$y_col,
data_tune_list_z1,
nuisance_id = "nuis_g1",
param_set$ml_g, tune_settings,
tune_settings$measure$ml_g,
private$task_type$ml_g)
if (self$subgroups$always_takers == TRUE) {
tuning_result_r0 = dml_tune(self$learner$ml_r,
c(self$data$x_cols, self$data$other_treat_cols),
self$data$treat_col,
data_tune_list_z0,
nuisance_id = "nuis_r0",
param_set$ml_r, tune_settings,
tune_settings$measure$ml_r,
private$task_type$ml_r)
} else {
tuning_result_r0 = list(list(), "params" = list(list()))
}
if (self$subgroups$never_takers == TRUE) {
tuning_result_r1 = dml_tune(self$learner$ml_r,
c(self$data$x_cols, self$data$other_treat_cols),
self$data$treat_col,
data_tune_list_z1,
nuisance_id = "nuis_r1",
param_set$ml_r, tune_settings,
tune_settings$measure$ml_r,
private$task_type$ml_r)
} else {
tuning_result_r1 = list(list(), "params" = list(list()))
}
tuning_result = list(
"ml_m" = list(tuning_result_m, params = tuning_result_m$params),
"ml_g0" = list(tuning_result_g0, params = tuning_result_g0$params),
"ml_g1" = list(tuning_result_g1, params = tuning_result_g1$params),
"ml_r0" = list(tuning_result_r0, params = tuning_result_r0$params),
"ml_r1" = list(tuning_result_r1, params = tuning_result_r1$params))
return(tuning_result)
},
check_score = function(score) {
assert(
check_character(score),
check_class(score, "function"))
if (is.character(score)) {
valid_score = c("LATE")
assertChoice(score, valid_score)
}
return()
},
check_data = function(obj_dml_data) {
one_treat = (obj_dml_data$n_treat == 1)
err_msg = paste(
"Incompatible data.\n",
"To fit an IIVM model with DoubleML",
"exactly one binary variable with values 0 and 1",
"needs to be specified as treatment variable.")
if (one_treat) {
binary_treat = test_integerish(obj_dml_data$data[[obj_dml_data$d_cols]],
lower = 0, upper = 1)
if (!(one_treat & binary_treat)) {
stop(err_msg)
}
} else {
stop(err_msg)
}
one_instr = (obj_dml_data$n_instr == 1)
err_msg = paste(
"Incompatible data.\n",
"To fit an IIVM model with DoubleML",
"exactly one binary variable with values 0 and 1",
"needs to be specified as instrumental variable.")
if (one_instr) {
binary_instr = test_integerish(obj_dml_data$data[[obj_dml_data$z_cols]],
lower = 0, upper = 1)
if (!(one_instr & binary_instr)) {
stop(err_msg)
}
} else {
stop(err_msg)
}
return()
}
)
)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.