tests/testthat/helper-11-dml_iivm.R

# Double Machine Learning for Interactive Instrumental Variable Regression Model.
dml_irmiv = function(data, y, d, z,
  n_folds,
  ml_g, ml_m, ml_r,
  dml_procedure, score,
  always_takers = TRUE, never_takers = TRUE,
  n_rep = 1, smpls = NULL,
  trimming_threshold = 1e-12,
  params_g = NULL, params_m = NULL, params_r = NULL) {

  if (is.null(smpls)) {
    smpls = lapply(1:n_rep, function(x) sample_splitting(n_folds, data))
  }

  all_thetas = all_ses = rep(NA_real_, n_rep)
  all_preds = list()

  for (i_rep in 1:n_rep) {
    this_smpl = smpls[[i_rep]]
    train_ids = this_smpl$train_ids
    test_ids = this_smpl$test_ids

    all_preds[[i_rep]] = fit_nuisance_iivm(
      data, y, d, z,
      ml_g, ml_m, ml_r,
      train_ids, test_ids,
      always_takers, never_takers,
      params_g, params_m, params_r)
    res = extract_iivm_preds(data, y, d, z, n_folds,
      this_smpl, all_preds[[i_rep]],
      trimming_threshold = trimming_threshold)
    m_hat = res$m_hat
    g0_hat = res$g0_hat
    g1_hat = res$g1_hat
    r0_hat = res$r0_hat
    r1_hat = res$r1_hat
    D = data[, d]
    Y = data[, y]
    Z = data[, z]

    # DML 1
    if (dml_procedure == "dml1") {
      thetas = vars = rep(NA_real_, n_folds)
      for (i in 1:n_folds) {
        test_index = test_ids[[i]]
        orth_est = orth_irmiv_dml(
          m_hat = m_hat[test_index],
          g0_hat = g0_hat[test_index], g1_hat = g1_hat[test_index],
          r0_hat = r0_hat[test_index], r1_hat = r1_hat[test_index],
          d = D[test_index], y = Y[test_index], z = Z[test_index],
          score = score)
        thetas[i] = orth_est$theta
      }
      all_thetas[i_rep] = mean(thetas, na.rm = TRUE)
      if (length(train_ids) == 1) {
        D = D[test_index]
        Y = Y[test_index]
        Z = Z[test_index]
        m_hat = m_hat[test_index]
        g0_hat = g0_hat[test_index]
        g1_hat = g1_hat[test_index]
        r0_hat = r0_hat[test_index]
        r1_hat = r1_hat[test_index]
      }
    }
    if (dml_procedure == "dml2") {
      orth_est = orth_irmiv_dml(
        m_hat = m_hat, g0_hat = g0_hat,
        g1_hat = g1_hat,
        r0_hat = r0_hat, r1_hat = r1_hat,
        d = D, y = Y, z = Z,
        score = score)
      all_thetas[i_rep] = orth_est$theta
    }

    all_ses[i_rep] = sqrt(var_irmiv(
      theta = all_thetas[i_rep], m_hat = m_hat, g0_hat = g0_hat,
      g1_hat = g1_hat, r0_hat = r0_hat, r1_hat = r1_hat,
      d = D, y = Y, z = Z, score = score))
  }

  theta = stats::median(all_thetas)
  if (length(this_smpl$train_ids) > 1) {
    n = nrow(data)
  } else {
    n = length(this_smpl$test_ids[[1]])
  }
  se = se_repeated(all_ses * sqrt(n), all_thetas, theta) / sqrt(n)

  t = theta / se
  pval = 2 * stats::pnorm(-abs(t))

  names(theta) = names(se) = d
  res = list(
    coef = theta, se = se, t = t, pval = pval,
    thetas = all_thetas, ses = all_ses,
    all_preds = all_preds, smpls = smpls)

  return(res)
}


