R/double_ml.R

#' @title Abstract class DoubleML
#'
#' @description
#' Abstract base class that can't be initialized.
#'
#'
#' @format [R6::R6Class] object.
#'
#' @family DoubleML
DoubleML = R6Class("DoubleML",
  active = list(
    #' @field all_coef (`matrix()`) \cr
    #' Estimates of the causal parameter(s) for the `n_rep` different sample
    #' splits after calling `fit()`.
    all_coef = function(value) {
      if (missing(value)) {
        return(private$all_coef_)
      } else {
        stop("can't set field all_coef")
      }
    },

    #' @field all_dml1_coef (`array()`) \cr
    #' Estimates of the causal parameter(s) for the `n_rep` different sample
    #' splits after calling `fit()` with `dml_procedure = "dml1"`.
    all_dml1_coef = function(value) {
      if (missing(value)) {
        return(private$all_dml1_coef_)
      } else {
        stop("can't set field all_dml1_coef")
      }
    },

    #' @field all_se (`matrix()`) \cr
    #' Standard errors of the causal parameter(s) for the `n_rep` different
    #' sample splits after calling `fit()`.
    all_se = function(value) {
      if (missing(value)) {
        return(private$all_se_)
      } else {
        stop("can't set field all_se")
      }
    },

    #' @field apply_cross_fitting (`logical(1)`) \cr
    #' Indicates whether cross-fitting should be applied. Default is `TRUE`.
    apply_cross_fitting = function(value) {
      if (missing(value)) {
        return(private$apply_cross_fitting_)
      } else {
        stop("can't set field apply_cross_fitting")
      }
    },

    #' @field boot_coef (`matrix()`) \cr
    #' Bootstrapped coefficients for the causal parameter(s) after calling
    #' `fit()` and `bootstrap()`.
    boot_coef = function(value) {
      if (missing(value)) {
        return(private$boot_coef_)
      } else {
        stop("can't set field boot_coef")
      }
    },

    #' @field boot_t_stat (`matrix()`) \cr
    #' Bootstrapped t-statistics for the causal parameter(s) after calling
    #' `fit()` and `bootstrap()`.
    boot_t_stat = function(value) {
      if (missing(value)) {
        return(private$boot_t_stat_)
      } else {
        stop("can't set field boot_t_stat")
      }
    },

    #' @field coef (`numeric()`) \cr
    #' Estimates for the causal parameter(s) after calling `fit()`.
    coef = function(value) {
      if (missing(value)) {
        return(private$coef_)
      } else {
        stop("can't set field coef")
      }
    },

    #' @field data ([`data.table`][data.table::data.table()])\cr
    #' Data object.
    data = function(value) {
      if (missing(value)) {
        return(private$data_)
      } else {
        stop("can't set field data")
      }
    },

    #' @field dml_procedure (`character(1)`) \cr
    #' A `character()` (`"dml1"` or `"dml2"`) specifying the double machine
    #' learning algorithm. Default is `"dml2"`.
    dml_procedure = function(value) {
      if (missing(value)) {
        return(private$dml_procedure_)
      } else {
        stop("can't set field dml_procedure")
      }
    },

    #' @field draw_sample_splitting (`logical(1)`) \cr
    #' Indicates whether the sample splitting should be drawn during
    #' initialization of the object. Default is `TRUE`.
    draw_sample_splitting = function(value) {
      if (missing(value)) {
        return(private$draw_sample_splitting_)
      } else {
        stop("can't set field draw_sample_splitting")
      }
    },

    #' @field learner (named `list()`) \cr
    #' The machine learners for the nuisance functions.
    learner = function(value) {
      if (missing(value)) {
        return(private$learner_)
      } else {
        stop("can't set field learner")
      }
    },

    #' @field n_folds (`integer(1)`) \cr
    #' Number of folds. Default is `5`.
    n_folds = function(value) {
      if (missing(value)) {
        return(private$n_folds_)
      } else {
        stop("can't set field n_folds")
      }
    },

    #' @field n_rep (`integer(1)`) \cr
    #' Number of repetitions for the sample splitting. Default is `1`.
    n_rep = function(value) {
      if (missing(value)) {
        return(private$n_rep_)
      } else {
        stop("can't set field n_rep")
      }
    },

    #' @field params (named `list()`) \cr
    #' The hyperparameters of the learners.
    params = function(value) {
      if (missing(value)) {
        return(private$params_)
      } else {
        stop("can't set field params")
      }
    },

    #' @field psi (`array()`) \cr
    #' Value of the score function
    #' \eqn{\psi(W;\theta, \eta)=\psi_a(W;\eta) \theta + \psi_b (W; \eta)}
    #' after calling `fit()`.
    psi = function(value) {
      if (missing(value)) {
        return(private$psi_)
      } else {
        stop("can't set field psi")
      }
    },

    #' @field psi_a (`array()`) \cr
    #' Value of the score function component \eqn{\psi_a(W;\eta)} after
    #' calling `fit()`.
    psi_a = function(value) {
      if (missing(value)) {
        return(private$psi_a_)
      } else {
        stop("can't set field psi_a")
      }
    },

    #' @field psi_b (`array()`) \cr
    #' Value of the score function component \eqn{\psi_b(W;\eta)} after
    #' calling `fit()`.
    psi_b = function(value) {
      if (missing(value)) {
        return(private$psi_b_)
      } else {
        stop("can't set field psi_b")
      }
    },

    #' @field predictions (`array()`) \cr
    #' Predictions of the nuisance models after calling
    #' `fit(store_predictions=TRUE)`.
    predictions = function(value) {
      if (missing(value)) {
        return(private$predictions_)
      } else {
        stop("can't set field predictions")
      }
    },

    #' @field models (`array()`) \cr
    #' The fitted nuisance models after calling
    #' `fit(store_models=TRUE)`.
    models = function(value) {
      if (missing(value)) {
        return(private$models_)
      } else {
        stop("can't set field models")
      }
    },

    #' @field pval (`numeric()`) \cr
    #' p-values for the causal parameter(s) after calling `fit()`.
    pval = function(value) {
      if (missing(value)) {
        return(private$pval_)
      } else {
        stop("can't set field pval")
      }
    },

    #' @field score (`character(1)`, `function()`) \cr
    #' A `character(1)` or `function()` specifying the score function.
    score = function(value) {
      if (missing(value)) {
        return(private$score_)
      } else {
        stop("can't set field score")
      }
    },

    #' @field se (`numeric()`) \cr
    #' Standard errors for the causal parameter(s) after calling `fit()`.
    se = function(value) {
      if (missing(value)) {
        return(private$se_)
      } else {
        stop("can't set field se")
      }
    },

    #' @field smpls (`list()`) \cr
    #' The partition used for cross-fitting.
    smpls = function(value) {
      if (missing(value)) {
        return(private$smpls_)
      } else {
        stop("can't set field smpls")
      }
    },

    #' @field smpls_cluster (`list()`) \cr
    #' The partition of clusters used for cross-fitting.
    smpls_cluster = function(value) {
      if (missing(value)) {
        return(private$smpls_cluster_)
      } else {
        stop("can't set field smpls_cluster")
      }
    },

    #' @field t_stat (`numeric()`) \cr
    #' t-statistics for the causal parameter(s) after calling `fit()`.
    t_stat = function(value) {
      if (missing(value)) {
        return(private$t_stat_)
      } else {
        stop("can't set field t_stat")
      }
    },

    #' @field tuning_res (named `list()`) \cr
    #' Results from hyperparameter tuning.
    tuning_res = function(value) {
      if (missing(value)) {
        return(private$tuning_res_)
      } else {
        stop("can't set field tuning_res")
      }
    }),

  public = list(
    #' @description
    #' DoubleML is an abstract class that can't be initialized.
    initialize = function() {
      stop("DoubleML is an abstract class that can't be initialized.")
    },
    #' @description
    #' Print DoubleML objects.
    print = function() {

      class_name = class(self)[1]
      header = paste0(
        "================= ", class_name,
        " Object ==================\n")

      if (private$is_cluster_data) {
        cluster_info = paste0(
          "Cluster variable(s): ",
          paste0(self$data$cluster_cols, collapse = ", "),
          "\n")
      } else {
        cluster_info = ""
      }
      data_info = paste0(
        "Outcome variable: ", self$data$y_col, "\n",
        "Treatment variable(s): ", paste0(self$data$d_cols, collapse = ", "),
        "\n",
        "Covariates: ", paste0(self$data$x_cols, collapse = ", "), "\n",
        "Instrument(s): ", paste0(self$data$z_cols, collapse = ", "), "\n",
        cluster_info,
        "No. Observations: ", self$data$n_obs, "\n")

      if (is.character(self$score)) {
        score_info = paste0(
          "Score function: ", self$score, "\n",
          "DML algorithm: ", self$dml_procedure, "\n")
      } else if (is.function(self$score)) {
        score_info = paste0(
          "Score function: User specified score function \n",
          "DML algorithm: ", self$dml_procedure, "\n")
      }
      learner_info = character(length(self$learner))
      for (i_lrn in seq_len(length(self$learner))) {
        if (any(class(self$learner[[i_lrn]]) == "Learner")) {
          learner_info[i_lrn] = paste0(
            self$learner_names()[[i_lrn]], ": ",
            self$learner[[i_lrn]]$id, "\n")
        } else {
          learner_info[i_lrn] = paste0(
            self$learner_names()[[i_lrn]], ": ",
            self$learner[i_lrn], "\n")
        }
      }
      if (private$is_cluster_data) {
        resampling_info = paste0(
          "No. folds per cluster: ", private$n_folds_per_cluster, "\n",
          "No. folds: ", self$n_folds, "\n",
          "No. repeated sample splits: ", self$n_rep, "\n",
          "Apply cross-fitting: ", self$apply_cross_fitting, "\n")
      } else {
        resampling_info = paste0(
          "No. folds: ", self$n_folds, "\n",
          "No. repeated sample splits: ", self$n_rep, "\n",
          "Apply cross-fitting: ", self$apply_cross_fitting, "\n")
      }
      cat(header, "\n",
        "\n------------------ Data summary      ------------------\n",
        data_info,
        "\n------------------ Score & algorithm ------------------\n",
        score_info,
        "\n------------------ Machine learner   ------------------\n",
        learner_info,
        "\n------------------ Resampling        ------------------\n",
        resampling_info,
        "\n------------------ Fit summary       ------------------\n ",
        sep = "")
      self$summary()

      invisible(self)
    },

    #' @description
    #' Estimate DoubleML models.
    #'
    #' @param store_predictions (`logical(1)`) \cr
    #' Indicates whether the predictions for the nuisance functions should be
    #' stored in field `predictions`. Default is `FALSE`.
    #'
    #'
    #' @param store_models (`logical(1)`) \cr
    #' Indicates whether the fitted models for the nuisance functions should be
    #' stored in field `models` if you want to analyze the models or extract
    #' information like variable importance. Default is `FALSE`.
    #'
    #' @return self
    fit = function(store_predictions = FALSE, store_models = FALSE) {

      if (store_predictions) {
        private$initialize_predictions()
      }
      if (store_models) {
        private$initialize_models()
      }

      # TODO: insert check for tuned params
      for (i_rep in 1:self$n_rep) {
        private$i_rep = i_rep

        for (i_treat in 1:self$data$n_treat) {
          private$i_treat = i_treat

          if (self$data$n_treat > 1) {
            self$data$set_data_model(self$data$d_cols[i_treat])
          }

          # ml estimation of nuisance models and computation of psi elements
          res = private$nuisance_est(private$get__smpls())
          private$psi_a_[, private$i_rep, private$i_treat] = res$psi_a
          private$psi_b_[, private$i_rep, private$i_treat] = res$psi_b
          if (store_predictions) {
            private$store_predictions(res$preds)
          }
          if (store_models) {
            private$store_models(res$models)
          }

          # estimate the causal parameter
          private$all_coef_[private$i_treat, private$i_rep] = private$est_causal_pars()
          # compute score (depends on estimated causal parameter)
          private$psi_[, private$i_rep, private$i_treat] = private$compute_score()

          # compute standard errors for causal parameter
          private$all_se_[private$i_treat, private$i_rep] = private$se_causal_pars()
        }
      }

      private$agg_cross_fit()

      private$t_stat_ = self$coef / self$se
      private$pval_ = 2 * pnorm(-abs(self$t_stat))
      names(private$coef_) = names(private$se_) = names(private$t_stat_) =
        names(private$pval_) = self$data$d_cols

      invisible(self)
    },

    #' @description
    #' Multiplier bootstrap for DoubleML models.
    #'
    #' @param method (`character(1)`) \cr
    #' A `character(1)` (`"Bayes"`, `"normal"` or `"wild"`) specifying the
    #' multiplier bootstrap method.
    #'
    #' @param n_rep_boot (`integer(1)`) \cr
    #' The number of bootstrap replications.
    #'
    #' @return self
    bootstrap = function(method = "normal", n_rep_boot = 500) {
      if (all(is.na(self$psi))) {
        stop("Apply fit() before bootstrap().")
      }
      assert_choice(method, c("normal", "Bayes", "wild"))
      assert_count(n_rep_boot, positive = TRUE)
      if (private$is_cluster_data) {
        stop("bootstrap not yet implemented with clustering.")
      }

      private$initialize_boot_arrays(n_rep_boot)

      for (i_rep in 1:self$n_rep) {
        private$i_rep = i_rep

        if (self$apply_cross_fitting) {
          n_obs = self$data$n_obs
        } else {
          smpls = private$get__smpls()
          test_ids = smpls$test_ids
          test_index = test_ids[[1]]
          n_obs = length(test_index)
        }
        weights = draw_weights(method, n_rep_boot, n_obs)

        for (i_treat in 1:self$data$n_treat) {
          private$i_treat = i_treat

          boot_res = private$compute_bootstrap(weights, n_rep_boot)
          i_start = (private$i_rep - 1) * private$n_rep_boot + 1
          i_end = private$i_rep * private$n_rep_boot
          private$boot_coef_[private$i_treat, i_start:i_end] = boot_res$boot_coef
          private$boot_t_stat_[private$i_treat, i_start:i_end] = boot_res$boot_t_stat
        }
      }
      invisible(self)
    },
    #' @description
    #' Draw sample splitting for DoubleML models.
    #'
    #' The samples are drawn according to the attributes `n_folds`, `n_rep`
    #' and `apply_cross_fitting`.
    #'
    #' @return self
    split_samples = function() {
      dummy_task = Task$new("dummy_resampling", "regr", self$data$data)

      if (self$apply_cross_fitting) {
        if (private$is_cluster_data) {
          all_smpls = list()
          all_smpls_cluster = list()
          for (i_rep in 1:self$n_rep) {
            smpls_cluster_vars = list()
            for (i_var in 1:self$data$n_cluster_vars) {
              clusters = unique(self$data$data_model[[self$data$cluster_cols[i_var]]])
              n_clusters = length(clusters)

              dummy_task = Task$new(
                "dummy_resampling", "regr",
                data.table(dummy_var = rep(0, n_clusters)))
              dummy_resampling_scheme = rsmp("repeated_cv",
                folds = private$n_folds_per_cluster,
                repeats = 1)$instantiate(dummy_task)
              train_ids = lapply(
                1:(private$n_folds_per_cluster),
                function(x) clusters[dummy_resampling_scheme$train_set(x)])
              test_ids = lapply(
                1:(private$n_folds_per_cluster),
                function(x) clusters[dummy_resampling_scheme$test_set(x)])

              smpls_cluster_vars[[i_var]] = list(
                train_ids = train_ids,
                test_ids = test_ids)
            }
            smpls = list(train_ids = list(), test_ids = list())
            smpls_cluster = list(train_ids = list(), test_ids = list())
            cart = expand.grid(lapply(
              1:self$data$n_cluster_vars,
              function(x) 1:private$n_folds_per_cluster))
            for (i_smpl in 1:(self$n_folds)) {
              ind_train = rep(TRUE, self$data$n_obs)
              ind_test = rep(TRUE, self$data$n_obs)
              this_cluster_smpl_train = list()
              this_cluster_smpl_test = list()
              for (i_var in 1:self$data$n_cluster_vars) {
                i_fold = cart[i_smpl, i_var]
                train_clusters = smpls_cluster_vars[[i_var]]$train_ids[[i_fold]]
                test_clusters = smpls_cluster_vars[[i_var]]$test_ids[[i_fold]]
                this_cluster_smpl_train[[i_var]] = train_clusters
                this_cluster_smpl_test[[i_var]] = test_clusters
                xx = self$data$data_model[[self$data$cluster_cols[i_var]]] %in% train_clusters
                ind_train = ind_train & xx
                xx = self$data$data_model[[self$data$cluster_cols[i_var]]] %in% test_clusters
                ind_test = ind_test & xx
              }
              smpls$train_ids[[i_smpl]] = seq(self$data$n_obs)[ind_train]
              smpls$test_ids[[i_smpl]] = seq(self$data$n_obs)[ind_test]
              smpls_cluster$train_ids[[i_smpl]] = this_cluster_smpl_train
              smpls_cluster$test_ids[[i_smpl]] = this_cluster_smpl_test
            }
            all_smpls[[i_rep]] = smpls
            all_smpls_cluster[[i_rep]] = smpls_cluster
          }
          smpls = all_smpls
          private$smpls_cluster_ = all_smpls_cluster
        } else {
          dummy_resampling_scheme = rsmp("repeated_cv",
            folds = self$n_folds,
            repeats = self$n_rep)$instantiate(dummy_task)
          train_ids = lapply(
            1:(self$n_folds * self$n_rep),
            function(x) dummy_resampling_scheme$train_set(x))
          test_ids = lapply(
            1:(self$n_folds * self$n_rep),
            function(x) dummy_resampling_scheme$test_set(x))

          smpls = lapply(1:self$n_rep, function(i_repeat) {
            list(
              train_ids = train_ids[((i_repeat - 1) * self$n_folds + 1):
              (i_repeat * self$n_folds)],
              test_ids = test_ids[((i_repeat - 1) * self$n_folds + 1):
              (i_repeat * self$n_folds)])
          })
        }
      } else {
        if (self$n_folds == 2) {
          dummy_resampling_scheme = rsmp("holdout", ratio = 0.5)$instantiate(dummy_task)
          train_ids = list(dummy_resampling_scheme$train_set(1))
          test_ids = list(dummy_resampling_scheme$test_set(1))

          smpls = list(list(train_ids = train_ids, test_ids = test_ids))

        } else if (self$n_folds == 1) {
          dummy_resampling_scheme = rsmp("insample")$instantiate(dummy_task)

          train_ids = lapply(
            1:(self$n_folds * self$n_rep),
            function(x) dummy_resampling_scheme$train_set(x))
          test_ids = lapply(
            1:(self$n_folds * self$n_rep),
            function(x) dummy_resampling_scheme$test_set(x))

          smpls = lapply(1:self$n_rep, function(i_repeat) {
            list(
              train_ids = train_ids[((i_repeat - 1) * self$n_folds + 1):
              (i_repeat * self$n_folds)],
              test_ids = test_ids[((i_repeat - 1) * self$n_folds + 1):
              (i_repeat * self$n_folds)])
          })
        }
      }
      private$smpls_ = smpls
      invisible(self)
    },
    #' @description
    #' Set the sample splitting for DoubleML models.
    #'
    #' The attributes `n_folds` and `n_rep` are derived from the provided
    #' partition.
    #'
    #' @param smpls (`list()`) \cr
    #' A nested `list()`. The outer lists needs to provide an entry per
    #' repeated sample splitting (length of the list is set as `n_rep`).
    #' The inner list is a named `list()` with names `train_ids` and `test_ids`.
    #' The entries in `train_ids` and `test_ids` must be partitions per fold
    #' (length of `train_ids` and `test_ids` is set as `n_folds`).
    #'
    #' @return self
    #'
    #' @examples
    #' library(DoubleML)
    #' library(mlr3)
    #' set.seed(2)
    #' obj_dml_data = make_plr_CCDDHNR2018(n_obs=10)
    #' dml_plr_obj = DoubleMLPLR$new(obj_dml_data,
    #'                               lrn("regr.rpart"), lrn("regr.rpart"))
    #'
    #' # simple sample splitting with two folds and without cross-fitting
    #' smpls = list(list(train_ids = list(c(1, 2, 3, 4, 5)),
    #'                   test_ids = list(c(6, 7, 8, 9, 10))))
    #' dml_plr_obj$set_sample_splitting(smpls)
    #'
    #' # sample splitting with two folds and cross-fitting but no repeated cross-fitting
    #' smpls = list(list(train_ids = list(c(1, 2, 3, 4, 5), c(6, 7, 8, 9, 10)),
    #'                   test_ids = list(c(6, 7, 8, 9, 10), c(1, 2, 3, 4, 5))))
    #' dml_plr_obj$set_sample_splitting(smpls)
    #'
    #' # sample splitting with two folds and repeated cross-fitting with n_rep = 2
    #' smpls = list(list(train_ids = list(c(1, 2, 3, 4, 5), c(6, 7, 8, 9, 10)),
    #'                   test_ids = list(c(6, 7, 8, 9, 10), c(1, 2, 3, 4, 5))),
    #'              list(train_ids = list(c(1, 3, 5, 7, 9), c(2, 4, 6, 8, 10)),
    #'                   test_ids = list(c(2, 4, 6, 8, 10), c(1, 3, 5, 7, 9))))
    #' dml_plr_obj$set_sample_splitting(smpls)
    set_sample_splitting = function(smpls) {
      if (private$is_cluster_data) {
        stop(paste(
          "Externally setting the sample splitting for DoubleML is",
          "not yet implemented with clustering."))
      }
      if (test_list(smpls, names = "unnamed")) {
        lapply(smpls, function(x) check_smpl_split(x, self$data$n_obs))

        n_folds_each_train_smpl = vapply(
          smpls, function(x) length(x$train_ids),
          integer(1L))
        n_folds_each_test_smpl = vapply(
          smpls, function(x) length(x$test_ids),
          integer(1L))

        if (!all(n_folds_each_train_smpl == n_folds_each_train_smpl[1])) {
          stop("Different number of folds for repeated cross-fitting.")
        }

        smpls_are_partitions = vapply(
          smpls,
          function(x) check_is_partition(x$test_ids, self$data$n_obs),
          FUN.VALUE = TRUE)

        if (all(smpls_are_partitions)) {
          if (length(smpls) == 1 &
            n_folds_each_train_smpl[1] == 1 &
            check_is_partition(smpls[[1]]$train_ids, self$data$n_obs)) {
            private$n_rep_ = 1
            private$n_folds_ = 1
            private$apply_cross_fitting_ = FALSE
            private$smpls_ = smpls
          } else {
            private$n_rep_ = length(smpls)
            private$n_folds_ = n_folds_each_train_smpl[1]
            private$apply_cross_fitting_ = TRUE
            lapply(
              smpls,
              function(x) {
                check_smpl_split(x, self$data$n_obs,
                  check_intersect = TRUE)
              })
            private$smpls_ = smpls
          }
        } else {
          if (n_folds_each_train_smpl[1] != 1) {
            stop(paste(
              "Invalid partition provided.",
              "Tuples (train_ids, test_ids) for more than one fold",
              "provided that don't form a partition."))
          }
          if (length(smpls) != 1) {
            stop(paste(
              "Repeated sample splitting without cross-fitting not",
              "implemented."))
          }
          private$n_rep_ = length(smpls)
          private$n_folds_ = 2
          private$apply_cross_fitting_ = FALSE
          lapply(
            smpls,
            function(x) {
              check_smpl_split(x, self$data$n_obs,
                check_intersect = TRUE)
            })
          private$smpls_ = smpls
        }
      } else {
        check_smpl_split(smpls, self$data$n_obs)
        private$n_rep_ = 1
        n_folds = length(smpls$train_ids)
        if (check_is_partition(smpls$test_ids, self$data$n_obs)) {
          if (n_folds == 1 & check_is_partition(smpls$train_ids, self$data$n_obs)) {
            private$n_folds_ = 1
            private$apply_cross_fitting_ = FALSE
            private$smpls_ = list(smpls)
          } else {
            private$n_folds_ = n_folds
            private$apply_cross_fitting_ = TRUE
            check_smpl_split(smpls, self$data$n_obs,
              check_intersect = TRUE)
            private$smpls_ = list(smpls)
          }
        } else {
          if (n_folds != 1) {
            stop(paste(
              "Invalid partition provided.",
              "Tuples (train_ids, test_ids) for more than one fold",
              "provided that don't form a partition."))
          }
          private$n_folds_ = 2
          private$apply_cross_fitting_ = FALSE
          check_smpl_split(smpls, self$data$n_obs,
            check_intersect = TRUE)
          private$smpls_ = list(smpls)
        }
      }

      private$initialize_arrays()

      invisible(self)
    },
    #' @description
    #' Hyperparameter-tuning for DoubleML models.
    #'
    #' The hyperparameter-tuning is performed using the tuning methods provided
    #' in the [mlr3tuning](https://mlr3tuning.mlr-org.com/) package. For more
    #' information on tuning in [mlr3](https://mlr3.mlr-org.com/), we refer to
    #' the section on parameter tuning in the
    #' [mlr3 book](https://mlr3book.mlr-org.com/optimization.html#tuning).
    #'
    #' @param param_set (named `list()`) \cr
    #' A named `list` with a parameter grid for each nuisance model/learner
    #' (see method `learner_names()`). The parameter grid must be an object of
    #' class [ParamSet][paradox::ParamSet].
    #'
    #' @param tune_settings (named `list()`) \cr
    #' A named `list()` with arguments passed to the hyperparameter-tuning with
    #' [mlr3tuning](https://mlr3tuning.mlr-org.com/) to set up
    #' [TuningInstance][mlr3tuning::TuningInstanceSingleCrit] objects.
    #' `tune_settings` has entries
    #' * `terminator` ([Terminator][bbotk::Terminator]) \cr
    #' A [Terminator][bbotk::Terminator] object. Specification of `terminator`
    #' is required to perform tuning.
    #' * `algorithm` ([Tuner][mlr3tuning::Tuner] or `character(1)`) \cr
    #' A [Tuner][mlr3tuning::Tuner] object (recommended) or key passed to the
    #' respective dictionary to specify the tuning algorithm used in
    #' [tnr()][mlr3tuning::tnr()]. `algorithm` is passed as an argument to
    #' [tnr()][mlr3tuning::tnr()]. If `algorithm` is not specified by the users,
    #' default is set to `"grid_search"`. If set to `"grid_search"`, then
    #' additional argument `"resolution"` is required.
    #' * `rsmp_tune` ([Resampling][mlr3::Resampling] or `character(1)`)\cr
    #' A [Resampling][mlr3::Resampling] object (recommended) or option passed
    #' to [rsmp()][mlr3::mlr_sugar] to initialize a
    #' [Resampling][mlr3::Resampling] for parameter tuning in `mlr3`.
    #' If not specified by the user, default is set to `"cv"`
    #' (cross-validation).
    #' * `n_folds_tune` (`integer(1)`, optional) \cr
    #' If `rsmp_tune = "cv"`, number of folds used for cross-validation.
    #' If not specified by the user, default is set to `5`.
    #' * `measure` (`NULL`, named `list()`, optional) \cr
    #' Named list containing the measures used for parameter tuning. Entries in
    #' list must either be [Measure][mlr3::Measure] objects or keys to be
    #' passed to passed to [msr()][mlr3::msr()]. The names of the entries must
    #' match the learner names (see method `learner_names()`). If set to `NULL`,
    #' default measures are used, i.e., `"regr.mse"` for continuous outcome
    #' variables and `"classif.ce"` for binary outcomes.
    #' * `resolution` (`character(1)`) \cr The key passed to the respective
    #' dictionary to specify  the tuning algorithm used in
    #' [tnr()][mlr3tuning::tnr()]. `resolution` is passed as an argument to
    #' [tnr()][mlr3tuning::tnr()].
    #'
    #' @param tune_on_folds (`logical(1)`) \cr
    #' Indicates whether the tuning should be done fold-specific or globally.
    #' Default is `FALSE`.
    #'
    #' @return self
    tune = function(param_set, tune_settings = list(
      n_folds_tune = 5,
      rsmp_tune = mlr3::rsmp("cv", folds = 5),
      measure = NULL,
      terminator = mlr3tuning::trm("evals", n_evals = 20),
      algorithm = mlr3tuning::tnr("grid_search"),
      resolution = 5),
    tune_on_folds = FALSE) {

      assert_list(param_set)
      valid_learner = self$learner_names()
      if (!test_names(names(param_set), subset.of = valid_learner)) {
        stop(paste(
          "Invalid param_set", paste0(names(param_set), collapse = ", "),
          "\n param_grids must be a named list with elements named",
          paste0(valid_learner, collapse = ", ")))
      }
      for (i_grid in seq_len(length(param_set))) {
        assert_class(param_set[[i_grid]], "ParamSet")
      }
      assert_logical(tune_on_folds, len = 1)
      tune_settings = private$assert_tune_settings(tune_settings)

      if (!self$apply_cross_fitting) {
        stop("Parameter tuning for no-cross-fitting case not implemented.")
      }

      if (tune_on_folds) {
        params_rep = vector("list", self$n_rep)
        private$tuning_res_ = rep(list(params_rep), self$data$n_treat)
        names(private$tuning_res_) = self$data$d_cols
        private$fold_specific_params = TRUE
      } else {
        private$tuning_res_ = vector("list", self$data$n_treat)
        names(private$tuning_res_) = self$data$d_cols
      }

      for (i_treat in 1:self$data$n_treat) {
        private$i_treat = i_treat

        if (self$data$n_treat > 1) {
          self$data$set_data_model(self$data$d_cols[i_treat])
        }

        if (tune_on_folds) {
          for (i_rep in 1:self$n_rep) {
            private$i_rep = i_rep
            param_tuning = private$nuisance_tuning(
              private$get__smpls(),
              param_set, tune_settings, tune_on_folds)
            private$tuning_res_[[i_treat]][[i_rep]] = param_tuning

            for (nuisance_model in names(param_tuning)) {
              if (!is.null(param_tuning[[nuisance_model]][[1]])) {
                self$set_ml_nuisance_params(
                  learner = nuisance_model,
                  treat_var = self$data$treat_col,
                  params = param_tuning[[nuisance_model]]$params,
                  set_fold_specific = FALSE)
              } else {
                next
              }
            }
          }
        } else {
          private$i_rep = 1
          param_tuning = private$nuisance_tuning(
            private$get__smpls(),
            param_set, tune_settings, tune_on_folds)
          private$tuning_res_[[i_treat]] = param_tuning

          for (nuisance_model in self$params_names()) {
            if (!is.null(param_tuning[[nuisance_model]][[1]])) {
              self$set_ml_nuisance_params(
                learner = nuisance_model,
                treat_var = self$data$treat_col,
                params = param_tuning[[nuisance_model]]$params[[1]],
                set_fold_specific = FALSE)
            } else {
              next
            }
          }
        }
      }
      invisible(self)
    },
    #' @description
    #' Summary for DoubleML models after calling `fit()`.
    #'
    #' @param digits (`integer(1)`) \cr
    #' The number of significant digits to use when printing.
    summary = function(digits = max(3L, getOption("digits") -
      3L)) {
      if (all(is.na(self$coef))) {
        message("fit() not yet called.")
      } else {
        k = length(self$coef)
        table = matrix(NA_real_, ncol = 4, nrow = k)
        rownames(table) = names(self$coef)
        colnames(table) = c("Estimate.", "Std. Error", "t value", "Pr(>|t|)")
        table[, 1] = self$coef
        table[, 2] = self$se
        table[, 3] = self$t_stat
        table[, 4] = self$pval
        private$summary_table = table

        if (length(k)) {
          cat(
            "Estimates and significance testing of the",
            "effect of target variables\n")
          res = as.matrix(printCoefmat(private$summary_table,
            digits = digits,
            P.values = TRUE,
            has.Pvalue = TRUE))
          cat("\n")
        }
        else {
          cat("No coefficients\n")
        }
        cat("\n")
        invisible(res)
      }
    },
    #' @description
    #' Confidence intervals for DoubleML models.
    #'
    #' @param joint (`logical(1)`) \cr
    #' Indicates whether joint confidence intervals are computed.
    #' Default is `FALSE`.
    #'
    #' @param level (`numeric(1)`) \cr
    #' The confidence level. Default is `0.95`.
    #'
    #' @param parm (`numeric()` or `character()`) \cr
    #' A specification of which parameters are to be given confidence intervals
    #' among the variables for which inference was done, either a vector of
    #' numbers or a vector of names. If missing, all parameters are considered
    #' (default).
    #' @return A `matrix()` with the confidence interval(s).
    confint = function(parm, joint = FALSE, level = 0.95) {
      assert_logical(joint, len = 1)
      assert_numeric(level, len = 1)
      if (level <= 0 | level >= 1) {
        stop("'level' must be > 0 and < 1.")
      }
      if (missing(parm)) {
        parm = names(self$coef)
      }
      else {
        assert(
          check_character(parm, max.len = self$data$n_treat),
          check_numeric(parm, max.len = self$data$n_treat))
        if (is.numeric(parm)) {
          parm = names(self$coef)[parm]
        }
      }

      if (joint == FALSE) {
        a = (1 - level) / 2
        a = c(a, 1 - a)
        pct = format.perc(a, 3)
        fac = qnorm(a)
        ci = array(NA_real_,
          dim = c(length(parm), 2L),
          dimnames = list(parm, pct))
        ci[] = self$coef[parm] + self$se[parm] %o% fac
      }

      if (joint == TRUE) {

        a = (1 - level)
        ab = c(a / 2, 1 - a / 2)
        pct = format.perc(ab, 3)
        ci = array(NA_real_,
          dim = c(length(parm), 2L),
          dimnames = list(parm, pct))

        if (all(is.na(self$boot_coef))) {
          stop(paste(
            "Multiplier bootstrap has not yet been performed.",
            "First call bootstrap() and then try confint() again."))
        }

        sim = apply(abs(self$boot_t_stat), 2, max)
        hatc = quantile(sim, probs = 1 - a)

        ci[, 1] = self$coef[parm] - hatc * self$se[parm]
        ci[, 2] = self$coef[parm] + hatc * self$se[parm]
      }
      return(ci)
    },
    #' @description
    #' Returns the names of the learners.
    #'
    #' @return `character()` with names of learners.
    learner_names = function() {
      return(names(self$learner))
    },
    #' @description
    #' Returns the names of the nuisance models with hyperparameters.
    #'
    #' @return `character()` with names of nuisance models with hyperparameters.
    params_names = function() {
      return(names(self$params))
    },
    #' @description
    #' Set hyperparameters for the nuisance models of DoubleML models.
    #'
    #' Note that in the current implementation, either all parameters have to
    #' be set globally or all parameters have to be provided fold-specific.
    #'
    #' @param learner (`character(1)`) \cr
    #' The nuisance model/learner (see method `params_names`).
    #'
    #' @param treat_var (`character(1)`) \cr
    #' The treatment varaible (hyperparameters can be set treatment-variable
    #' specific).
    #'
    #' @param params (named `list()`) \cr
    #' A named `list()` with estimator parameters. Parameters are used for all
    #' folds by default. Alternatively, parameters can be passed in a
    #' fold-specific way if option  `fold_specific`is `TRUE`. In this case, the
    #' outer list needs to be of length `n_rep` and the inner list of length
    #' `n_folds`.
    #'
    #' @param set_fold_specific (`logical(1)`) \cr
    #' Indicates if the parameters passed in `params` should be passed in
    #' fold-specific way. Default is `FALSE`. If `TRUE`, the outer list needs
    #' to be of length `n_rep` and the inner list of length `n_folds`.
    #' Note that in the current implementation, either all parameters have to
    #' be set globally or all parameters have to be provided fold-specific.
    #'
    #' @return self
    set_ml_nuisance_params = function(learner = NULL, treat_var = NULL, params,
      set_fold_specific = FALSE) {

      valid_learner = self$params_names()
      assert_character(learner, len = 1)
      assert_choice(learner, valid_learner)

      assert_choice(treat_var, self$data$d_cols)
      assert_list(params)
      assert_logical(set_fold_specific, len = 1)

      if (!set_fold_specific) {
        if (private$fold_specific_params) {
          private$params_[[learner]][[treat_var]][[private$i_rep]] = params
        } else {
          private$params_[[learner]][[treat_var]] = params
        }
      } else {
        if (length(params) != self$n_rep) {
          stop("Length of (outer) parameter list does not match n_rep.")
        }
        if (!all(lapply(params, length) == self$n_folds)) {
          stop("Length of (inner) parameter list does not match n_folds.")
        }

        private$fold_specific_params = set_fold_specific
        private$params_[[learner]][[treat_var]] = params
      }
    },
    #' @description
    #' Multiple testing adjustment for DoubleML models.
    #'
    #' @param method (`character(1)`) \cr
    #' A `character(1)`(`"romano-wolf"`, `"bonferroni"`, `"holm"`, etc)
    #' specifying the adjustment method. In addition to `"romano-wolf"`,
    #' all methods implemented in [p.adjust()][stats::p.adjust()] can be
    #' applied. Default is `"romano-wolf"`.
    #' @param return_matrix (`logical(1)`) \cr
    #' Indicates if the output is returned as a matrix with corresponding
    #' coefficient names.
    #'
    #' @return `numeric()` with adjusted p-values. If `return_matrix = TRUE`,
    #' a `matrix()` with adjusted p_values.
    p_adjust = function(method = "romano-wolf", return_matrix = TRUE) {
      if (all(is.na(self$coef))) {
        stop("apply fit() before p_adjust().")
      }

      if (tolower(method) %in% c("rw", "romano-wolf")) {
        if (is.null(self$boot_t_stat) | all(is.na(self$coef))) {
          stop("apply fit() & bootstrap() before p_adjust().")
        }
        k = self$data$n_treat
        pinit = p_val_corrected = vector(mode = "numeric", length = k)

        boot_t_stat = self$boot_t_stat
        t_stat = self$t_stat
        stepdown_ind = order(abs(t_stat), decreasing = TRUE)
        ro = order(stepdown_ind)

        for (i_d in 1:k) {
          if (i_d == 1) {
            sim = apply(abs(boot_t_stat), 2, max)
            pinit[i_d] = pmin(1, mean(sim > abs(t_stat[stepdown_ind][i_d])))
          } else {
            sim = apply(
              abs(boot_t_stat[-stepdown_ind[1:(i_d - 1)], , drop = FALSE]), 2,
              max)
            pinit[i_d] = pmin(1, mean(sim > abs(t_stat[stepdown_ind][i_d])))
          }
        }
        # ensure monotonicity
        for (i_d in 1:k) {
          if (i_d == 1) {
            p_val_corrected[i_d] = pinit[i_d]
          } else {
            p_val_corrected[i_d] = max(pinit[i_d], p_val_corrected[i_d - 1])
          }
        }
        p_val = p_val_corrected[ro]
      } else {
        if (is.element(method, p.adjust.methods)) {
          p_val = p.adjust(self$pval,
            method = method,
            n = self$data$n_treat)
        } else {
          stop(paste(
            "Invalid method", method,
            "argument specified in p_adjust()."))
        }
      }

      if (return_matrix) {
        res = as.matrix(cbind(self$coef, p_val))
        colnames(res) = c("Estimate.", "pval")
        return(res)
      } else {
        return(p_val)
      }
    },
    #' @description
    #' Get hyperparameters for the nuisance model of DoubleML models.
    #'
    #' @param learner (`character(1)`) \cr
    #' The nuisance model/learner (see method `params_names()`)
    #'
    #' @return named `list()`with paramers for the nuisance model/learner.
    get_params = function(learner) {
      valid_learner = self$params_names()
      assert_character(learner, len = 1)
      assert_choice(learner, valid_learner)

      if (private$fold_specific_params) {
        params = self$params[[learner]][[self$data$treat_col]][[private$i_rep]]
      } else {
        params = self$params[[learner]][[self$data$treat_col]]
      }
      return(params)
    }
  ),
  private = list(
    all_coef_ = NULL,
    all_dml1_coef_ = NULL,
    all_se_ = NULL,
    apply_cross_fitting_ = NULL,
    boot_coef_ = NULL,
    boot_t_stat_ = NULL,
    coef_ = NULL,
    data_ = NULL,
    dml_procedure_ = NULL,
    draw_sample_splitting_ = NULL,
    learner_ = NULL,
    n_folds_ = NULL,
    n_rep_ = NULL,
    params_ = NULL,
    psi_ = NULL,
    psi_a_ = NULL,
    psi_b_ = NULL,
    predictions_ = NULL,
    models_ = NULL,
    pval_ = NULL,
    score_ = NULL,
    se_ = NULL,
    smpls_ = NULL,
    t_stat_ = NULL,
    tuning_res_ = NULL,
    n_rep_boot = NULL,
    i_rep = NA_integer_,
    i_treat = NA_integer_,
    fold_specific_params = NULL,
    summary_table = NULL,
    task_type = list(),
    is_cluster_data = FALSE,
    n_folds_per_cluster = NA_integer_,
    smpls_cluster_ = NULL,
    var_scaling_factor = NA_real_,
    initialize_double_ml = function(data,
      n_folds,
      n_rep,
      score,
      dml_procedure,
      draw_sample_splitting,
      apply_cross_fitting) {
      # check and pick up obj_dml_data

      assert_class(data, "DoubleMLData")
      private$is_cluster_data = FALSE
      if (test_class(data, "DoubleMLClusterData")) {
        if (data$n_cluster_vars > 2) {
          stop("Multi-way (n_ways > 2) clustering not yet implemented.")
        }
        private$is_cluster_data = TRUE
      }
      private$data_ = data

      # initialize learners and parameters which are set model specific
      private$learner_ = NULL
      private$params_ = NULL
      # Set fold_specific_params = FALSE at instantiation
      private$fold_specific_params = FALSE

      # check resampling specifications
      assert_count(n_folds)
      assert_count(n_rep)
      assert_logical(apply_cross_fitting, len = 1)
      assert_logical(draw_sample_splitting, len = 1)

      # set resampling specifications
      if (private$is_cluster_data) {
        if ((n_folds == 1) | (!apply_cross_fitting)) {
          stop(paste(
            "No cross-fitting (`apply_cross_fitting = False`)",
            "is not yet implemented with clustering."))
        }
        private$n_folds_per_cluster = n_folds
        private$n_folds_ = n_folds^self$data$n_cluster_vars
      } else {
        private$n_folds_ = n_folds
      }
      private$n_rep_ = n_rep
      private$apply_cross_fitting_ = apply_cross_fitting
      private$draw_sample_splitting_ = draw_sample_splitting

      # check and set dml_procedure and score
      assert_choice(dml_procedure, c("dml1", "dml2"))
      private$dml_procedure_ = dml_procedure
      private$score_ = score

      if (self$n_folds == 1 & self$apply_cross_fitting) {
        message(paste(
          "apply_cross_fitting is set to FALSE.",
          "Cross-fitting is not supported for n_folds = 1."))
        private$apply_cross_fitting_ = FALSE
      }

      if (!self$apply_cross_fitting) {
        if (self$n_folds > 2) {
          stop(paste(
            "Estimation without cross-fitting not supported for",
            "n_folds > 2."))
        }
        if (self$dml_procedure == "dml2") {
          # redirect to dml1 which works out-of-the-box; dml_procedure is of no
          # relevance without cross-fitting
          private$dml_procedure_ = "dml1"
        }
      }

      # perform sample splitting
      if (self$draw_sample_splitting) {
        self$split_samples()
      } else {
        private$smpls_ = NULL
      }

      # initialize arrays according to obj_dml_data and the resampling settings
      private$initialize_arrays()

      # also initialize bootstrap arrays with the default number of
      # bootstrap replications
      private$initialize_boot_arrays(n_rep_boot = 500)

      # initialize instance attributes which are later used for iterating
      invisible(self)
    },
    assert_learner = function(learner, learner_name, Regr, Classif) {

      assert(
        check_character(learner, max.len = 1),
        check_class(learner, "Learner"))

      if (test_class(learner, "AutoTuner")) {
        stop(paste0(
          "Learners of class 'AutoTuner' are not supported."
        ))
      }
      if (is.character(learner)) {
        # warning("Learner provision by character() will be deprecated in the
        # future.")
        learner = lrn(learner)
      }

      if ((Regr & learner$task_type == "regr") |
        (Classif & learner$task_type == "classif")) {
        private$task_type[learner_name] = learner$task_type
      }

      if ((Regr & !Classif & !learner$task_type == "regr")) {
        stop(paste0(
          "Invalid learner provided for ", learner_name,
          ": 'learner$task_type' must be 'regr'"))
      }
      if ((Classif & !Regr & !learner$task_type == "classif")) {
        stop(paste0(
          "Invalid learner provided for ", learner_name,
          ": 'learner$task_type must be 'classif'"))
      }
      invisible(learner)
    },
    assert_tune_settings = function(tune_settings) {

      valid_learner = self$learner_names()

      if (!test_names(names(tune_settings), must.include = "terminator")) {
        stop(paste(
          "Invalid tune_settings\n",
          "object 'terminator' is missing."))
      }
      assert_class(tune_settings$terminator, "Terminator")

      if (test_names(names(tune_settings), must.include = "n_folds_tune")) {
        assert_integerish(tune_settings$n_folds_tune, len = 1, lower = 2)
      } else {
        tune_settings$n_folds_tune = 5
      }

      if (test_names(names(tune_settings), must.include = "rsmp_tune")) {
        assert(
          check_character(tune_settings$rsmp_tune),
          check_class(tune_settings$rsmp_tune, "Resampling"))
        if (!test_class(tune_settings$rsmp_tune, "Resampling")) {
          if (tune_settings$rsmp_tune == "cv") {
            tune_settings$rsmp_tune = rsmp(tune_settings$rsmp_tune,
              folds = tune_settings$n_folds_tune)
          } else {
            tune_settings$rsmp_tune = rsmp(tune_settings$rsmp_tune)
          }
        }
      } else {
        tune_settings$rsmp_tune = rsmp("cv", folds = tune_settings$n_folds_tune)
      }

      if (test_names(names(tune_settings), must.include = "measure") && !is.null(tune_settings$measure)) {
        assert_list(tune_settings$measure)
        if (!test_names(names(tune_settings$measure),
          subset.of = valid_learner)) {
          stop(paste(
            "Invalid name of measure", paste0(names(tune_settings$measure),
              collapse = ", "),
            "\n measure must be a named list with elements named",
            paste0(valid_learner, collapse = ", ")))
        }
        for (i_msr in seq_len(length(tune_settings$measure))) {
          assert(
            check_character(tune_settings$measure[[i_msr]]),
            check_class(tune_settings$measure[[i_msr]], "Measure"))
        }
      } else {
        tune_settings$measure = rep(list(NULL), length(valid_learner))
        names(tune_settings$measure) = valid_learner
      }

      for (this_learner in valid_learner) {
        if (!test_class(tune_settings$measure[[this_learner]], "Measure")) {
          tune_settings$measure[[this_learner]] = set_default_measure(
            tune_settings$measure[[this_learner]],
            private$task_type[[this_learner]])
        }
      }

      if (!test_names(names(tune_settings), must.include = "algorithm")) {
        tune_settings$algorithm = "grid_search"
      } else {
        assert(
          check_character(tune_settings$algorithm, len = 1),
          check_class(tune_settings$algorithm, "Tuner"))
      }

      if (test_character(tune_settings$algorithm)) {
        if (tune_settings$algorithm == "grid_search") {
          if (is.null(tune_settings$resolution)) {
            stop(paste(
              "Invalid tune_settings\n",
              "object 'resolution' is missing."))
          } else {
            assert_count(tune_settings$resolution, positive = TRUE)
          }
          tune_settings$tuner = tnr(tune_settings$algorithm,
            resolution = tune_settings$resolution)
        }
      } else {
        tune_settings$tuner = tune_settings$algorithm
      }

      return(tune_settings)
    },
    initialize_arrays = function() {

      private$psi_ = array(NA_real_, dim = c(
        self$data$n_obs, self$n_rep,
        self$data$n_treat))
      private$psi_a_ = array(NA_real_, dim = c(
        self$data$n_obs, self$n_rep,
        self$data$n_treat))
      private$psi_b_ = array(NA_real_, dim = c(
        self$data$n_obs, self$n_rep,
        self$data$n_treat))

      private$coef_ = array(NA_real_, dim = c(self$data$n_treat))
      private$se_ = array(NA_real_, dim = c(self$data$n_treat))

      private$all_coef_ = array(NA_real_,
        dim = c(self$data$n_treat, self$n_rep))
      private$all_se_ = array(NA_real_,
        dim = c(self$data$n_treat, self$n_rep))

      if (self$dml_procedure == "dml1") {
        if (self$apply_cross_fitting) {
          private$all_dml1_coef_ = array(NA_real_, dim = c(
            self$data$n_treat, self$n_rep,
            self$n_folds))
        } else {
          private$all_dml1_coef_ = array(NA_real_, dim = c(
            self$data$n_treat, self$n_rep,
            1))
        }
      }
    },
    initialize_boot_arrays = function(n_rep_boot) {
      private$n_rep_boot = n_rep_boot
      private$boot_coef_ = array(NA_real_, dim = c(
        self$data$n_treat,
        n_rep_boot * self$n_rep))
      private$boot_t_stat_ = array(NA_real_, dim = c(
        self$data$n_treat,
        n_rep_boot * self$n_rep))
    },
    initialize_predictions = function() {
      private$predictions_ = sapply(self$params_names(),
        function(key) {
          array(NA_real_, dim = c(
            self$data$n_obs, self$n_rep,
            self$data$n_treat))
        },
        simplify = F)
    },
    initialize_models = function() {
      private$models_ = sapply(self$params_names(),
        function(x) {
          sapply(self$data$d_cols,
            function(x) {
              lapply(
                seq(self$n_rep),
                function(x) vector("list", length = self$n_folds))
            },
            simplify = F)
        },
        simplify = F)
    },
    store_predictions = function(preds) {
      for (learner in self$params_names()) {
        if (!is.null(preds[[learner]])) {
          private$predictions_[[learner]][
            , private$i_rep,
            private$i_treat] = preds[[learner]]
        }
      }
    },
    store_models = function(models) {
      for (learner in self$params_names()) {
        if (!is.null(models[[learner]])) {
          private$models_[[learner]][[self$data$treat_col]][[
          private$i_rep]] = models[[learner]]
        }
      }
    },
    # Comment from python: The private properties with __ always deliver the
    # single treatment, single (cross-fitting) sample subselection
    # The slicing is based on the two properties self._i_treat,
    # the index of the treatment variable, and
    # self._i_rep, the index of the cross-fitting sample.
    get__smpls = function() self$smpls[[private$i_rep]],
    get__smpls_cluster = function() self$smpls_cluster[[private$i_rep]],
    get__psi_a = function() self$psi_a[, private$i_rep, private$i_treat],
    get__psi_b = function() self$psi_b[, private$i_rep, private$i_treat],
    get__psi = function() self$psi[, private$i_rep, private$i_treat],
    get__all_coef = function() self$all_coef[private$i_treat, private$i_rep],
    get__all_se = function() self$all_se[private$i_treat, private$i_rep],
    est_causal_pars = function() {
      dml_procedure = self$dml_procedure
      smpls = private$get__smpls()
      test_ids = smpls$test_ids

      if (!private$is_cluster_data) {
        if (dml_procedure == "dml1") {
          # Note that length(test_ids) is only not equal to self.n_folds
          # if self$apply_cross_fitting ==False
          thetas = rep(NA_real_, length(test_ids))
          for (i_fold in seq_len(length(test_ids))) {
            test_index = test_ids[[i_fold]]
            thetas[i_fold] = private$orth_est(inds = test_index)
          }
          coef = mean(thetas, na.rm = TRUE)
          private$all_dml1_coef_[private$i_treat, private$i_rep, ] = thetas
        } else if (dml_procedure == "dml2") {
          coef = private$orth_est()
        }
      } else {
        coef = private$orth_est_cluster_data()
      }
      return(coef)
    },
    se_causal_pars = function() {
      if (!private$is_cluster_data) {
        se = sqrt(private$var_est())
      } else {
        se = sqrt(private$var_est_cluster_data())
      }
      return(se)
    },
    agg_cross_fit = function() {
      # aggregate parameters from the repeated cross-fitting
      # don't use the getter (always for one treatment variable and one sample),
      # but the private variable

      private$coef_ = apply(
        self$all_coef, 1,
        function(x) median(x, na.rm = TRUE))
      # TODO: In the edge case of repeated no-cross-fitting, the test sets might
      # have different size and therefore it would note be valid to always use
      # the same self._var_scaling_factor
      private$se_ = sqrt(apply(
        private$var_scaling_factor * self$all_se^2 + (self$all_coef - self$coef)^2,
        1, function(x) median(x, na.rm = TRUE)) / private$var_scaling_factor)

      invisible(self)
    },
    compute_bootstrap = function(weights, n_rep_boot) {

      dml_procedure = self$dml_procedure
      smpls = private$get__smpls()
      test_ids = smpls$test_ids

      if (self$apply_cross_fitting) {
        n_obs = self$data$n_obs
      } else {
        test_index = test_ids[[1]]
        n_obs = length(test_index)
      }

      if (self$apply_cross_fitting) {
        J = mean(private$get__psi_a())
        boot_coef = weights %*% private$get__psi() / (n_obs * J)
        boot_t_stat = weights %*% private$get__psi() /
          (n_obs * private$get__all_se() * J)
      } else {
        J = mean(private$get__psi_a()[test_index])
        boot_coef = weights %*% private$get__psi()[test_index] /
          (n_obs * private$get__all_se() * J)
        boot_t_stat = weights %*% private$get__psi()[test_index] /
          (n_obs * J)
      }

      res = list(boot_coef = boot_coef, boot_t_stat = boot_t_stat)
      return(res)
    },
    var_est = function() {
      psi_a = private$get__psi_a()
      psi = private$get__psi()
      if (self$apply_cross_fitting) {
        private$var_scaling_factor = self$data$n_obs
      } else {
        smpls = private$get__smpls()
        test_ids = smpls$test_ids
        test_index = test_ids[[1]]
        psi_a = psi_a[test_index]
        psi = psi[test_index]
        private$var_scaling_factor = length(test_index)
      }
      J = mean(psi_a)
      sigma2_hat = mean(psi^2) / (J^2) / private$var_scaling_factor
      return(sigma2_hat)
    },
    var_est_cluster_data = function() {
      psi_a = private$get__psi_a()
      psi = private$get__psi()

      if (self$data$n_cluster_vars == 1) {
        this_cluster_var = self$data$data_model[[self$data$cluster_cols[1]]]
        clusters = unique(this_cluster_var)
        gamma_hat = 0
        j_hat = 0
        smpls = private$get__smpls()
        smpls_cluster = private$get__smpls_cluster()
        for (i_fold in 1:self$n_folds) {
          test_inds = smpls$test_ids[[i_fold]]
          test_cluster_inds = smpls_cluster$test_ids[[i_fold]]
          I_k = test_cluster_inds[[1]]
          const = 1 / length(I_k)
          for (cluster_value in I_k) {
            ind_cluster = (this_cluster_var == cluster_value)
            gamma_hat = gamma_hat + const * sum(outer(
              psi[ind_cluster],
              psi[ind_cluster]))
          }
          j_hat = j_hat + sum(psi_a[test_inds]) / length(I_k)
        }

        gamma_hat = gamma_hat / private$n_folds_per_cluster
        j_hat = j_hat / private$n_folds_per_cluster
        private$var_scaling_factor = length(clusters)
        sigma2_hat = gamma_hat / (j_hat^2) / private$var_scaling_factor
      } else {
        assert_choice(self$data$n_cluster_vars, 2)
        first_cluster_var = self$data$data_model[[self$data$cluster_cols[1]]]
        second_cluster_var = self$data$data_model[[self$data$cluster_cols[2]]]
        gamma_hat = 0
        j_hat = 0
        smpls = private$get__smpls()
        smpls_cluster = private$get__smpls_cluster()
        for (i_fold in 1:self$n_folds) {
          test_inds = smpls$test_ids[[i_fold]]
          test_cluster_inds = smpls_cluster$test_ids[[i_fold]]
          I_k = test_cluster_inds[[1]]
          J_l = test_cluster_inds[[2]]
          const = min(length(I_k), length(J_l)) / ((length(I_k) * length(J_l))^2)
          for (cluster_value in I_k) {
            ind_cluster = (first_cluster_var == cluster_value) &
              second_cluster_var %in% J_l
            gamma_hat = gamma_hat + const * sum(outer(
              psi[ind_cluster],
              psi[ind_cluster]))
          }
          for (cluster_value in J_l) {
            ind_cluster = (second_cluster_var == cluster_value) &
              first_cluster_var %in% I_k
            gamma_hat = gamma_hat + const * sum(outer(
              psi[ind_cluster],
              psi[ind_cluster]))
          }
          j_hat = j_hat + sum(psi_a[test_inds]) / (length(I_k) * length(J_l))
        }
        gamma_hat = gamma_hat / (private$n_folds_per_cluster^2)
        j_hat = j_hat / (private$n_folds_per_cluster^2)
        n_first_clusters = length(unique(first_cluster_var))
        n_second_clusters = length(unique(second_cluster_var))
        private$var_scaling_factor = min(n_first_clusters, n_second_clusters)
        sigma2_hat = gamma_hat / (j_hat^2) / private$var_scaling_factor
      }
      return(sigma2_hat)
    },
    orth_est = function(inds = NULL) {
      psi_a = private$get__psi_a()
      psi_b = private$get__psi_b()
      if (!is.null(inds)) {
        psi_a = psi_a[inds]
        psi_b = psi_b[inds]
      }
      theta = -mean(psi_b) / mean(psi_a)
      return(theta)
    },
    orth_est_cluster_data = function() {

      dml_procedure = self$dml_procedure
      psi_a = private$get__psi_a()
      psi_b = private$get__psi_b()

      smpls = private$get__smpls()
      test_ids = smpls$test_ids
      smpls_cluster = private$get__smpls_cluster()

      if (dml_procedure == "dml1") {
        # note that in the dml1 case we could also simply apply the standard
        # function without cluster adjustment
        thetas = rep(NA_real_, length(test_ids))
        for (i_fold in seq_len(length(test_ids))) {
          test_index = test_ids[[i_fold]]
          test_cluster_inds = smpls_cluster$test_ids[[i_fold]]
          xx = sapply(
            test_cluster_inds,
            function(x) length(x))
          scaling_factor = 1 / prod(xx)
          thetas[i_fold] = -(scaling_factor * sum(psi_b[test_index])) /
            (scaling_factor * sum(psi_a[test_index]))
        }
        theta = mean(thetas, na.rm = TRUE)
        private$all_dml1_coef_[private$i_treat, private$i_rep, ] = thetas
      } else if (dml_procedure == "dml2") {
        # See Chiang et al. (2021) Algorithm 1
        psi_a = private$get__psi_a()
        psi_b = private$get__psi_b()
        psi_a_subsample_mean = 0.
        psi_b_subsample_mean = 0.
        for (i_fold in seq_len(length(test_ids))) {
          test_index = test_ids[[i_fold]]
          test_cluster_inds = smpls_cluster$test_ids[[i_fold]]
          xx = sapply(
            test_cluster_inds,
            function(x) length(x))
          scaling_factor = 1 / prod(xx)
          psi_a_subsample_mean = psi_a_subsample_mean +
            scaling_factor * sum(psi_a[test_index])
          psi_b_subsample_mean = psi_b_subsample_mean +
            scaling_factor * sum(psi_b[test_index])
        }
        theta = -psi_b_subsample_mean / psi_a_subsample_mean
      }
      return(theta)
    },
    compute_score = function() {
      psi = private$get__psi_a() * private$get__all_coef() + private$get__psi_b()
      return(psi)
    }
  )
)

Try the DoubleML package in your browser

Any scripts or data that you put into this service are public.

DoubleML documentation built on April 1, 2023, 12:16 a.m.