Nothing
skip_if_no_torch = function() {
skip_if_not_installed("torch")
skip_if_not(torch::torch_is_installed(), "Torch backend not available")
}
iris_x = as.matrix(iris[, 2:4])
iris_y = iris$Sepal.Length
iris_cls_x = as.matrix(iris[, 1:4])
iris_cls_y = iris$Species
test_that("train_nn() dispatches on matrix, data.frame, and formula", {
skip_if_no_torch()
expect_s3_class(
train_nn(iris_x, iris_y, epochs = 5),
"nn_fit"
)
expect_s3_class(
train_nn(iris[, 2:4], iris_y, epochs = 5),
c("nn_fit_tab", "nn_fit"), exact = TRUE
)
expect_s3_class(
train_nn(Sepal.Length ~ ., data = iris[, 1:4], epochs = 5),
c("nn_fit_tab", "nn_fit"), exact = TRUE
)
})
# ---- Classification -----
test_that("train_nn() handles classification correctly", {
skip_if_no_torch()
m = train_nn(iris_cls_x, iris_cls_y, epochs = 5)
expect_true(m$is_classification)
expect_equal(m$y_levels, levels(iris_cls_y))
expect_s3_class(m$fitted, "factor")
expect_equal(levels(m$fitted), levels(iris_cls_y))
})
test_that("train_nn() errors on unsupported input types", {
skip_if_no_torch()
expect_error(train_nn("not valid"), class = "rlang_error")
expect_error(train_nn(Sepal.Length ~ ., data = NULL), class = "rlang_error")
expect_error(
train_nn(iris_x, iris_y, arch = list(nn_name = "bad"), epochs = 5),
class = "rlang_error"
)
})
test_that("train_nn() return object has correct structure", {
skip_if_no_torch()
m = train_nn(iris_x, iris_y, epochs = 10, validation_split = 0.2)
expect_named(m, c(
"model", "fitted", "loss_history", "val_loss_history",
"n_epochs", "stopped_epoch", "hidden_neurons", "activations",
"output_activation", "penalty", "mixture", "feature_names",
"response_name", "no_x", "no_y", "is_classification",
"y_levels", "n_classes", "device", "cached_weights", "arch"
), ignore.order = TRUE)
expect_length(m$loss_history, 10)
expect_length(m$val_loss_history, 10)
expect_equal(m$no_x, ncol(iris_x))
expect_null(m$cached_weights)
expect_true(is.na(m$stopped_epoch))
})
test_that("cache_weights stores weight matrices when TRUE", {
skip_if_no_torch()
m = train_nn(iris_x, iris_y, epochs = 5, cache_weights = TRUE)
expect_type(m$cached_weights, "list")
})
test_that("train_nn() accepts various act_funs() syntaxes", {
skip_if_no_torch()
expect_no_error(
train_nn(
iris_x,
iris_y,
= c(16, 8),
activations = act_funs(relu, ),
epochs = 5
)
)
expect_no_error(
train_nn(
iris_x,
iris_y,
= 16,
activations = act_funs(elu[alpha = 0.5]),
epochs = 5
)
)
expect_no_error(
train_nn(
iris_x,
iris_y,
= 16,
activations = act_funs(new_act_fn(\(x) torch::torch_tanh(x))),
epochs = 5
)
)
})
test_that("act_funs() errors on bad activation specs", {
skip_if_no_torch()
expect_error(act_funs(not_a_real_fn), class = "activation_not_found_error")
expect_error(act_funs(relu[bad_param = 1]), class = "purrr_error_indexed")
})
# ---- Loss functions ----
test_that("train_nn() accepts built-in and custom loss functions", {
skip_if_no_torch()
expect_no_error(train_nn(iris_x, iris_y, loss = "mae", epochs = 5))
expect_no_error(
train_nn(
iris_x,
iris_y,
loss = \(input, target) torch::nnf_mse_loss(input, target),
epochs = 5
)
)
expect_error(train_nn(iris_x, iris_y, loss = "not_a_loss", epochs = 5))
expect_error(
train_nn(iris_x, iris_y, loss = \(input, target) 42, epochs = 5),
class = "loss_fn_output_error"
)
expect_error(
train_nn(iris_x, iris_y, loss = \(x) torch::nnf_mse_loss(x, x), epochs = 5),
class = "loss_fn_arity_error"
)
})
# ---- Early stopping ----
describe("train_nn() early stopping", {
it("runs cleanly and trims loss_history when triggered", {
skip_if_no_torch()
# min_delta = 1e10 forces early stopping to fire reliably
es = early_stop(patience = 2, min_delta = 1e10, monitor = "val_loss")
m = train_nn(iris_x, iris_y, epochs = 50,
validation_split = 0.2, early_stopping = es)
expect_lt(length(m$loss_history), 50)
expect_false(is.na(m$stopped_epoch))
})
it("errors when val_loss monitor is used without validation_split", {
skip_if_no_torch()
expect_error(
train_nn(
iris_x,
iris_y,
epochs = 10,
early_stopping = early_stop(patience = 3, monitor = "val_loss")
),
class = "rlang_error"
)
})
it("errors when early_stopping is not an early_stop_spec", {
skip_if_no_torch()
expect_error(
train_nn(iris_x, iris_y, epochs = 5, early_stopping = list(patience = 5)),
class = "rlang_error"
)
})
})
test_that("predict.nn_fit() returns correct output types", {
skip_if_no_torch()
m_reg = train_nn(iris_x, iris_y, epochs = 5)
m_cls = train_nn(iris_cls_x, iris_cls_y, epochs = 5)
expect_equal(predict(m_reg), m_reg$fitted)
expect_type(predict(m_reg, newdata = iris_x), "double")
expect_s3_class(predict(m_cls, newdata = iris_cls_x), "factor")
probs = predict(m_cls, newdata = iris_cls_x, type = "prob")
expect_true(is.matrix(probs))
expect_equal(rowSums(probs), rep(1, nrow(iris_cls_x)), tolerance = 1e-5)
expect_error(predict(m_reg, newdata = iris_x, type = "prob"), class = "rlang_error")
expect_error(predict(m_reg, newdata = iris_x, type = "bad"), class = "rlang_error")
expect_error(predict(m_reg, newdata = iris_x, type = "good"), class = "rlang_error")
})
test_that("new_data is accepted as alias for newdata", {
skip_if_no_torch()
m = train_nn(iris_x, iris_y, epochs = 5)
expect_warning(predict(m, new_data = iris_x))
expect_equal(predict(m, newdata = iris_x), suppressWarnings(predict(m, new_data = iris_x)))
})
test_that("train_nn() handles edge case inputs", {
skip_if_no_torch()
m = train_nn(iris_x, iris_y, epochs = 5)
expect_length(predict(m, newdata = iris_x[1, , drop = FALSE]), 1)
expect_no_error(
train_nn(iris_x[1:10, ], iris_y[1:10], batch_size = 50, epochs = 5)
)
expect_length(train_nn(iris_x, iris_y, epochs = 1)$loss_history, 1)
m_multi = train_nn(as.matrix(iris[, 3:4]), as.matrix(iris[, 1:2]), epochs = 5)
expect_equal(m_multi$no_y, 2L)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.