tests/testthat/test-train-nn.R

skip_if_no_torch = function() {
    skip_if_not_installed("torch")
    skip_if_not(torch::torch_is_installed(), "Torch backend not available")
}

iris_x = as.matrix(iris[, 2:4])
iris_y = iris$Sepal.Length
iris_cls_x = as.matrix(iris[, 1:4])
iris_cls_y = iris$Species

test_that("train_nn() dispatches on matrix, data.frame, and formula", {
    skip_if_no_torch()
    expect_s3_class(
        train_nn(iris_x, iris_y, epochs = 5),
        "nn_fit"
    )
    expect_s3_class(
        train_nn(iris[, 2:4], iris_y, epochs = 5),
        c("nn_fit_tab", "nn_fit"), exact = TRUE
    )
    expect_s3_class(
        train_nn(Sepal.Length ~ ., data = iris[, 1:4], epochs = 5),
        c("nn_fit_tab", "nn_fit"), exact = TRUE
    )
})

# ---- Classification -----

test_that("train_nn() handles classification correctly", {
    skip_if_no_torch()
    m = train_nn(iris_cls_x, iris_cls_y, epochs = 5)
    expect_true(m$is_classification)
    expect_equal(m$y_levels, levels(iris_cls_y))
    expect_s3_class(m$fitted, "factor")
    expect_equal(levels(m$fitted), levels(iris_cls_y))
})

test_that("train_nn() errors on unsupported input types", {
    skip_if_no_torch()
    expect_error(train_nn("not valid"), class = "rlang_error")
    expect_error(train_nn(Sepal.Length ~ ., data = NULL), class = "rlang_error")
    expect_error(
        train_nn(iris_x, iris_y, arch = list(nn_name = "bad"), epochs = 5),
        class = "rlang_error"
    )
})

test_that("train_nn() return object has correct structure", {
    skip_if_no_torch()
    m = train_nn(iris_x, iris_y, epochs = 10, validation_split = 0.2)
    expect_named(m, c(
        "model", "fitted", "loss_history", "val_loss_history",
        "n_epochs", "stopped_epoch", "hidden_neurons", "activations",
        "output_activation", "penalty", "mixture", "feature_names",
        "response_name", "no_x", "no_y", "is_classification",
        "y_levels", "n_classes", "device", "cached_weights", "arch"
    ), ignore.order = TRUE)
    expect_length(m$loss_history, 10)
    expect_length(m$val_loss_history, 10)
    expect_equal(m$no_x, ncol(iris_x))
    expect_null(m$cached_weights)
    expect_true(is.na(m$stopped_epoch))
})

test_that("cache_weights stores weight matrices when TRUE", {
    skip_if_no_torch()
    m = train_nn(iris_x, iris_y, epochs = 5, cache_weights = TRUE)
    expect_type(m$cached_weights, "list")
})

test_that("train_nn() accepts various act_funs() syntaxes", {
    skip_if_no_torch()
    expect_no_error(
        train_nn(
            iris_x, 
            iris_y, 
            hidden_neurons = c(16, 8),
            activations = act_funs(relu, ), 
            epochs = 5
        )
    )
    expect_no_error(
        train_nn(
            iris_x, 
            iris_y, 
            hidden_neurons = 16,
            activations = act_funs(elu[alpha = 0.5]), 
            epochs = 5
        )
    )
    expect_no_error(
        train_nn(
            iris_x, 
            iris_y, 
            hidden_neurons = 16,
            activations = act_funs(new_act_fn(\(x) torch::torch_tanh(x))),
            epochs = 5
        )
    )
})

test_that("act_funs() errors on bad activation specs", {
    skip_if_no_torch()
    expect_error(act_funs(not_a_real_fn), class = "activation_not_found_error")
    expect_error(act_funs(relu[bad_param = 1]), class = "purrr_error_indexed")
})

# ---- Loss functions ----

