Nothing
skip_if_no_torch = function() {
testthat::skip_if_not_installed("torch")
testthat::skip_if_not(torch::torch_is_installed(), "Torch backend not available")
}
make_dataset = function(n = 10L) {
force(n)
torch::dataset(
name = "toy_cls_dataset",
initialize = function() {
self$x = torch::torch_tensor(
matrix(rnorm(n * 4L), ncol = 4L),
dtype = torch::torch_float32()
)
self$y = torch::torch_tensor(
rep(c(1L, 2L), length.out = n),
dtype = torch::torch_long()
)
},
.getitem = function(i) {
list(self$x[i, ], self$y[i])
},
.length = function() {
self$x$size(1)
}
)()
}
test_that("dataset method trains and predicts", {
skip_if_no_torch()
set.seed(1)
ds = make_dataset(10L)
fit = kindling::train_nn(
x = ds,
= c(8L),
epochs = 2,
batch_size = 4,
learn_rate = 0.01,
n_classes = 2,
verbose = FALSE
)
expect_s3_class(fit, "nn_fit_ds")
cls_pred = stats::predict(fit, ds)
prob_pred = stats::predict(fit, ds, type = "prob")
expect_s3_class(cls_pred, "factor")
expect_length(cls_pred, length(ds))
expect_equal(dim(prob_pred), c(length(ds), 2L))
expect_equal(colnames(prob_pred), c("1", "2"))
})
test_that("dataset classification requires n_classes", {
skip_if_no_torch()
set.seed(2)
ds = make_dataset(8L)
expect_error(
kindling::train_nn(
x = ds,
= c(4L),
epochs = 1,
batch_size = 4,
learn_rate = 0.01,
verbose = FALSE
),
"n_classes"
)
})
test_that("dataset validation split guards against empty partitions", {
skip_if_no_torch()
set.seed(3)
ds = make_dataset(1L)
expect_error(
kindling::train_nn(
x = ds,
= c(4L),
epochs = 1,
batch_size = 1,
learn_rate = 0.01,
n_classes = 2,
validation_split = 0.5,
verbose = FALSE
),
"empty train/validation"
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.