tests/testthat/test_rewardestimation.R

context("RewardEstimation")


test_that("equal_propensity_estimator", {
  skip_on_cran()

  if (iai:::iai_version_less_than("2.1.0")) {
    expect_error(iai::equal_propensity_estimator(),
                 "requires IAI version 2.1.0")
  } else {
    lnr <- iai::equal_propensity_estimator()
    expect_true(iai:::jl_isa(lnr, "IAI.EqualPropensityEstimator"))
  }
})

test_that("`reward_estimator` is deprecated", {
  skip_on_cran()
  if (iai:::iai_version_less_than("2.0.0")) {
    expect_error(iai::reward_estimator(), "requires IAI version 2.0.0")
  } else if (iai:::iai_version_less_than("2.3.0")) {
    lifecycle::expect_deprecated(iai::reward_estimator())
  } else {
    expect_error(iai::reward_estimator(), "removed in IAI v3")
  }
})

test_that("`categorical_reward_estimator` is deprecated", {
  skip_on_cran()
  if (iai:::iai_version_less_than("2.0.0")) {
    expect_error(iai::categorical_reward_estimator(),
                 "requires IAI version 2.0.0")
  } else if (iai:::iai_version_less_than("2.3.0")) {
    lifecycle::expect_deprecated(iai::categorical_reward_estimator())
  } else {
    expect_error(iai::categorical_reward_estimator(), "removed in IAI v3")
  }
})

test_that("`numeric_reward_estimator` is deprecated", {
  skip_on_cran()
  if (iai:::iai_version_less_than("2.1.0")) {
    expect_error(iai::numeric_reward_estimator(), "requires IAI version 2.1.0")
  } else if (iai:::iai_version_less_than("2.3.0")) {
    lifecycle::expect_deprecated(iai::numeric_reward_estimator())
  } else {
    expect_error(iai::numeric_reward_estimator(), "removed in IAI v3")
  }
})


test_that("internal estimator consistency", {
  skip_on_cran()
  if (iai:::iai_version_less_than("2.2.0")) {
    expect_error(iai::categorical_classification_reward_estimator(),
                 "requires IAI version 2.2.0")
    expect_error(iai::categorical_regression_reward_estimator(),
                 "requires IAI version 2.2.0")
    expect_error(iai::numeric_classification_reward_estimator(),
                 "requires IAI version 2.2.0")
    expect_error(iai::numeric_regression_reward_estimator(),
                 "requires IAI version 2.2.0")
  } else {
    class_lnr <- iai::optimal_tree_classifier()
    reg_lnr <- iai::optimal_tree_regressor()
    surv_lnr <- iai::optimal_tree_survival_learner()

    L <- iai::categorical_classification_reward_estimator
    # propensity
    L(propensity_estimator = class_lnr)
    expect_error(L(propensity_estimator = reg_lnr))
    # outcome
    L(outcome_estimator = class_lnr)
    expect_error(L(outcome_estimator = reg_lnr))

    L <- iai::categorical_regression_reward_estimator
    # propensity
    L(propensity_estimator = class_lnr)
    expect_error(L(propensity_estimator = reg_lnr))
    # outcome
    expect_error(L(outcome_estimator = class_lnr))
    L(outcome_estimator = reg_lnr)

    L <- iai::categorical_survival_reward_estimator
    # propensity
    L(propensity_estimator = class_lnr)
    expect_error(L(propensity_estimator = reg_lnr))
    # outcome
    expect_error(L(outcome_estimator = class_lnr))
    L(outcome_estimator = surv_lnr)

    L <- iai::numeric_classification_reward_estimator
    # propensity
    expect_error(L(propensity_estimator = class_lnr))
    L(propensity_estimator = reg_lnr)
    # outcome
    L(outcome_estimator = class_lnr)
    expect_error(L(outcome_estimator = reg_lnr))

    L <- iai::numeric_regression_reward_estimator
    # propensity
    expect_error(L(propensity_estimator = class_lnr))
    L(propensity_estimator = reg_lnr)
    # outcome
    expect_error(L(outcome_estimator = class_lnr))
    L(outcome_estimator = reg_lnr)

    L <- iai::numeric_survival_reward_estimator
    # propensity
    expect_error(L(propensity_estimator = class_lnr))
    L(propensity_estimator = reg_lnr)
    # outcome
    expect_error(L(outcome_estimator = class_lnr))
    L(outcome_estimator = surv_lnr)
  }
})


