test_that("basic binomial mlp LBFGS", {
skip_if_not(torch::torch_is_installed())
skip_if_not_installed("modeldata")
skip_if_not_installed("yardstick")
suppressPackageStartupMessages(library(dplyr))
suppressPackageStartupMessages(library(recipes))
# ------------------------------------------------------------------------------
set.seed(585)
bin_tr <- modeldata::sim_classification(5000)
bin_te <- modeldata::sim_classification(1000)
rec <-
recipe(class ~ ., data = bin_tr) %>%
step_normalize(all_predictors())
num_class <- length(levels(bin_tr$class))
# ------------------------------------------------------------------------------
expect_error({
set.seed(392)
bin_fit_f_lbfgs <-
brulee_mlp(class ~ .,
bin_tr,
epochs = 200,
hidden_units = 5,
rate_schedule = "cyclic",
learn_rate = 0.1)},
regex = NA)
expect_error({
set.seed(392)
bin_fit_lbfgs <-
brulee_mlp(rec,
bin_tr,
epochs = 200,
hidden_units = 5,
rate_schedule = "cyclic",
learn_rate = 0.1)},
regex = NA)
expect_error(
bin_pred_lbfgs <-
predict(bin_fit_lbfgs, bin_te) %>%
bind_cols(predict(bin_fit_lbfgs, bin_te, type = "prob")) %>%
bind_cols(bin_te) %>%
select(starts_with(".pred"), class),
regex = NA)
fact_str <- structure(integer(0), levels = c("class_1", "class_2"), class = "factor")
exp_str <-
structure(
list(.pred_class =
fact_str,
.pred_class_1 = numeric(0),
.pred_class_2 = numeric(0),
class = fact_str),
row.names = integer(0),
class = c("tbl_df", "tbl", "data.frame"))
expect_equal(bin_pred_lbfgs[0,], exp_str)
expect_equal(nrow(bin_pred_lbfgs), nrow(bin_te))
# Did it learn anything?
bin_brier_lbfgs <-
bin_pred_lbfgs %>%
yardstick::brier_class(class, .pred_class_1)
expect_true(bin_brier_lbfgs$.estimate < (1 - 1/num_class)^2)
})
test_that("basic binomial mlp SGD", {
skip_if_not(torch::torch_is_installed())
skip_if_not_installed("modeldata")
skip_if_not_installed("yardstick")
suppressPackageStartupMessages(library(dplyr))
suppressPackageStartupMessages(library(recipes))
# ------------------------------------------------------------------------------
set.seed(585)
bin_tr <- modeldata::sim_classification(5000)
bin_te <- modeldata::sim_classification(1000)
rec <-
recipe(class ~ ., data = bin_tr) %>%
step_normalize(all_predictors())
num_class <- length(levels(bin_tr$class))
# ------------------------------------------------------------------------------
expect_error({
set.seed(392)
bin_fit_f_sgd <-
brulee_mlp(class ~ .,
bin_tr,
epochs = 200,
penalty = 0,
dropout = .1,
hidden_units = 5,
optimize = "SGD",
batch_size = 64,
momentum = 0.5,
learn_rate = 0.1)},
regex = NA)
expect_error({
set.seed(392)
bin_fit_sgd <-
brulee_mlp(rec,
bin_tr,
epochs = 200,
penalty = 0,
dropout = .1,
hidden_units = 5,
optimize = "SGD",
batch_size = 64,
momentum = 0.5,
learn_rate = 0.1)},
regex = NA)
expect_error(
bin_pred_sgd <-
predict(bin_fit_sgd, bin_te) %>%
bind_cols(predict(bin_fit_sgd, bin_te, type = "prob")) %>%
bind_cols(bin_te) %>%
select(starts_with(".pred"), class),
regex = NA)
fact_str <- structure(integer(0), levels = c("class_1", "class_2"), class = "factor")
exp_str <-
structure(
list(.pred_class =
fact_str,
.pred_class_1 = numeric(0),
.pred_class_2 = numeric(0),
class = fact_str),
row.names = integer(0),
class = c("tbl_df", "tbl", "data.frame"))
expect_equal(bin_pred_sgd[0,], exp_str)
expect_equal(nrow(bin_pred_sgd), nrow(bin_te))
# Did it learn anything?
bin_brier_sgd <-
bin_pred_sgd %>%
yardstick::brier_class(class, .pred_class_1)
expect_true(bin_brier_sgd$.estimate < (1 - 1/num_class)^2)
})
test_that("binomial mlp case weights", {
skip_if_not(torch::torch_is_installed())
skip_if_not_installed("modeldata")
skip_if_not_installed("yardstick")
suppressPackageStartupMessages(library(dplyr))
suppressPackageStartupMessages(library(recipes))
# ------------------------------------------------------------------------------
set.seed(585)
bin_tr <- modeldata::sim_classification(5000, intercept = 1)
bin_te <- modeldata::sim_classification(1000, intercept = 1)
rec <-
recipe(class ~ ., data = bin_tr) %>%
step_normalize(all_predictors())
num_class <- length(levels(bin_tr$class))
# ------------------------------------------------------------------------------
expect_error({
set.seed(392)
weighted <-
brulee_mlp(rec,
bin_tr,
epochs = 200,
hidden_units = 5,
rate_schedule = "cyclic",
class_weights = 10,
learn_rate = 0.1)},
regex = NA)
expect_error(
weighted_pred <-
predict(weighted, bin_te) %>%
bind_cols(predict(weighted, bin_te, type = "prob")) %>%
bind_cols(bin_te) %>%
select(starts_with(".pred"), class),
regex = NA)
expect_error({
set.seed(392)
unweighted <-
brulee_mlp(rec,
bin_tr,
epochs = 200,
hidden_units = 5,
rate_schedule = "cyclic",
learn_rate = 0.1)},
regex = NA)
expect_error(
unweighted_pred <-
predict(unweighted, bin_te) %>%
bind_cols(predict(unweighted, bin_te, type = "prob")) %>%
bind_cols(bin_te) %>%
select(starts_with(".pred"), class),
regex = NA)
expect_true(
sum(weighted_pred$.pred_class == "class_2") >
sum(unweighted_pred$.pred_class == "class_2")
)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.