test_that("train_nn() accepts built-in and custom loss functions", {
    skip_if_no_torch()
    expect_no_error(train_nn(iris_x, iris_y, loss = "mae", epochs = 5))
    expect_no_error(
        train_nn(
            iris_x, 
            iris_y,
            loss = \(input, target) torch::nnf_mse_loss(input, target),
            epochs = 5
        )
    )
    expect_error(train_nn(iris_x, iris_y, loss = "not_a_loss", epochs = 5))
    expect_error(
        train_nn(iris_x, iris_y, loss = \(input, target) 42, epochs = 5),
        class = "loss_fn_output_error"
    )
    expect_error(
        train_nn(iris_x, iris_y, loss = \(x) torch::nnf_mse_loss(x, x), epochs = 5),
        class = "loss_fn_arity_error"
    )
})

# ---- Early stopping ----

describe("train_nn() early stopping", {
    it("runs cleanly and trims loss_history when triggered", {
        skip_if_no_torch()
        # min_delta = 1e10 forces early stopping to fire reliably
        es = early_stop(patience = 2, min_delta = 1e10, monitor = "val_loss")
        m = train_nn(iris_x, iris_y, epochs = 50,
                     validation_split = 0.2, early_stopping = es)
        expect_lt(length(m$loss_history), 50)
        expect_false(is.na(m$stopped_epoch))
    })
    
    it("errors when val_loss monitor is used without validation_split", {
        skip_if_no_torch()
        expect_error(
            train_nn(
                iris_x, 
                iris_y, 
                epochs = 10,
                early_stopping = early_stop(patience = 3, monitor = "val_loss")
            ),
            class = "rlang_error"
        )
    })
    
    it("errors when early_stopping is not an early_stop_spec", {
        skip_if_no_torch()
        expect_error(
            train_nn(iris_x, iris_y, epochs = 5, early_stopping = list(patience = 5)),
            class = "rlang_error"
        )
    })
})

test_that("predict.nn_fit() returns correct output types", {
    skip_if_no_torch()
    m_reg = train_nn(iris_x, iris_y, epochs = 5)
    m_cls = train_nn(iris_cls_x, iris_cls_y, epochs = 5)
    
    expect_equal(predict(m_reg), m_reg$fitted)
    expect_type(predict(m_reg, newdata = iris_x), "double")
    expect_s3_class(predict(m_cls, newdata = iris_cls_x), "factor")
    
    probs = predict(m_cls, newdata = iris_cls_x, type = "prob")
    expect_true(is.matrix(probs))
    expect_equal(rowSums(probs), rep(1, nrow(iris_cls_x)), tolerance = 1e-5)
    
    expect_error(predict(m_reg, newdata = iris_x, type = "prob"), class = "rlang_error")
    expect_error(predict(m_reg, newdata = iris_x, type = "bad"), class = "rlang_error")
    expect_error(predict(m_reg, newdata = iris_x, type = "good"), class = "rlang_error")
})

test_that("new_data is accepted as alias for newdata", {
    skip_if_no_torch()
    m = train_nn(iris_x, iris_y, epochs = 5)
    expect_warning(predict(m, new_data = iris_x))
    expect_equal(predict(m, newdata = iris_x), suppressWarnings(predict(m, new_data = iris_x)))
})

test_that("train_nn() handles edge case inputs", {
    skip_if_no_torch()
    m = train_nn(iris_x, iris_y, epochs = 5)
    expect_length(predict(m, newdata = iris_x[1, , drop = FALSE]), 1)
    
    expect_no_error(
        train_nn(iris_x[1:10, ], iris_y[1:10], batch_size = 50, epochs = 5)
    )
    
    expect_length(train_nn(iris_x, iris_y, epochs = 1)$loss_history, 1)
    
    m_multi = train_nn(as.matrix(iris[, 3:4]), as.matrix(iris[, 1:2]), epochs = 5)
    expect_equal(m_multi$no_y, 2L)
})

Try the kindling package in your browser

Any scripts or data that you put into this service are public.

kindling documentation built on March 3, 2026, 9:07 a.m.