test_that("all_treatment_combinations", {
  skip_on_cran()

  if (iai:::iai_version_less_than("2.1.0")) {
    expect_error(iai::all_treatment_combinations(),
                 "requires IAI version 2.1.0")
  } else {
    c1 <- seq(1, 5)
    c2 <- c(1, 2)

    out <- iai::all_treatment_combinations(c1)
    expect_true(is.data.frame(out))
    expect_equal(colnames(out), c("treatment1"))
    expect_equal(nrow(out), length(c1))
    expect_equal(ncol(out), 1)

    out <- iai::all_treatment_combinations(c1, c2)
    expect_true(is.data.frame(out))
    expect_equal(colnames(out), c("treatment1", "treatment2"))
    expect_equal(nrow(out), length(c1) * length(c2))
    expect_equal(ncol(out), 2)

    out <- iai::all_treatment_combinations(name1 = c1)
    expect_true(is.data.frame(out))
    expect_equal(colnames(out), c("name1"))
    expect_equal(nrow(out), length(c1))
    expect_equal(ncol(out), 1)

    out <- iai::all_treatment_combinations(name1 = c1, name2 = c2)
    expect_true(is.data.frame(out))
    expect_equal(colnames(out), c("name1", "name2"))
    expect_equal(nrow(out), length(c1) * length(c2))
    expect_equal(ncol(out), 2)
  }
})


test_that("categorical API", {
  skip_on_cran()

  X <- iris[, 1:2]
  t <- iris$Species
  y <- iris[, 4]

  if (iai:::iai_version_less_than("2.0.0")) {
  } else if (iai:::iai_version_less_than("2.2.0")) {
    withr::local_options(lifecycle_verbosity = "quiet")

    lnr <- iai::categorical_reward_estimator(
        propensity_estimation_method = "random_forest",
        outcome_estimation_method = "random_forest",
        reward_estimation_method = "doubly_robust",
        random_seed = 1,
    )
    rewards <- iai::fit_predict(lnr, X, t, y)
    preds <- iai::predict(lnr, X, t, y)

    expect_equal(nrow(preds), nrow(X))
    expect_equal(ncol(preds), length(unique(t)))

    if (iai:::iai_version_less_than("2.1.0")) {
      expect_error(iai::score(lnr, X, t, y), "requires IAI version 2.1.0")
    } else {
      s <- iai::score(lnr, X, t, y)
      expect_equal(length(s), 2)
      expect_equal(length(s$propensity), 1)
      expect_equal(length(s$outcome), length(unique(t)))
    }

    expect_error(iai::predict_reward(lnr, X, t, y),
                 "requires IAI version 3.0.0")
  } else {
    lnr <- iai::categorical_regression_reward_estimator(
        propensity_estimator = iai::random_forest_classifier(num_trees = 5),
        outcome_estimator = iai::random_forest_regressor(num_trees = 5),
        reward_estimator = "doubly_robust",
    )
    rewards <- iai::fit_predict(lnr, X, t, y)
    preds <- iai::predict(lnr, X, t, y)

    s <- iai::score(lnr, X, t, y)
    expect_equal(length(s), 2)
    expect_equal(length(s$propensity), 1)
    expect_equal(length(s$outcome), length(unique(t)))

    if (iai:::iai_version_less_than("3.0.0")) {
      expect_equal(nrow(preds), nrow(X))
      expect_equal(ncol(preds), length(unique(t)))

      expect_error(iai::predict_reward(lnr, X, t, y),
                   "requires IAI version 3.0.0")
    } else {
      expect_equal(nrow(preds$reward), nrow(X))
      expect_equal(ncol(preds$reward), length(unique(t)))
      expect_equal(nrow(preds$propensity), nrow(X))
      expect_equal(ncol(preds$propensity), length(unique(t)))
      expect_equal(nrow(preds$outcome), nrow(X))
      expect_equal(ncol(preds$outcome), length(unique(t)))

      rewards2 <- iai::predict_reward(lnr, t, y, rewards$predictions)
      expect_equal(nrow(rewards2$reward), nrow(X))
      expect_equal(ncol(rewards2$reward), length(unique(t)))
      expect_equal(nrow(rewards2$propensity), nrow(X))
      expect_equal(ncol(rewards2$propensity), length(unique(t)))
      expect_equal(nrow(rewards2$outcome), nrow(X))
      expect_equal(ncol(rewards2$outcome), length(unique(t)))
    }
  }
})

