Nothing
# test-ALEPlot.R
# Tests to ensure that ale package gives exactly the same results
# as the gold standard reference ALEPlot package.
# test_file('tests/testthat/test-ALEPlot.R')
# To minimize test time, the reference output should be serialized with expect_snapshot_value.
# Do not run these on CRAN so that the required packages are not included as dependencies.
# https://community.rstudio.com/t/skip-an-entire-test-file-on-cran-only/162842
if (!identical(Sys.getenv("NOT_CRAN"), "true")) return()
# nnet -----------------
set.seed(0)
n = 1000 # smaller dataset for more rapid execution
x1 <- runif(n, min = 0, max = 1)
x2 <- runif(n, min = 0, max = 1)
x3 <- runif(n, min = 0, max = 1)
x4 <- runif(n, min = 0, max = 1)
y = 4*x1 + 3.87*x2^2 + 2.97*exp(-5+10*x3)/(1+exp(-5+10*x3))+
13.86*(x1-0.5)*(x2-0.5)+ rnorm(n, 0, 1)
DAT <<- data.frame(y, x1, x2, x3, x4)
set.seed(0)
nnet.DAT <<- nnet::nnet(y ~ ., data = DAT, linout = T, skip = F, size = 6,
decay = 0.1, maxit = 1000, trace = F)
# Define the predict functions
nnet_pred_fun_ALEPlot <<- function(X.model, newdata) {
as.numeric(predict(X.model, newdata,type = "raw"))
}
nnet_pred_fun_ale <<- function(object, newdata, type = pred_type) {
as.numeric(predict(object, newdata, type = type))
}
# gbm ----------------
adult_data <<-
census |>
as.data.frame() |> # ALEPlot is not compatible with the tibble format
select(age:native_country, higher_income) |> # Rearrange columns to match ALEPlot order
stats::na.omit(data)
# Dump plots automatically generated by gbm into a temp PDF file so they don't print
# Don't print any plots
pdf(file = NULL)
set.seed(0)
gbm.data <<- gbm::gbm(
higher_income ~ .,
data = adult_data[,-c(3,4)] |>
# gbm::gbm() requires binary response outcomes to be numeric 0 or 1
mutate(higher_income = as.integer(higher_income)),
distribution = "bernoulli",
n.trees = 100, # smaller model than ALEPlot example for rapid execution
shrinkage = 0.02,
interaction.depth = 3
)
# Return to regular printing of plots
dev.off() |> invisible()
gbm_pred_fun_ALEPlot <<- function(X.model, newdata) {
as.numeric(gbm::predict.gbm(X.model, newdata, n.trees = 100, type="link"))
}
gbm_pred_fun_ale <<- function(object, newdata, type = pred_type) {
as.numeric(gbm::predict.gbm(object, newdata, n.trees = 100, type = type))
}
# Tests --------------------
test_that('ale function matches output of ALEPlot with nnet', {
# Dump plots into a temp PDF file so they don't print
# Don't print any plots
pdf(file = NULL)
# Create list of ALEPlot data that can be readily compared for accuracy
nnet_ALEPlot <-
map(1:4, \(it.col_idx) {
ALEPlot::ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = nnet_pred_fun_ALEPlot, J = it.col_idx, K = 10) |>
as_tibble() |>
select(-K)
}) |>
set_names(names(DAT[,2:5]))
# Return to regular printing of plots
dev.off() |> invisible()
# Create ale results with data only
nnet_ale <- ALE(
# basic arguments
model = nnet.DAT,
data = DAT,
# make ale equivalent to ALEPlot
parallel = 0,
output_stats = FALSE,
boot_it = 0,
# specific options requested by ALEPlot example
pred_type = "raw", pred_fun = nnet_pred_fun_ale,
max_num_bins = 10 + 1,
silent = TRUE
)
# Convert ale results to version that can be readily compared with ALEPlot
nnet_ale_to_ALEPlot <-
get(nnet_ale, ale_centre = 'zero') |>
map(\(it.x) {
tibble(
x.values = it.x[[1]],
f.values = it.x$.y,
)
})
# Compare results of ALEPlot with ale
expect_true(
all.equal(nnet_ALEPlot, nnet_ale_to_ALEPlot, tolerance = 0.01)
)
})
test_that('ale function matches output of ALEPlot with gbm', {
# Dump plots into a temp PDF file so they don't print
# Don't print any plots
pdf(file = NULL)
# Create list of ALEPlot data that can be readily compared for accuracy
# For this test, get only four variables: c('age', 'workclass', 'education_num', 'sex')
# These are column indexes c(1, 2, 3, 8)
gbm_ALEPlot <-
map(c(1, 2, 3, 8), \(it.col_idx) {
ALEPlot::ALEPlot(
adult_data[,-c(3,4,15)], gbm.data, pred.fun = gbm_pred_fun_ALEPlot,
J = it.col_idx,
K = 10, NA.plot = TRUE
) |>
as_tibble() |>
select(-K)
}) |>
set_names(names(adult_data[,-c(3,4,15)])[c(1, 2, 3, 8)])
# Return to regular printing of plots
dev.off() |> invisible()
# Create ale results with data only
gbm_ale <- ALE(
model = gbm.data,
x_cols = c('age', 'workclass', 'education_num', 'sex'),
data = adult_data[,-c(3,4)], # unlike ALEPlot, include the y column (15)
# make ale equivalent to ALEPlot
parallel = 0,
output_stats = FALSE,
boot_it = 0,
# specific options requested by ALEPlot example
pred_fun = gbm_pred_fun_ale, pred_type = 'link',
max_num_bins = 10 + 1,
silent = TRUE
) |>
suppressMessages()
# Convert ale results to version that can be readily compared with ALEPlot
gbm_ale_to_ALEPlot <-
get(gbm_ale, ale_centre = 'zero') |>
map(\(it.x) {
tibble(
x.values = it.x[[1]],
f.values = unname(it.x$.y),
) |>
mutate(across(where(is.factor), as.character))
})
# Compare results of ALEPlot with ale
expect_true(
all.equal(gbm_ALEPlot, gbm_ale_to_ALEPlot)
)
})
test_that('2D ALE matches output of ALEPlot interactions with nnet', {
# Dump plots into a temp PDF file so they don't print
# Don't print any plots
pdf(file = NULL)
# Create list of ALEPlot data that can be readily compared for accuracy
nnet_ALEPlot_ixn <- list()
for (it.x1 in 1:4) {
for (it.x2 in 1:4) {
if (it.x1 < it.x2) {
ap_data <- ALEPlot::ALEPlot(
DAT[,2:5],
nnet.DAT,
pred.fun = nnet_pred_fun_ALEPlot,
J = c(it.x1, it.x2),
K = 10
)
.x1 <- ap_data$x.values[[1]]
.x2 <- ap_data$x.values[[2]]
.y <- ap_data$f.values
ixn_tbl <-
expand.grid(
row = 1:length(.x1),
col = 1:length(.x2)
) |>
as_tibble() |>
mutate(
.x1 = .x1[row],
.x2 = as.numeric(.x2[col]),
.y = as.numeric(.y[cbind(row, col)])
) |>
select(-row, -col) |>
arrange(.x1, .x2, .y)
# Remove extraneous attributes, otherwise comparison will not match
attributes(ixn_tbl)$out.attrs <- NULL
nnet_ALEPlot_ixn[[str_glue('x{it.x1}:x{it.x2}')]] <- ixn_tbl
}
}
}
# Return to regular printing of plots
dev.off() |> invisible()
nnet_2D <- ALE(
# basic arguments
model = nnet.DAT,
data = DAT,
x_cols = list(d2 = TRUE),
parallel = 0,
output_stats = FALSE,
pred_fun = nnet_pred_fun_ale,
pred_type = "raw", max_num_bins = 10 + 1, # specific options requested
silent = TRUE
)
# Convert ale results to version that can be readily compared with ALEPlot
nnet_2D_to_ALEPlot <-
get(nnet_2D, ale_centre = 'zero') |>
map(\(it.ale) {
it.ale <- it.ale |>
select(1, 2, .y) |>
set_names(c('.x1', '.x2', '.y')) |>
arrange(.x1, .x2, .y)
# Strip incomparable attributes
attr(it.ale, 'x') <- NULL
it.ale
})
# Compare results of ALEPlot with ale
expect_true(
all.equal(nnet_ALEPlot_ixn, nnet_2D_to_ALEPlot, tolerance = 0.01)
)
})
test_that('2D ALE matches output of ALEPlot interactions with gbm', {
# Dump plots into a temp PDF file so they don't print
# Don't print any plots
pdf(file = NULL)
# Create list of ALEPlot data that can be readily compared for accuracy
gbm_ALEPlot_ixn <- list()
adult_data_subset <- adult_data[,-c(3,4,15)]
for (it.x1 in c(1, 2, 3, 8)) {
for (it.x2 in c(1, 3, 11)) {
if (it.x1 < it.x2) {
ap_data <- ALEPlot::ALEPlot(
adult_data_subset,
gbm.data,
pred.fun = gbm_pred_fun_ALEPlot,
J = c(it.x1, it.x2),
K = 10,
NA.plot = TRUE
)
.x1 <- ap_data$x.values[[1]]
.x2 <- ap_data$x.values[[2]]
.y <- ap_data$f.values
ixn_tbl <-
expand.grid(
row = 1:length(.x1),
col = 1:length(.x2)
) |>
as_tibble() |>
mutate(
.x1 = .x1[row],
.x2 = as.numeric(.x2[col]),
.y = as.numeric(.y[cbind(row, col)])
) |>
select(-row, -col) |>
arrange(.x1, .x2, .y)
# Remove extraneous attributes, otherwise comparison will not match
attributes(ixn_tbl)$out.attrs <- NULL
gbm_ALEPlot_ixn[[str_glue(
'{names(adult_data_subset)[it.x1]}:{names(adult_data_subset)[it.x2]}'
)]] <- ixn_tbl
}
}
}
gbm_2D <- ALE(
model = gbm.data,
data = adult_data,
x_cols = c(
'age:education_num',
'age:hours_per_week',
'workclass:education_num',
'workclass:hours_per_week',
'education_num:hours_per_week',
'sex:hours_per_week'
),
parallel = 0,
output_stats = FALSE,
pred_fun = gbm_pred_fun_ale,
pred_type = 'link', max_num_bins = 10 + 1, # specific options requested
silent = TRUE
)
# Return to regular printing of plots.
# For some reason, calling ALE() on gbm.data also prints some plots.
dev.off() |> invisible()
# Convert ale results to version that can be readily compared with ALEPlot
gbm_2D_to_ALEPlot <-
get(gbm_2D, ale_centre = 'zero') |>
map(\(it.ale) {
it.ale <- it.ale |>
select(1, 2, .y) |>
set_names(c('.x1', '.x2', '.y')) |>
# Convert [ordered] factor columns to character for comparability with ALEPlot
mutate(across(
'.x1',
\(it.col) if (is.factor(it.col)) as.character(it.col) else it.col
)) |>
arrange(.x1, .x2, .y)
# Strip incomparable attributes
attr(it.ale, 'x') <- NULL
it.ale
})
# Compare results of ALEPlot with ale
expect_true(
all.equal(gbm_ALEPlot_ixn, gbm_2D_to_ALEPlot)
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.