Nothing
context("IAITrees")
test_that("common structure", {
skip_on_cran()
if (iai:::iai_version_less_than("2.0.0")) {
X <- JuliaCall::julia_eval(
"IAIConvert.convert_to_R(IAI.IAIBase.generate_mixed_data())"
)
names(X) <- c("num_attempts", "score1", "score2", "score3", "num_children",
"region")
y <- X$score1 >= 50 &
X$region %in% c("A", "B") |
X$score1 < 50 &
(X$score2 + 85 * X$score3 + 90 * (X$region == "E")) > 140
lnr <- iai::optimal_tree_classifier(
random_seed = 1,
max_depth = 2,
cp = 0.04,
hyperplane_config = list(sparsity = "all"),
)
} else {
if (iai:::iai_version_less_than("2.2.0")) {
X <- JuliaCall::julia_eval(
"IAIConvert.convert_to_R(IAI.IAIBase.generate_mixed_data(rng = IAI.IAIBase.make_rng(3)))"
)
} else {
X <- JuliaCall::julia_eval(
"IAIConvert.convert_to_R(IAI.IAIBase.generate_mixed_data(rng = IAI.IAIBase.make_rng(5)))"
)
}
names(X) <- c("num_attempts", "score1", "score2", "score3", "num_children",
"region")
y <- X$score1 >= 60 &
X$region %in% c("A", "B") |
X$score1 < 60 &
(X$score2 + 85 * X$score3 + 90 * (X$region == "E")) > 140
lnr <- iai::optimal_tree_classifier(
random_seed = 1,
max_depth = 2,
cp = 0.01,
hyperplane_config = list(sparsity = "all"),
)
}
iai::fit(lnr, X, y)
expect_equal(iai::get_num_nodes(lnr), 7)
expect_equal(iai::is_leaf(lnr, 1), FALSE)
expect_equal(iai::get_depth(lnr, 6), 2)
if (iai:::iai_version_less_than("2.0.0")) {
expect_equal(iai::get_num_samples(lnr, 6), 97)
} else if (iai:::iai_version_less_than("2.2.0")) {
expect_equal(iai::get_num_samples(lnr, 6), 72)
} else {
expect_equal(iai::get_num_samples(lnr, 6), 78)
}
expect_equal(iai::get_parent(lnr, 2), 1)
expect_equal(iai::get_lower_child(lnr, 1), 2)
expect_equal(iai::get_upper_child(lnr, 1), 5)
expect_equal(iai::is_parallel_split(lnr, 1), TRUE)
expect_equal(iai::is_hyperplane_split(lnr, 2), TRUE)
expect_equal(iai::is_categoric_split(lnr, 5), TRUE)
expect_equal(iai::is_ordinal_split(lnr, 1), FALSE)
expect_equal(iai::is_mixed_parallel_split(lnr, 2), FALSE)
expect_equal(iai::is_mixed_ordinal_split(lnr, 5), FALSE)
expect_equal(iai::missing_goes_lower(lnr, 1), FALSE)
expect_equal(iai::get_split_feature(lnr, 1), as.symbol("score1"))
if (iai:::iai_version_less_than("2.0.0")) {
expect_equal(iai::get_split_threshold(lnr, 1), 50, tolerance = 0.5)
} else {
expect_equal(iai::get_split_threshold(lnr, 1), 60, tolerance = 0.5)
}
expect_mapequal(iai::get_split_categories(lnr, 5), list(
A = TRUE,
B = TRUE,
C = FALSE,
D = FALSE,
E = FALSE
))
weights <- iai::get_split_weights(lnr, 2)
if (iai:::iai_version_less_than("2.0.0")) {
expect_mapequal(weights$numeric, list(score2 = 0.010100076620278502,
score3 = 2.0478494324868732))
expect_mapequal(weights$categoric,
list(region = list(E = 1.5176596636410404)))
} else if (iai:::iai_version_less_than("2.2.0")) {
expect_mapequal(weights$numeric, list(score2 = 0.0012369248211116827,
score3 = 0.09806740780674195))
expect_mapequal(weights$categoric,
list(region = list(E = 0.10571515793193487)))
} else {
expect_mapequal(weights$numeric, list(score2 = 0.018901518025769143,
score3 = 1.2041462082802483))
expect_mapequal(weights$categoric,
list(region = list(E = 1.4792242450156097)))
}
})
test_that("classification structure", {
skip_on_cran()
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_iris_tree(random_seed=1)"
)
lnr <- iai:::set_obj_class(lnr)
expect_equal(iai::get_classification_label(lnr, 2), "setosa")
expect_mapequal(iai::get_classification_proba(lnr, 4), list(
virginica = 0.09259259259259259,
setosa = 0.0,
versicolor = 0.9074074074074074
))
expect_error(iai::get_classification_label(lnr, 1))
expect_error(iai::get_classification_proba(lnr, 1))
if (!iai:::iai_version_less_than("2.1.0")) {
iai::get_classification_label(lnr, 1, check_leaf = FALSE)
iai::get_classification_proba(lnr, 1, check_leaf = FALSE)
}
if (iai:::iai_version_less_than("3.0.0")) {
expect_error(iai::get_regression_constant(lnr, 2),
"requires IAI version 3.0.0")
expect_error(iai::get_regression_weights(lnr, 2),
"requires IAI version 3.0.0")
} else {
expect_equal(iai::get_regression_constant(lnr, 2), NaN)
weights <- iai::get_regression_weights(lnr, 2)
expect_equal(length(weights$numeric), 0)
expect_equal(length(weights$categoric), 0)
}
})
test_that("regression structure", {
skip_on_cran()
if (iai:::iai_version_less_than("2.1.0")) {
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_mtcars_tree(random_seed=1,
regression_sparsity=\"all\",
regression_lambda=0.2)"
)
} else if (iai:::iai_version_less_than("3.1.0")) {
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_mtcars_tree(random_seed=1,
regression_sparsity=\"all\",
regression_lambda=0.02)"
)
} else {
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_mtcars_tree(random_seed=1,
regression_features=Set([\"All\"]),
regression_lambda=0.02)"
)
}
lnr <- iai:::set_obj_class(lnr)
expect_equal(iai::get_regression_constant(lnr, 2), 30.879999999999995)
if (iai:::iai_version_less_than("2.1.0")) {
expect_equal(iai::get_regression_constant(lnr, 3), 26.56192034262967)
weights <- iai::get_regression_weights(lnr, 3)
expect_mapequal(weights$numeric, list(Disp = -0.021044493648366,
HP = -0.018861409939436))
expect_true(is.list(weights$categoric) && length(weights$categoric) == 0)
} else {
expect_equal(iai::get_regression_constant(lnr, 3), 30.887599089534906)
weights <- iai::get_regression_weights(lnr, 3)
expect_mapequal(weights$numeric,
list(Cyl = -0.794565711367838, Gear = 0.058519556715652,
HP = -0.012667192837728, WT = -1.649738918131852))
expect_true(is.list(weights$categoric) && length(weights$categoric) == 0)
}
expect_error(iai::get_regression_constant(lnr, 1))
expect_error(iai::get_regression_weights(lnr, 1))
if (!iai:::iai_version_less_than("2.1.0")) {
iai::get_regression_constant(lnr, 1, check_leaf = FALSE)
iai::get_regression_weights(lnr, 1, check_leaf = FALSE)
}
})
test_that("survival structure", {
skip_on_cran()
if (iai:::iai_version_less_than("2.0.0")) {
iai::set_julia_seed(4)
lnr <- JuliaCall::julia_eval("IAI.OptimalTrees.load_survival_tree()")
} else {
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_survival_tree(random_seed=1, max_depth=1, cp=0)")
}
lnr <- iai:::set_obj_class(lnr)
curve <- iai::get_survival_curve(lnr, 2)
expect_equal(class(curve), c(
"survival_curve",
"IAIObject",
"JuliaObject"
))
if (iai:::iai_version_less_than("2.2.0")) {
expect_error(iai::predict_expected_survival_time(curve),
"requires IAI version 2.2.0")
} else {
expect_true(is.numeric(iai::predict_expected_survival_time(curve)))
}
curve_data <- iai::get_survival_curve_data(curve)
if (iai:::iai_version_less_than("2.0.0")) {
expect_equal(curve_data$coefs, c(
0.00000000, 0.02380952, 0.02597403, 0.02813853, 0.03030303,
0.03246753, 0.03463203, 0.03679654, 0.03896104, 0.04112554,
0.04329004, 0.04545455, 0.06734007, 0.08922559, 0.11111111,
0.13299663, 0.15488215, 0.17676768, 0.19865320, 0.22053872,
0.24242424))
expect_equal(curve_data$times, c(
00000, 11000, 12000, 13000, 14000, 15000, 16000, 17000, 18000,
19000, 20000, 21000, 22000, 23000, 24000, 25000, 26000, 27000,
28000, 29000, 30000))
} else if (iai:::iai_version_less_than("2.2.0")) {
expect_equal(curve_data$coefs, c(
0.000000, 0.003472, 0.024306, 0.066330, 0.098207, 0.112815,
0.123968, 0.150682, 0.193501, 0.225786, 0.260612, 0.316929,
0.377445, 0.427421, 0.462996, 0.501441, 0.547923, 0.575877,
0.618289, 0.685522, 0.749382, 0.862943), tolerance = 1e-6)
expect_equal(curve_data$times, c(
00000, 06000, 11000, 12000, 13000, 14000, 15000, 16000, 17000,
18000, 19000, 20000, 21000, 22000, 23000, 24000, 25000, 26000,
27000, 28000, 29000, 30000))
} else {
expect_equal(curve_data$coefs, c(
0.000000, 0.005814, 0.023256, 0.052538, 0.076310, 0.094997,
0.114033, 0.140796, 0.168202, 0.190038, 0.220994, 0.271124,
0.340665, 0.395205, 0.433004, 0.482354, 0.523395, 0.546250,
0.597468, 0.669475, 0.727317, 0.845066), tolerance = 1e-6)
expect_equal(curve_data$times, c(
00000, 7000, 11000, 12000, 13000, 14000, 15000, 16000,
17000, 18000, 19000, 20000, 21000, 22000, 23000, 24000,
25000, 26000, 27000, 28000, 29000, 30000))
}
if (iai:::iai_version_less_than("2.1.0")) {
expect_error(iai::get_survival_expected_time(),
"requires IAI version 2.1.0")
expect_error(iai::get_survival_hazard(), "requires IAI version 2.1.0")
} else if (iai:::iai_version_less_than("2.2.0")) {
expect_equal(iai::get_survival_expected_time(lnr, 2), 22981.39)
expect_equal(iai::get_survival_hazard(lnr, 2), 0.9541041, tolerance = 1e-6)
} else {
expect_equal(iai::get_survival_expected_time(lnr, 2), 23443.187)
expect_equal(iai::get_survival_hazard(lnr, 2), 0.8880508, tolerance = 1e-6)
}
expect_error(iai::get_survival_curve(lnr, 1))
expect_error(iai::get_survival_expected_time(lnr, 1))
expect_error(iai::get_survival_hazard(lnr, 1))
if (!iai:::iai_version_less_than("2.1.0")) {
iai::get_survival_curve(lnr, 1, check_leaf = FALSE)
iai::get_survival_expected_time(lnr, 1, check_leaf = FALSE)
iai::get_survival_hazard(lnr, 1, check_leaf = FALSE)
}
if (iai:::iai_version_less_than("3.0.0")) {
expect_error(iai::get_regression_constant(lnr, 2),
"requires IAI version 3.0.0")
expect_error(iai::get_regression_weights(lnr, 2),
"requires IAI version 3.0.0")
} else {
expect_equal(iai::get_regression_constant(lnr, 2), NaN)
weights <- iai::get_regression_weights(lnr, 2)
expect_equal(length(weights$numeric), 0)
expect_equal(length(weights$categoric), 0)
}
})
test_that("prescription structure", {
skip_on_cran()
if (iai:::iai_version_less_than("2.0.0")) {
iai::set_julia_seed(2)
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_prescription_tree(regression_sparsity=\"all\",
regression_lambda=0.22,
max_depth=2)"
)
} else if (iai:::iai_version_less_than("2.1.0")) {
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_prescription_tree(regression_sparsity=\"all\",
regression_lambda=0.22,
max_depth=2,
random_seed=1)"
)
} else if (iai:::iai_version_less_than("2.2.0")) {
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_prescription_tree(regression_sparsity=\"all\",
regression_weighted_betas=true,
regression_lambda=1.9,
max_depth=2,
random_seed=1)"
)
} else if (iai:::iai_version_less_than("3.1.0")) {
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_prescription_tree(regression_sparsity=\"all\",
regression_weighted_betas=true,
regression_lambda=1.9,
max_depth=2,
random_seed=2)"
)
} else {
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_prescription_tree(
regression_features=Set([\"All\"]),
regression_weighted_betas=true,
regression_lambda=1.9,
max_depth=2,
random_seed=2,
)"
)
}
lnr <- iai:::set_obj_class(lnr)
if (iai:::iai_version_less_than("2.0.0")) {
weights <- iai::get_regression_weights(lnr, 5, 1)
expect_mapequal(weights$numeric, list(Disp = -0.007198454096246))
expect_true(is.list(weights$categoric) && length(weights$categoric) == 0)
expect_equal(iai::get_prescription_treatment_rank(lnr, 2), c(1, 0))
expect_equal(iai::get_regression_constant(lnr, 2, 0), 30.5)
} else if (iai:::iai_version_less_than("2.1.0")) {
weights <- iai::get_regression_weights(lnr, 5, 0)
expect_mapequal(weights$numeric, list(Disp = -0.00853409230131,
AM = 1.316408317777783))
expect_true(is.list(weights$categoric) && length(weights$categoric) == 0)
expect_equal(iai::get_prescription_treatment_rank(lnr, 5), c(1, 0))
expect_equal(iai::get_regression_constant(lnr, 5, 0), 18.507454507299066)
} else if (iai:::iai_version_less_than("2.2.0")) {
weights <- iai::get_regression_weights(lnr, 4, 0)
expect_mapequal(weights$numeric, list(Cyl = -0.189847291283807))
expect_true(is.list(weights$categoric) && length(weights$categoric) == 0)
expect_equal(iai::get_prescription_treatment_rank(lnr, 4), c(0, 1))
expect_equal(iai::get_regression_constant(lnr, 4, 0), 20.7970532059596)
} else {
weights <- iai::get_regression_weights(lnr, 2, 0)
expect_mapequal(weights$numeric, list(Cyl = -1.377692110219233))
expect_true(is.list(weights$categoric) && length(weights$categoric) == 0)
expect_equal(iai::get_prescription_treatment_rank(lnr, 2), c(0, 1))
expect_equal(iai::get_regression_constant(lnr, 2, 0), 28.682819327982067)
}
expect_error(iai::get_prescription_treatment_rank(lnr, 1))
expect_error(iai::get_regression_constant(lnr, 1, 0))
expect_error(iai::get_regression_weights(lnr, 1, 0))
if (!iai:::iai_version_less_than("2.1.0")) {
iai::get_prescription_treatment_rank(lnr, 1, check_leaf = FALSE)
iai::get_regression_constant(lnr, 1, 0, check_leaf = FALSE)
iai::get_regression_weights(lnr, 1, 0, check_leaf = FALSE)
}
})
test_that("policy structure", {
skip_on_cran()
if (iai:::iai_version_less_than("2.0.0")) {
expect_error(iai::get_policy_treatment_rank(), "requires IAI version 2.0.0")
} else {
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_policy_tree(max_depth=2, random_seed=1)"
)
lnr <- iai:::set_obj_class(lnr)
expect_equal(iai::get_policy_treatment_rank(lnr, 3), c("A", "C", "B"))
}
if (iai:::iai_version_less_than("2.1.0")) {
expect_error(iai::get_policy_treatment_outcome(),
"requires IAI version 2.1.0")
} else {
outcomes <- iai::get_policy_treatment_outcome(lnr, 3)
if (iai:::iai_version_less_than("2.2.0")) {
expect_equal(outcomes$A, 0.8276032, tolerance = 1e-6)
expect_equal(outcomes$B, 1.698339, tolerance = 1e-6)
expect_equal(outcomes$C, 1.096775, tolerance = 1e-6)
} else {
expect_equal(outcomes$A, 0.827778, tolerance = 1e-6)
expect_equal(outcomes$B, 1.70248, tolerance = 1e-5)
expect_equal(outcomes$C, 1.09849, tolerance = 1e-5)
}
}
if (iai:::iai_version_less_than("3.2.0")) {
expect_error(iai::get_policy_treatment_outcome_standard_error(),
"requires IAI version 3.2.0")
} else {
errors <- iai::get_policy_treatment_outcome_standard_error(lnr, 3)
expect_equal(errors$A, 0.0777876, tolerance = 1e-5)
expect_equal(errors$B, 0.083841, tolerance = 1e-5)
expect_equal(errors$C, 0.10806, tolerance = 1e-5)
}
expect_error(iai::get_policy_treatment_rank(lnr, 1))
expect_error(iai::get_policy_treatment_outcome(lnr, 1))
expect_error(iai::get_policy_treatment_outcome_standard_error(lnr, 1))
if (!iai:::iai_version_less_than("2.1.0")) {
iai::get_policy_treatment_rank(lnr, 1, check_leaf = FALSE)
iai::get_policy_treatment_outcome(lnr, 1, check_leaf = FALSE)
}
if (!iai:::iai_version_less_than("3.2.0")) {
iai::get_policy_treatment_outcome_standard_error(lnr, 1, check_leaf = FALSE)
}
})
test_that("visualization", {
skip_on_cran()
lnr <- JuliaCall::julia_eval("IAI.OptimalTrees.load_iris_tree()")
lnr <- iai:::set_obj_class(lnr)
extra_content <- replicate(iai::get_num_nodes(lnr),
list("node_color" = "#FFFFFF"),
FALSE)
if (!iai:::iai_version_less_than("3.1.0")) {
if (!JuliaCall::julia_eval("IAI.IAITrees.has_graphviz()")) {
iai::load_graphviz()
}
expect_true(JuliaCall::julia_eval("IAI.IAITrees.has_graphviz()"))
}
if (JuliaCall::julia_eval("IAI.IAITrees.has_graphviz()")) {
iai::write_png("test.png", lnr)
expect_true(file.exists("test.png"))
file.remove("test.png")
if (iai:::iai_version_less_than("2.1.0")) {
} else {
iai::write_png("test.png", lnr, extra_content = extra_content)
expect_true(file.exists("test.png"))
file.remove("test.png")
}
if (iai:::iai_version_less_than("2.1.0")) {
error_message <- "requires IAI version 2.1.0"
expect_error(iai::write_pdf("test.pdf", lnr), error_message)
expect_error(iai::write_svg("test.svg", lnr), error_message)
} else {
iai::write_pdf("test.pdf", lnr)
expect_true(file.exists("test.pdf"))
file.remove("test.pdf")
iai::write_pdf("test.pdf", lnr, extra_content = extra_content)
expect_true(file.exists("test.pdf"))
file.remove("test.pdf")
iai::write_svg("test.svg", lnr)
expect_true(file.exists("test.svg"))
file.remove("test.svg")
iai::write_svg("test.svg", lnr, extra_content = extra_content)
expect_true(file.exists("test.svg"))
file.remove("test.svg")
}
}
iai::write_dot("test.dot", lnr)
expect_true(file.exists("test.dot"))
file.remove("test.dot")
if (iai:::iai_version_less_than("2.1.0")) {
} else {
iai::write_dot("test.dot", lnr, extra_content = extra_content)
expect_true(file.exists("test.dot"))
file.remove("test.dot")
}
iai::write_html("tree.html", lnr)
expect_true(file.exists("tree.html"))
lines <- readLines("tree.html")
expect_false(length(grep("\"Target\"", lines, value = TRUE)) > 0)
expect_false(length(grep("\"Results\"", lines, value = TRUE)) > 0)
file.remove("tree.html")
iai::write_html("tree.html", lnr, extra_content = extra_content)
expect_true(file.exists("tree.html"))
lines <- readLines("tree.html")
file.remove("tree.html")
iai::write_questionnaire("question.html", lnr)
expect_true(file.exists("question.html"))
file.remove("question.html")
if (iai:::iai_version_less_than("1.1.0")) {
expect_error(iai::tree_plot(lnr), "requires IAI version 1.1.0")
expect_error(iai::questionnaire(lnr), "requires IAI version 1.1.0")
expect_error(iai::multi_tree_plot(list()), "requires IAI version 1.1.0")
expect_error(iai::multi_questionnaire(list()), "requires IAI version 1.1.0")
} else {
feature_renames <- list(
"PetalLength" = "A",
"PetalWidth" = "B",
"SepalWidth" = "C"
)
vis <- iai::tree_plot(lnr, feature_renames = feature_renames)
expect_equal(class(vis), c(
"tree_plot",
"abstract_visualization",
"IAIObject",
"JuliaObject"
))
iai::write_html("tree_rename.html", vis)
expect_true(file.exists("tree_rename.html"))
file.remove("tree_rename.html")
vis <- iai::questionnaire(lnr, feature_renames = feature_renames)
expect_equal(class(vis), c(
"questionnaire",
"abstract_visualization",
"IAIObject",
"JuliaObject"
))
iai::write_html("questionnaire_rename.html", vis)
expect_true(file.exists("questionnaire_rename.html"))
file.remove("questionnaire_rename.html")
questions <- list("Use learner with" = list(
"renamed features" = lnr,
"extra text output" = lnr
))
vis <- iai::multi_tree_plot(questions)
expect_equal(class(vis), c(
"multi_tree_plot",
"abstract_visualization",
"IAIObject",
"JuliaObject"
))
iai::write_html("multitree.html", vis)
expect_true(file.exists("multitree.html"))
file.remove("multitree.html")
vis <- iai::multi_questionnaire(questions)
expect_equal(class(vis), c(
"multi_questionnaire",
"abstract_visualization",
"IAIObject",
"JuliaObject"
))
iai::write_html("multiquestion.html", vis)
expect_true(file.exists("multiquestion.html"))
file.remove("multiquestion.html")
}
X <- iris[, 1:4]
y <- iris$Species
grid <- iai::grid_search(
iai::optimal_tree_classifier(
random_seed = 1,
max_depth = 1,
),
)
iai::fit(grid, X, y)
if (iai:::iai_version_less_than("2.0.0")) {
expect_error(iai::write_html("grid.html", grid),
"requires IAI version 2.0.0")
expect_error(iai::write_questionnaire("grid.html", grid),
"requires IAI version 2.0.0")
expect_error(iai::show_in_browser(grid), "requires IAI version 2.0.0")
expect_error(iai::show_questionnaire(grid), "requires IAI version 2.0.0")
} else {
vis <- iai::multi_tree_plot(grid)
expect_equal(class(vis), c(
"multi_tree_plot",
"abstract_visualization",
"IAIObject",
"JuliaObject"
))
iai::write_html("multitree.html", vis)
expect_true(file.exists("multitree.html"))
file.remove("multitree.html")
vis <- iai::multi_questionnaire(grid)
expect_equal(class(vis), c(
"multi_questionnaire",
"abstract_visualization",
"IAIObject",
"JuliaObject"
))
iai::write_html("multiquestion.html", vis)
expect_true(file.exists("multiquestion.html"))
file.remove("multiquestion.html")
}
# Data visualization
if (iai:::iai_version_less_than("2.1.0")) {
} else {
X <- iris[, 1:4]
y <- iris$Species
grid <- iai::grid_search(
iai::optimal_tree_classifier(
random_seed = 1,
max_depth = 1,
),
)
iai::fit(grid, X, y)
lnr <- iai::get_learner(grid)
iai::write_html("tree_with_data.html", lnr, data = list(X, y))
lines <- readLines("tree_with_data.html")
expect_true(length(grep("\"Target\"", lines, value = TRUE)) > 0)
expect_true(length(grep("\"Results\"", lines, value = TRUE)) > 0)
file.remove("tree_with_data.html")
if (iai:::iai_version_less_than("2.2.0")) {
expect_error(iai::write_html("tree_with_data.html", lnr, data = X))
} else {
iai::write_html("tree_with_data.html", lnr, data = X)
lines <- readLines("tree_with_data.html")
expect_false(length(grep("\"Target\"", lines, value = TRUE)) > 0)
expect_true(length(grep("\"Results\"", lines, value = TRUE)) > 0)
file.remove("tree_with_data.html")
}
}
})
test_that("tree API", {
skip_on_cran()
X <- iris[, 1:4]
y <- iris$Species
lnr <- iai::optimal_tree_classifier(max_depth = 1, cp = 0)
iai::fit(lnr, X, y)
expect_equal(length(iai::apply(lnr, X)), length(y))
expect_equal(length(iai::apply_nodes(lnr, X)), iai::get_num_nodes(lnr))
path <- iai::decision_path(lnr, X)
expect_equal(nrow(path), length(y))
expect_equal(ncol(path), iai::get_num_nodes(lnr))
iai::print_path(lnr, X, 1)
expect_true(is.data.frame(iai::variable_importance(lnr)))
if (iai:::iai_version_less_than("2.2.0")) {
expect_error(iai::get_features_used(lnr), "requires IAI version 2.2.0")
} else {
expect_true(is.vector(iai::get_features_used(lnr)))
}
})
test_that("classification tree API", {
skip_on_cran()
X <- iris[, 1:4]
y <- iris$Species == "setosa"
lnr <- iai::optimal_tree_classifier(max_depth = 1, cp = 0)
iai::fit(lnr, X, y)
expect_true(is.data.frame(iai::predict_proba(lnr, X)))
expect_equal(iai::get_num_nodes(lnr), 3)
iai::set_threshold(lnr, TRUE, 0, simplify = TRUE)
expect_equal(iai::get_num_nodes(lnr), 1)
iai::set_display_label(lnr, TRUE)
expect_true(grepl("true)", print(lnr)))
iai::reset_display_label(lnr)
expect_false(grepl("true)", print(lnr)))
})
test_that("survival tree API", {
skip_on_cran()
lnr <- iai::optimal_tree_survival_learner(max_depth = 1, cp = 0)
n <- 100
X <- matrix(rnorm(200), n, 2)
died <- rbinom(n, 1, 0.5) == 1
times <- runif(n)
iai::fit(lnr, X, died, times)
if (iai:::iai_version_less_than("1.2.0")) {
expect_error(iai::predict_hazard(lnr, X), "requires IAI version 1.2.0")
} else {
expect_equal(length(iai::predict_hazard(lnr, X)), n)
}
if (iai:::iai_version_less_than("2.0.0")) {
expect_error(iai::predict_expected_survival_time(lnr, X),
"requires IAI version 2.0.0")
} else {
expect_equal(length(iai::predict_expected_survival_time(lnr, X)), n)
}
})
test_that("prescription tree API", {
skip_on_cran()
for (f in c(iai::optimal_tree_prescription_minimizer,
iai::optimal_tree_prescription_maximizer)) {
lnr <- f(max_depth = 1, cp = 0)
X <- matrix(rnorm(200), 100, 2)
treatments <- rbinom(100, 1, 0.5)
outcomes <- runif(100)
iai::fit(lnr, X, treatments, outcomes)
expect_true(is.data.frame(iai::predict_outcomes(lnr, X)))
pred <- iai::predict(lnr, X)
expect_true(is.list(pred))
expect_equal(names(pred), c("treatments", "outcomes"))
}
})
test_that("policy tree API", {
skip_on_cran()
if (iai:::iai_version_less_than("2.0.0")) {
expect_error(iai::optimal_tree_policy_minimizer(),
"requires IAI version 2.0.0")
expect_error(iai::optimal_tree_policy_maximizer(),
"requires IAI version 2.0.0")
} else {
for (f in c(iai::optimal_tree_policy_minimizer,
iai::optimal_tree_policy_maximizer)) {
lnr <- f(max_depth = 1, cp = 0)
X <- matrix(rnorm(200), 100, 2)
rewards <- matrix(rnorm(200), 100, 2)
iai::fit(lnr, X, rewards)
expect_true(is.vector(iai::predict(lnr, X)))
expect_true(is.vector(iai::predict_outcomes(lnr, X, rewards)))
}
}
})
test_that("stability", {
skip_on_cran()
X <- iris[, 1:4]
y <- iris$Species == "setosa"
lnr <- iai::optimal_tree_classifier(max_depth = 1, cp = 0)
iai::fit(lnr, X, y)
if (iai:::iai_version_less_than("2.2.0")) {
expect_error(iai::stability_analysis(lnr, X, y),
"requires IAI version 2.2.0")
} else {
stability <- iai::stability_analysis(lnr, X, y)
expect_equal(class(stability), c(
"stability_analysis",
"abstract_visualization",
"IAIObject",
"JuliaObject"
))
}
deviations <- runif(4)
if (iai:::iai_version_less_than("2.2.0")) {
expect_error(iai::similarity_comparison(lnr, lnr, deviations),
"requires IAI version 2.2.0")
} else {
similarity <- iai::similarity_comparison(lnr, lnr, deviations)
expect_equal(class(similarity), c(
"similarity_comparison",
"abstract_visualization",
"IAIObject",
"JuliaObject"
))
}
})
test_that("multi classification structure", {
skip_on_cran()
if (iai:::iai_version_less_than("3.2.0")) {
expect_error(iai::optimal_tree_multi_classifier(),
"requires IAI version 3.2.0")
} else {
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_iris_tree_multi(random_seed=1, max_depth=1)"
)
lnr <- iai:::set_obj_class(lnr)
label_all <- iai::get_classification_label(lnr, 3)
expect_true(is.list(label_all))
label_single <- iai::get_classification_label(lnr, 3, "y1")
expect_true(is.character(label_single))
expect_equal(label_all$y1, label_single)
proba_all <- iai::get_classification_proba(lnr, 3)
expect_true(is.list(proba_all))
proba_single <- iai::get_classification_proba(lnr, 3, "y1")
expect_true(is.list(proba_single))
expect_equal(proba_all$y1, proba_single)
const_all <- iai::get_regression_constant(lnr, 3)
expect_true(is.list(const_all))
const_single <- iai::get_regression_constant(lnr, 3, "y1")
expect_true(is.nan(const_single))
expect_true(is.nan(const_all$y1))
weights_all <- iai::get_regression_weights(lnr, 3)
expect_true(is.list(weights_all))
weights_single <- iai::get_regression_weights(lnr, 3, "y1")
expect_true(is.list(weights_single))
expect_equal(length(weights_single), 2)
expect_true(is.list(weights_single$numeric))
expect_true(is.list(weights_single$categoric))
expect_equal(length(weights_single$numeric), 0)
expect_equal(length(weights_single$categoric), 0)
expect_equal(weights_all$y1, weights_single)
}
})
test_that("multi regression structure", {
skip_on_cran()
if (iai:::iai_version_less_than("3.2.0")) {
expect_error(iai::optimal_tree_multi_regressor(),
"requires IAI version 3.2.0")
} else {
lnr <- JuliaCall::julia_eval(
"IAI.OptimalTrees.load_mtcars_tree_multi(
random_seed=1,
regression_lambda=0.01,
regression_features=Set([\"All\"]),
max_depth=1,
)"
)
lnr <- iai:::set_obj_class(lnr)
const_all <- iai::get_regression_constant(lnr, 1)
expect_true(is.list(const_all))
const_single <- iai::get_regression_constant(lnr, 1, "MPG")
expect_true(is.numeric(const_single))
expect_equal(const_all$MPG, const_single)
weights_all <- iai::get_regression_weights(lnr, 1)
expect_true(is.list(weights_all))
weights_single <- iai::get_regression_weights(lnr, 1, "MPG")
expect_true(is.list(weights_single))
expect_equal(length(weights_single), 2)
expect_true(is.list(weights_single$numeric))
expect_true(is.list(weights_single$categoric))
expect_equal(weights_all$MPG, weights_single)
}
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.