Nothing
formula_bin <- "psych_well_bin ~ age + gender + depression"
formula_reg <- "psych_well ~ age + gender + depression"
analysis_object_reg = preprocessing(df = sim_data, formula_reg, task = "regression")
analysis_object_bin = preprocessing(df = sim_data, formula_bin, task = "classification")
hyper_nn_tune_list = list(
learn_rate = c(-2, -1),
hidden_units = c(3,10)
)
hyper_rf_tune_list <- list(
mtry = c(2,6),
trees = 100
)
model_object_bin = build_model(analysis_object = analysis_object_bin,
model_name = "Neural Network",
hyperparameters = hyper_nn_tune_list)
model_object_reg = build_model(analysis_object = analysis_object_reg,
model_name = "Random Forest",
hyperparameters = hyper_rf_tune_list)
### create_workflow
test_that("Check create_workflow works properly", {
workflow_bin = create_workflow(model_object_bin)
workflow_reg = create_workflow(model_object_reg)
expect_equal(class(workflow_bin$pre$actions$recipe), c("action_recipe", "action_pre", "action"))
expect_equal(class(workflow_reg$pre$actions$recipe), c("action_recipe", "action_pre", "action"))
})
###### split_data
test_that("Check split_data works properly", {
model_object_bin$modify("tuner", "Bayesian Optimization")
model_object_reg$modify("tuner", "Grid Search CV")
split_data_bin <- split_data(model_object_bin)
split_data_reg <- split_data(model_object_reg)
expect_equal(class(split_data_bin$sampling_method$splits[[1]]), c("val_split", "rsplit"))
expect_equal(split_data_bin$final_split, rbind(model_object_bin$train_data, model_object_bin$validation_data))
expect_equal(class(split_data_reg$sampling_method$splits[[1]]), c("vfold_split", "rsplit") )
expect_equal(split_data_reg$final_split, model_object_reg$train_data)
})
##### create_metric_set
test_that("Test create_metric_set works properly", {
metrics_bin = c("roc_auc", "accuracy")
metrics_reg = c("rmse", "ccc")
model_object_bin$modify("metrics", metrics_bin)
model_object_reg$modify("metrics", metrics_reg)
metric_set_bin <- create_metric_set(model_object_bin$metrics)
metric_set_reg <- create_metric_set(model_object_reg$metrics)
expect_equal(class(metric_set_bin), c("class_prob_metric_set", "metric_set", "function"))
expect_equal(class(metric_set_reg), c("numeric_metric_set", "metric_set", "function"))
})
###### extract_hyperparams
test_that("Test extract_hyperparams works properly", {
model_object_bin2 = model_object_bin$clone()
model_object_reg2 = model_object_reg$clone()
model_object_bin2$modify("workflow", create_workflow(model_object_bin))
model_object_reg2$modify("workflow", create_workflow(model_object_reg))
extracted_hyp_bin <- extract_hyperparams(model_object_bin2)
extracted_hyp_reg <- extract_hyperparams(model_object_reg2)
expect_equal(extracted_hyp_bin$name, c("hidden_units", "activation", "learn_rate"))
expect_equal(length(extracted_hyp_bin$object), 3)
expect_equal(extracted_hyp_reg$name, c("mtry", "min_n"))
expect_equal(length(extracted_hyp_reg$object), 2)
})
##### hyperparams_grid
test_that("Test hyperparams_grid works properly", {
hyp_rf = HyperparamsRF$new(hyper_rf_tune_list)
grid = hyperparams_grid(hyp_rf, levels = 5)
expect_equal(names(grid), c("mtry", "min_n"))
expect_equal(nrow(grid), 25)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.