fit_nuisance_iivm = function(data, y, d, z,
  ml_g, ml_m, ml_r,
  train_ids, test_ids,
  always_takers, never_takers,
  params_g, params_m, params_r) {

  # Set up task_m first to get resampling (test and train ids) scheme based on full sample
  # nuisance m

  m_indx = names(data) != y & names(data) != d
  data_m = data[, m_indx, drop = FALSE]

  # tbd: handle case with classif vs. regr. for task_m
  # if (grepl("regr.", mlmethod$mlmethod_p )) {
  #  # task_m = mlr3::TaskRegr$new(id = paste0("nuis_m_", z), backend = data_m, target = z)
  #   task_m = mlr3::TaskRegr$new(id = paste0("nuis_m_", z), backend = data_m, target = z)
  #   # task_m = mlr3::tsk(id = paste0("nuis_m_", z), backend )
  # }

  # if (grepl("classif.", mlmethod$mlmethod_m )) {
  data_m[, z] = factor(data_m[, z])
  task_m = mlr3::TaskClassif$new(
    id = paste0("nuis_m_", z), backend = data_m,
    target = z, positive = "1")
  # }

  resampling_m = mlr3::rsmp("custom")
  resampling_m$instantiate(task_m, train_ids, test_ids)
  n_iters = resampling_m$iters

  # in each fold, select those with z = 0
  train_ids_0 = lapply(1:n_iters, function(x) {
    resampling_m$train_set(x)[data[resampling_m$train_set(x), z] == 0]
  })
  # in each fold, select those with d = 0
  train_ids_1 = lapply(1:n_iters, function(x) {
    resampling_m$train_set(x)[data[resampling_m$train_set(x), z] == 1]
  })

  if (!is.null(params_m)) {
    ml_m$param_set$values = params_m
  }
  r_m = mlr3::resample(task_m, ml_m, resampling_m, store_models = TRUE)
  m_hat_list = lapply(r_m$predictions(), function(x) x$prob[, "1"])

  # nuisance g0: E[Y|Z=0, X]
  g_indx = names(data) != d & names(data) != z
  data_g = data[, g_indx, drop = FALSE]

  # binary or continuous outcome y
  if (any(class(ml_g) == "LearnerClassif")) {
    data_g[, y] = factor(data_g[, y])
    task_g0 = mlr3::TaskClassif$new(
      id = paste0("nuis_g0_", y), backend = data_g,
      target = y, positive = "1"
    )
  } else {
    task_g0 = mlr3::TaskRegr$new(id = paste0("nuis_g0_", y), backend = data_g, target = y)
  }

  ml_g0 = ml_g$clone()
  if (!is.null(params_g)) {
    ml_g0$param_set$values = params_g
  }
  resampling_g0 = mlr3::rsmp("custom")
  # Train on subset with z == 0 (in each fold) only, predict for all test obs
  resampling_g0$instantiate(task_g0, train_ids_0, test_ids)
  train_ids_g0 = lapply(1:n_iters, function(x) resampling_g0$train_set(x))
  test_ids_g0 = lapply(1:n_iters, function(x) resampling_g0$test_set(x))

  r_g0 = mlr3::resample(task_g0, ml_g0, resampling_g0, store_models = TRUE)
  if (any(class(ml_g) == "LearnerClassif")) {
    g0_hat_list = lapply(r_g0$predictions(), function(x) x$prob[, "1"])
  } else {
    g0_hat_list = lapply(r_g0$predictions(), function(x) x$response)
  }

  # nuisance g1: E[Y|Z=1, X]
  if (any(class(ml_g) == "LearnerClassif")) {
    data_g[, y] = factor(data_g[, y])
    task_g1 = mlr3::TaskClassif$new(
      id = paste0("nuis_g1_", y), backend = data_g,
      target = y, positive = "1"
    )
  } else {
    task_g1 = mlr3::TaskRegr$new(id = paste0("nuis_g1_", y), backend = data_g, target = y)
  }
  ml_g1 = ml_g$clone()
  if (!is.null(params_g)) {
    ml_g1$param_set$values = params_g
  }
  resampling_g1 = mlr3::rsmp("custom")
  # Train on subset with z == 1 (in each fold) only, predict for all test obs
  resampling_g1$instantiate(task_g1, train_ids_1, test_ids)
  train_ids_g1 = lapply(1:n_iters, function(x) resampling_g1$train_set(x))
  test_ids_g1 = lapply(1:n_iters, function(x) resampling_g1$test_set(x))

  r_g1 = mlr3::resample(task_g1, ml_g1, resampling_g1, store_models = TRUE)
  if (any(class(ml_g) == "LearnerClassif")) {
    g1_hat_list = lapply(r_g1$predictions(), function(x) x$prob[, "1"])
  } else {
    g1_hat_list = lapply(r_g1$predictions(), function(x) x$response)
  }

  # nuisance r0: E[D|Z=0, X]
  r_indx = names(data) != y & names(data) != z
  data_r = data[, r_indx, drop = FALSE]
  data_r[, d] = factor(data_r[, d])

  if (always_takers == FALSE & never_takers == FALSE) {
    message("If there are no always-takers and no never-takers, ATE is estimated")
  }

  if (always_takers == FALSE) {
    lengths = lapply(test_ids, length)
    r0_hat_list = lapply(lengths, function(x) rep(0, x))
  }

  if (always_takers == TRUE) {
    task_r0 = mlr3::TaskClassif$new(
      id = paste0("nuis_r0_", d), backend = data_r,
      target = d, positive = "1")
    ml_r0 = ml_r$clone()
    if (!is.null(params_r)) {
      ml_r0$param_set$values = params_r
    }

    resampling_r0 = mlr3::rsmp("custom")
    # Train on subset with z == 0 (in each fold) only, predict for all test obs
    resampling_r0$instantiate(task_r0, train_ids_0, test_ids)
    train_ids_r0 = lapply(1:n_iters, function(x) resampling_r0$train_set(x))
    test_ids_r0 = lapply(1:n_iters, function(x) resampling_r0$test_set(x))
    r_r0 = mlr3::resample(task_r0, ml_r0, resampling_r0, store_models = TRUE)
    r0_hat_list = lapply(r_r0$predictions(), function(x) x$prob[, "1"])
  }

  if (never_takers == FALSE) {
    lengths = lapply(test_ids, length)
    r1_hat_list = lapply(lengths, function(x) rep(1, x))
  }

  if (never_takers == TRUE) {
    # nuisance m1: E[E|Z=1, 0]
    task_r1 = mlr3::TaskClassif$new(
      id = paste0("nuis_r1_", d), backend = data_r,
      target = d, positive = "1")
    ml_r1 = ml_r$clone()
    if (!is.null(params_r)) {
      ml_r1$param_set$values = params_r
    }

    resampling_r1 = mlr3::rsmp("custom")
    # Train on subset with z == 0 (in each fold) only, predict for all test obs
    resampling_r1$instantiate(task_r1, train_ids_1, test_ids)
    train_ids_r1 = lapply(1:n_iters, function(x) resampling_r1$train_set(x))
    test_ids_r1 = lapply(1:n_iters, function(x) resampling_r1$test_set(x))
    r_r1 = mlr3::resample(task_r1, ml_r1, resampling_r1, store_models = TRUE)
    r1_hat_list = lapply(r_r1$predictions(), function(x) x$prob[, "1"])
  }

  all_preds = list(
    m_hat_list = m_hat_list,
    g0_hat_list = g0_hat_list,
    g1_hat_list = g1_hat_list,
    r0_hat_list = r0_hat_list,
    r1_hat_list = r1_hat_list)

  return(all_preds)
}


