default_hyperparams <- structure(
list(
param = c(
"lambda", "lambda", "lambda", "lambda", "lambda",
"lambda", "lambda", "lambda", "lambda", "lambda", "lambda", "lambda",
"lambda", "alpha", "sigma", "sigma", "sigma", "sigma", "sigma",
"sigma", "sigma", "sigma", "C", "C", "C", "C", "C", "C", "C",
"C", "C", "maxdepth", "maxdepth", "maxdepth", "maxdepth", "maxdepth",
"maxdepth", "nrounds", "gamma", "eta", "eta", "eta", "eta", "max_depth",
"colsample_bytree", "min_child_weight", "subsample", "subsample",
"subsample", "subsample", "mtry", "mtry"
),
value = c(
"1e-6",
"1e-5", "1e-4", "1e-3", "0.0025", "0.005", "0.01", "0.05", "0.1",
"0.25", "0.5", "1", "10", "0", "0.00000001",
"0.0000001", "0.000001", "0.00001", "0.0001", "0.001", "0.01",
"0.1", "0.0000001", "0.000001", "0.00001", "0.0001", "0.001",
"0.01", "0.1", "1", "10", "1", "2", "3", "4", "5", "6", "500",
"0", "0.001", "0.01", "0.1", "1", "8", "0.8", "1", "0.4", "0.5",
"0.6", "0.7", "500", "1000"
),
method = c(
"glmnet", "glmnet",
"glmnet", "glmnet", "glmnet", "glmnet", "glmnet",
"glmnet", "glmnet", "glmnet", "glmnet", "glmnet",
"glmnet", "glmnet", "svmRadial", "svmRadial",
"svmRadial", "svmRadial", "svmRadial", "svmRadial", "svmRadial",
"svmRadial", "svmRadial", "svmRadial", "svmRadial", "svmRadial",
"svmRadial", "svmRadial", "svmRadial", "svmRadial", "svmRadial",
"rpart2", "rpart2", "rpart2", "rpart2", "rpart2", "rpart2", "xgbTree",
"xgbTree", "xgbTree", "xgbTree", "xgbTree", "xgbTree", "xgbTree",
"xgbTree", "xgbTree", "xgbTree", "xgbTree", "xgbTree", "xgbTree", "rf", "rf"
)
),
class = c("spec_tbl_df", "tbl_df", "tbl", "data.frame"),
row.names = c(NA, -52L),
spec = structure(
list(
cols = list(
param = structure(list(), class = c("collector_character", "collector")),
val = structure(list(), class = c("collector_character", "collector")),
method = structure(list(), class = c("collector_character", "collector"))
),
default = structure(list(), class = c("collector_guess", "collector")), skip = 1
),
class = "col_spec"
)
)
# tune grid tests for each method
test_that("tune grid works for glmnet", {
hyperparams_lst <- default_hyperparams %>%
get_hyperparams_from_df("glmnet")
grid <- expand.grid(
alpha = hyperparams_lst$alpha,
lambda = hyperparams_lst$lambda
) %>% mutate_all_types()
expect_equal(get_tuning_grid(hyperparams_lst, "glmnet"), grid)
})
test_that("tune grid works for svmRadial", {
hyperparams_lst <- default_hyperparams %>%
get_hyperparams_from_df("svmRadial")
grid <- expand.grid(
C = hyperparams_lst$C,
sigma = hyperparams_lst$sigma
) %>% mutate_all_types()
expect_equal(get_tuning_grid(hyperparams_lst, "svmRadial"), grid)
})
test_that("tune grid works for rpart2", {
hyperparams_lst <- default_hyperparams %>%
get_hyperparams_from_df("rpart2")
grid <- expand.grid(maxdepth = hyperparams_lst$maxdepth) %>% mutate_all_types()
expect_equal(get_tuning_grid(hyperparams_lst, "rpart2"), grid)
})
test_that("tune grid works for rf", {
hyperparams_lst <- default_hyperparams %>%
get_hyperparams_from_df("rf")
grid <- expand.grid(mtry = hyperparams_lst$mtry) %>% mutate_all_types()
expect_equal(get_tuning_grid(hyperparams_lst, "rf"), grid)
})
test_that("tune grid works for xgbTree", {
hyperparams_lst <- default_hyperparams %>%
get_hyperparams_from_df("xgbTree")
grid <- expand.grid(
colsample_bytree = hyperparams_lst$colsample_bytree,
eta = hyperparams_lst$eta,
gamma = hyperparams_lst$gamma,
max_depth = hyperparams_lst$max_depth,
min_child_weight = hyperparams_lst$min_child_weight,
nrounds = hyperparams_lst$nrounds,
subsample = hyperparams_lst$subsample
) %>% mutate_all_types()
expect_equal(get_tuning_grid(hyperparams_lst, "xgbTree"), grid)
})
# get_hyperparams_list
test_that("get_hyperparams_list works for all models", {
expect_equal(
get_hyperparams_list(otu_mini_bin, "glmnet"),
list(
lambda = c(1e-04, 0.001, 0.01, 0.1, 1, 10),
alpha = 0
)
)
expect_equal(
get_hyperparams_list(otu_mini_bin, "rf"),
list(mtry = c(2, 3, 6))
)
expect_equal(
get_hyperparams_list(otu_small, "rf"),
list(mtry = c(4, 8, 16))
)
expect_equal(
get_hyperparams_list(data.frame(a = 1:10, b = 4:13), "rf"),
list(mtry = 1)
)
expect_equal(
get_hyperparams_list(otu_small, "rpart2"),
list(maxdepth = c(1, 2, 4, 8, 16, 30))
)
expect_equal(
get_hyperparams_list(data.frame(a = 1:10, b = 4:13), "rpart2"),
list(maxdepth = c(1, 2, 4, 8))
)
expect_equal(
get_hyperparams_list(otu_mini_bin, "svmRadial"),
list(
C = c(0.001, 0.01, 0.1, 1, 10, 100),
sigma = c(1e-06, 1e-05, 1e-04, 0.001, 0.01, 0.1)
)
)
expect_equal(
get_hyperparams_list(otu_mini_bin, "xgbTree"),
list(
nrounds = 100, gamma = 0, eta = c(0.001, 0.01, 0.1, 1),
max_depth = c(1, 2, 4, 8, 16, 30), colsample_bytree = 0.8,
min_child_weight = 1, subsample = c(0.4, 0.5, 0.6, 0.7)
)
)
})
test_that("parRF and rf use same default hyperparameters", {
expect_equal(
get_hyperparams_list(otu_mini_bin, "rf"),
get_hyperparams_list(otu_mini_bin, "parRF")
)
})
test_that("get_hyperparams_list throws error for unsupported method", {
expect_error(
get_hyperparams_list(otu_mini_bin, "not_a_method"),
"method 'not_a_method' is not supported."
)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.