test_that("numeric API", {
  skip_on_cran()

  X <- iris[, 1:2]
  t <- iris[, 3]
  y <- iris[, 4]
  c <- unique(y)

  if (iai:::iai_version_less_than("2.1.0")) {
  } else if (iai:::iai_version_less_than("2.2.0")) {
    withr::local_options(lifecycle_verbosity = "quiet")

    lnr <- iai::numeric_reward_estimator(
        outcome_estimator = iai::random_forest_regressor(num_trees = 5),
    )
    rewards <- iai::fit_predict(lnr, X, t, y, c)

    preds <- iai::predict(lnr, X, c)
    expect_equal(nrow(preds), nrow(X))
    expect_equal(ncol(preds), length(c))

    expect_equal(length(iai::score(lnr, X, t, y)), 1)
  } else {
    lnr <- iai::numeric_regression_reward_estimator(
        propensity_estimator = iai::random_forest_regressor(num_trees = 5),
        outcome_estimator = iai::random_forest_regressor(num_trees = 5),
        reward_estimator = "doubly_robust",
    )
    rewards <- iai::fit_predict(lnr, X, t, y, c)
    preds <- iai::predict(lnr, X, t, y)

    if (iai:::iai_version_less_than("3.0.0")) {
      expect_equal(nrow(preds), nrow(X))
      expect_equal(ncol(preds), length(c))
    } else {
      expect_equal(nrow(preds$reward), nrow(X))
      expect_equal(ncol(preds$reward), length(c))
      expect_equal(nrow(preds$propensity), nrow(X))
      expect_equal(ncol(preds$propensity), length(c))
      expect_equal(nrow(preds$outcome), nrow(X))
      expect_equal(ncol(preds$outcome), length(c))
    }
  }

  if (iai:::iai_version_less_than("2.2.0")) {
    expect_error(iai::get_estimation_densities(), "requires IAI version 2.2.0")
    expect_error(iai::tune_reward_kernel_bandwidth(),
                 "requires IAI version 2.2.0")
    expect_error(iai::set_reward_kernel_bandwidth(),
                 "requires IAI version 2.2.0")
  } else {

    if (iai:::iai_version_less_than("3.0.0")) {
      densities <- iai::get_estimation_densities(lnr)
    } else {
      densities <- iai::get_estimation_densities(lnr, t, c)
    }
    expect_equal(length(densities), length(c))

    input <- c(1, 2, 3)
    if (iai:::iai_version_less_than("3.0.0")) {
      tuned <- iai::tune_reward_kernel_bandwidth(lnr, input)
    } else {
      tuned <- iai::tune_reward_kernel_bandwidth(lnr, t, y, rewards$predictions,
                                                 input)
    }
    expect_equal(length(tuned), length(input))

    if (iai:::iai_version_less_than("3.0.0")) {
      rewards2 <- iai::set_reward_kernel_bandwidth(lnr, 2)
      expect_equal(nrow(rewards2), nrow(X))
      expect_equal(ncol(rewards2), length(c))
    } else {
      rewards2 <- iai::set_reward_kernel_bandwidth(lnr, t, y,
                                                   rewards$predictions, 2)
      expect_equal(nrow(rewards2$reward), nrow(X))
      expect_equal(ncol(rewards2$reward), length(c))
      expect_equal(nrow(rewards2$propensity), nrow(X))
      expect_equal(ncol(rewards2$propensity), length(c))
      expect_equal(nrow(rewards2$outcome), nrow(X))
      expect_equal(ncol(rewards2$outcome), length(c))
    }

    if (iai:::iai_version_less_than("3.0.0")) {
      expect_error(iai::predict_reward(lnr, t, y, rewards$predictions))
    } else {
      rewards2 <- iai::predict_reward(lnr, t, y, rewards$predictions)
      expect_equal(nrow(rewards2$reward), nrow(X))
      expect_equal(ncol(rewards2$reward), length(c))
      expect_equal(nrow(rewards2$propensity), nrow(X))
      expect_equal(ncol(rewards2$propensity), length(c))
      expect_equal(nrow(rewards2$outcome), nrow(X))
      expect_equal(ncol(rewards2$outcome), length(c))
    }
  }
})