extract_iivm_preds = function(data, y, d, z, n_folds, smpls, all_preds,
  trimming_threshold) {

  test_ids = smpls$test_ids

  m_hat_list = all_preds$m_hat_list
  g0_hat_list = all_preds$g0_hat_list
  g1_hat_list = all_preds$g1_hat_list
  r0_hat_list = all_preds$r0_hat_list
  r1_hat_list = all_preds$r1_hat_list

  n = nrow(data)
  D = data[, d]
  Y = data[, y]
  Z = data[, z]
  m_hat = g0_hat = g1_hat = r0_hat = r1_hat = rep(NA_real_, n)

  for (i in 1:n_folds) {
    test_index = test_ids[[i]]

    m_hat[test_index] = m_hat_list[[i]]
    g0_hat[test_index] = g0_hat_list[[i]]
    g1_hat[test_index] = g1_hat_list[[i]]
    r0_hat[test_index] = r0_hat_list[[i]]
    r1_hat[test_index] = r1_hat_list[[i]]
  }

  m_hat = trim_vec(m_hat, trimming_threshold)

  res = list(
    m_hat = m_hat, g0_hat = g0_hat, g1_hat = g1_hat,
    r0_hat = r0_hat, r1_hat = r1_hat)
  return(res)
}


# Orthogonalized Estimation of Coefficient in irm
orth_irmiv_dml = function(m_hat, g0_hat, g1_hat, r0_hat, r1_hat, d, y, z, score) {
  theta = NA_real_

  if (score == "LATE" | score == "partialling out") {
    theta = 1 / mean(r1_hat - r0_hat + z * (d - r1_hat) / m_hat - ((1 - z) * (d - r0_hat) / (1 - m_hat))) *
      mean(g1_hat - g0_hat + z * (y - g1_hat) / m_hat - ((1 - z) * (y - g0_hat) / (1 - m_hat)))
  }
  else {
    stop("Inference framework for orthogonal estimation unknown")
  }

  res = list(theta = theta)
  return(res)
}


