Nothing
# Prepare common environment for testing
# https://testthat.r-lib.org/articles/special-files.html
# setup.R is only available for tests
# helper.R is also available for package devoloment with devtools::load_all()
# Train a GAM on var_cars dataset -------------------
## Create test_cars: --------
# var_cars with more variables so that test will not fail with mgcv::gam error "Model has more coefficients than data"
set.seed(0)
dbl_jitter <- stats::runif(nrow(var_cars), 0.99, 1.01)
int_jitter <- sample(c(-1L, 0L, 1L), nrow(var_cars), replace = TRUE)
test_cars <- var_cars |>
dplyr::bind_rows(
var_cars |>
mutate(
across(where(is.double), \(it.dbl) it.dbl * dbl_jitter)) |>
mutate(across(
where(is.integer),
\(it.int) if_else(it.int > 1, it.int + int_jitter, it.int)
))
) |>
# Use only a few variables to minimize ALE processing time
select(
# outcomes: numeric, binary, categorical
mpg, vs, continent,
# predictors: binary, categorical, ordinal, integer, double
am, model, gear, carb, wt
)
rm(dbl_jitter)
rm(int_jitter)
test_gam <- mgcv::gam(
mpg ~ model + s(wt) + am + gear + carb,
# mpg ~ model + cyl + s(disp) + s(hp) + s(drat) + s(wt) + s(qsec) + vs + am + gear + carb + country + continent,
data = test_cars
)
test_gam_binary <- mgcv::gam(
vs ~ model + s(wt) + am + gear + carb,
# vs ~ model + cyl + s(disp) + s(hp) + s(drat) + s(wt) + s(qsec) + am + gear + carb + country + continent,
data = test_cars,
family = stats::binomial()
)
test_nn_categorical <- nnet::multinom(
# Remove mpg (typical target) and country (perfectly determines continent) from the model
continent ~ model + wt + am + gear + carb,
# continent ~ . - mpg - country,
data = test_cars,
trace = FALSE # suppress noisy output from nnet
)
# Returns list of ALE plots converted to ggplot data format ---------------
ale_plots_to_data <- function(
ale_plots # ALEPlots object
) {
purrr::imap(ale_plots@plots, \(it.cat_plots, it.cat_name) {
list(
d1 = if (it.cat_name != '.all_cats') {
it.cat_plots$d1 |>
purrr::map(\(it.plot) {
ggplot2::ggplot_build(it.plot)$data[[1]]
})
} else {
it.cat_plots$d1 |>
purrr::map(\(it.x_col) {
it.x_col |>
purrr::map(\(it.plot) {
ggplot2::ggplot_build(it.plot)$data[[1]]
})
})
},
d2 = if (it.cat_name != '.all_cats') {
it.cat_plots$d2 |>
purrr::map(\(it.plot) {
ggplot2::ggplot_build(it.plot)$data[[1]]
})
} else {
it.cat_plots$d2 |>
purrr::map(\(it.plot) {
ggplot2::ggplot_build(it.plot)$data[[1]]
})
},
eff = if (it.cat_name != '.all_cats') {
if (!is.null(it.cat_plots$eff)) {
it.cat_plots$eff |>
ggplot2::ggplot_build() |>
(`[[`)('data')
} else {
# No effects plot if no 1D data or no statistics
NULL
}
} else {
# No effects plot for '.all_cats' category
NULL
}
)
})
}
# custom predict function ------------
test_predict <- function(object, newdata, type = pred_type) {
predict(object, newdata, se.fit = TRUE, type = type)$fit
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.