test_hyperparams <- data.frame(
param = c(
"lambda", "lambda", "lambda", "alpha",
"sigma", "sigma", "C", "C",
"maxdepth", "maxdepth",
"nrounds", "gamma", "eta", "max_depth", "colsample_bytree", "min_child_weight", "subsample",
"mtry", "mtry"
),
value = c(
"1e-3", "1e-2", "1e-1", "1",
"0.00000001", "0.0000001", "0.01", "0.1",
"1", "2",
"10", "0", "0.01", "1", "0.8", "1", "0.4",
"1", "2"
),
method = c(
"glmnet", "glmnet", "glmnet", "glmnet",
"svmRadial", "svmRadial", "svmRadial", "svmRadial",
"rpart2", "rpart2",
"xgbTree", "xgbTree", "xgbTree", "xgbTree", "xgbTree", "xgbTree", "xgbTree",
"rf", "rf"
)
)
tg_rpart2 <- get_tuning_grid(get_hyperparams_from_df(test_hyperparams, "rpart2"), "rpart2")
tg_rf <- get_tuning_grid(get_hyperparams_from_df(test_hyperparams, "rf"), "rf")
tg_lr <- get_tuning_grid(get_hyperparams_from_df(test_hyperparams, "glmnet"), "glmnet")
train_dat <- otu_mini_bin_results_glmnet$trained_model$trainingData %>%
dplyr::rename(dx = .outcome)
hparams_list <- list(lambda = c("1e-3", "1e-2", "1e-1"), alpha = "0.01")
cv <- define_cv(
train_dat,
"dx",
hparams_list,
perf_metric_function = caret::multiClassSummary,
class_probs = TRUE,
cv_times = 2
)
test_that("train_model works", {
skip_on_cran() # this functionality is already tested in test-run_ml.R
set.seed(2019)
rf_model <- train_model(
train_dat,
"dx",
method = "rf",
cv = define_cv(
train_dat,
"dx",
get_hyperparams_list(train_dat, "rf"),
perf_metric_function = caret::multiClassSummary,
class_probs = TRUE,
cv_times = 2
),
perf_metric_name = "AUC",
tune_grid = tg_rf,
ntree = 1000,
weights = NULL
)
auc <- rf_model$results %>%
dplyr::filter(mtry == rf_model$bestTune$mtry) %>%
dplyr::pull(AUC)
expect_true(dplyr::near(auc, 0.671, tol = 10^-3))
set.seed(2019)
expect_equal(
train_model(
train_dat,
"dx",
method = "glmnet",
cv = cv,
perf_metric_name = "AUC",
tune_grid = tg_lr,
ntree = NULL
)$bestTune$lambda,
0.01
) %>% expect_warning("`caret::train\\(\\)` issued the following warning:")
})
test_that("case weights work", {
case_weights_dat <- train_dat %>%
dplyr::count(dx) %>%
dplyr::mutate(p = n / sum(n)) %>%
dplyr::select(dx, p)
case_weights_vctr <- train_dat %>%
dplyr::inner_join(case_weights_dat, by = "dx") %>%
dplyr::pull(p)
expect_warning(
lr_model_weighted <- train_model(
train_dat,
"dx",
method = "glmnet",
cv = cv,
perf_metric_name = "AUC",
tune_grid = tg_lr,
ntree = NULL,
weights = case_weights_vctr
),
"`caret::train\\(\\)` issued the following warning:"
)
model_weights <- lr_model_weighted$pred %>%
dplyr::select(obs, weights) %>%
dplyr::distinct() %>%
dplyr::rename(dx = obs, p = weights)
expect_true(dplyr::all_equal(
model_weights,
case_weights_dat
))
expect_false("weights" %in% colnames(otu_mini_bin_results_glmnet$trained_model$pred))
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.