# Variance estimation for DML estimator in the Interactive Instrumental Variable Regression Model
var_irmiv = function(theta, m_hat, g0_hat, g1_hat, r0_hat, r1_hat, d, y, z, score) {
  n = length(d)
  if (score == "LATE") {
    var = 1 / n * 1 / (mean((r1_hat - r0_hat + z * (d - r1_hat) / m_hat - (1 - z) * (d - r0_hat) / (1 - m_hat))))^2 *
      mean((g1_hat - g0_hat + z * (y - g1_hat) / m_hat - (1 - z) * (y - g0_hat) / (1 - m_hat) -
        (r1_hat - r0_hat + z * (d - r1_hat) / m_hat - (1 - z) * (d - r0_hat) / (1 - m_hat)) * theta)^2)
  } else {
    stop("Inference framework for variance estimation unknown")
  }
  return(c(var))
}


# Bootstrap Implementation for Interactive Instrumental Variable Regression Model
bootstrap_irmiv = function(theta, se, data, y, d, z, n_folds, smpls, all_preds,
  score, bootstrap, n_rep_boot,
  n_rep = 1, trimming_threshold = 1e-12) {
  for (i_rep in 1:n_rep) {
    res = extract_iivm_preds(data, y, d, z, n_folds,
      smpls[[i_rep]], all_preds[[i_rep]],
      trimming_threshold = trimming_threshold)
    m_hat = res$m_hat
    g0_hat = res$g0_hat
    g1_hat = res$g1_hat
    r0_hat = res$r0_hat
    r1_hat = res$r1_hat
    D = data[, d]
    Y = data[, y]
    Z = data[, z]

    if (score == "LATE") {

      psi = g1_hat - g0_hat + Z * (Y - g1_hat) / m_hat - (1 - Z) * (Y - g0_hat) / (1 - m_hat) -
        (r1_hat - r0_hat + Z * (D - r1_hat) / m_hat - (1 - Z) * (D - r0_hat) / (1 - m_hat)) * theta[i_rep]

      psi_a = -(r1_hat - r0_hat + Z * (D - r1_hat) / m_hat
        - (1 - Z) * (D - r0_hat) / (1 - m_hat))
    } else {
      stop("Inference framework for multiplier bootstrap unknown")
    }

    n = length(psi)
    weights = draw_bootstrap_weights(bootstrap, n_rep_boot, n)
    this_res = functional_bootstrap(
      theta[i_rep], se[i_rep], psi, psi_a, n_folds,
      smpls[[i_rep]],
      n_rep_boot, weights)
    if (i_rep == 1) {
      boot_res = this_res
    } else {
      boot_res$boot_coef = cbind(boot_res$boot_coef, this_res$boot_coef)
      boot_res$boot_t_stat = cbind(boot_res$boot_t_stat, this_res$boot_t_stat)
    }
  }
  return(boot_res)
}

Try the DoubleML package in your browser

Any scripts or data that you put into this service are public.

DoubleML documentation built on April 1, 2023, 12:16 a.m.