tests/testthat/test-dataset-method.R

skip_if_no_torch = function() {
    testthat::skip_if_not_installed("torch")
    testthat::skip_if_not(torch::torch_is_installed(), "Torch backend not available")
}

make_dataset = function(n = 10L) {
    force(n)

    torch::dataset(
        name = "toy_cls_dataset",
        initialize = function() {
            self$x = torch::torch_tensor(
                matrix(rnorm(n * 4L), ncol = 4L),
                dtype = torch::torch_float32()
            )
            self$y = torch::torch_tensor(
                rep(c(1L, 2L), length.out = n),
                dtype = torch::torch_long()
            )
        },
        .getitem = function(i) {
            list(self$x[i, ], self$y[i])
        },
        .length = function() {
            self$x$size(1)
        }
    )()
}

test_that("dataset method trains and predicts", {
    skip_if_no_torch()

    set.seed(1)
    ds = make_dataset(10L)

    fit = kindling::train_nn(
        x = ds,
        hidden_neurons = c(8L),
        epochs = 2,
        batch_size = 4,
        learn_rate = 0.01,
        n_classes = 2,
        verbose = FALSE
    )

    expect_s3_class(fit, "nn_fit_ds")

    cls_pred = stats::predict(fit, ds)
    prob_pred = stats::predict(fit, ds, type = "prob")

    expect_s3_class(cls_pred, "factor")
    expect_length(cls_pred, length(ds))
    expect_equal(dim(prob_pred), c(length(ds), 2L))
    expect_equal(colnames(prob_pred), c("1", "2"))
})

test_that("dataset classification requires n_classes", {
    skip_if_no_torch()

    set.seed(2)
    ds = make_dataset(8L)

    expect_error(
        kindling::train_nn(
            x = ds,
            hidden_neurons = c(4L),
            epochs = 1,
            batch_size = 4,
            learn_rate = 0.01,
            verbose = FALSE
        ),
        "n_classes"
    )
})

test_that("dataset validation split guards against empty partitions", {
    skip_if_no_torch()

    set.seed(3)
    ds = make_dataset(1L)

    expect_error(
        kindling::train_nn(
            x = ds,
            hidden_neurons = c(4L),
            epochs = 1,
            batch_size = 1,
            learn_rate = 0.01,
            n_classes = 2,
            validation_split = 0.5,
            verbose = FALSE
        ),
        "empty train/validation"
    )
})

Try the kindling package in your browser

Any scripts or data that you put into this service are public.

kindling documentation built on March 3, 2026, 9:07 a.m.