test_that("JSON", {
  skip_on_cran()

  if (!iai:::iai_version_less_than("3.0.0")) {
    lnrs <- list(iai::categorical_classification_reward_estimator(),
                 iai::categorical_regression_reward_estimator(),
                 iai::categorical_survival_reward_estimator(),
                 iai::numeric_classification_reward_estimator(),
                 iai::numeric_regression_reward_estimator(),
                 iai::numeric_survival_reward_estimator(),
                 iai::equal_propensity_estimator())
    for (i in seq(lnrs)) {
      lnr <- lnrs[[i]]

      iai::write_json("re.json", lnr)
      new_lnr <- iai::read_json("re.json")
      file.remove("re.json")
      expect_true(lnr == new_lnr)
    }
  }
})


test_that("class", {
  skip_on_cran()

  if (!iai:::iai_version_less_than("2.1.0")) {
    expect_equal(class(iai::equal_propensity_estimator()), c(
        "equal_propensity_estimator",
        "classification_learner",
        "supervised_learner",
        "learner",
        "IAIObject",
        "JuliaObject"
    ))
  }

  if (!iai:::iai_version_less_than("3.0.0")) {
    expect_equal(class(iai::categorical_classification_reward_estimator()), c(
        "categorical_classification_reward_estimator",
        "categorical_reward_estimator",
        "reward_estimator",
        "supervised_learner",
        "learner",
        "IAIObject",
        "JuliaObject"
    ))
    expect_equal(class(iai::categorical_regression_reward_estimator()), c(
        "categorical_regression_reward_estimator",
        "categorical_reward_estimator",
        "reward_estimator",
        "supervised_learner",
        "learner",
        "IAIObject",
        "JuliaObject"
    ))
    expect_equal(class(iai::categorical_survival_reward_estimator()), c(
        "categorical_survival_reward_estimator",
        "categorical_reward_estimator",
        "reward_estimator",
        "supervised_learner",
        "learner",
        "IAIObject",
        "JuliaObject"
    ))

    expect_equal(class(iai::numeric_classification_reward_estimator()), c(
        "numeric_classification_reward_estimator",
        "numeric_reward_estimator",
        "reward_estimator",
        "supervised_learner",
        "learner",
        "IAIObject",
        "JuliaObject"
    ))
    expect_equal(class(iai::numeric_regression_reward_estimator()), c(
        "numeric_regression_reward_estimator",
        "numeric_reward_estimator",
        "reward_estimator",
        "supervised_learner",
        "learner",
        "IAIObject",
        "JuliaObject"
    ))
    expect_equal(class(iai::numeric_survival_reward_estimator()), c(
        "numeric_survival_reward_estimator",
        "numeric_reward_estimator",
        "reward_estimator",
        "supervised_learner",
        "learner",
        "IAIObject",
        "JuliaObject"
    ))
  }
})

Try the iai package in your browser

Any scripts or data that you put into this service are public.

iai documentation built on July 9, 2023, 5:41 p.m.