Nothing
test_that('classification models', {
skip_if(!is_tab_pfn_installed())
skip_on_cran()
skip_if_not_installed("modeldata")
#-----------------------------------------------------------------------------
data(two_class_dat, package = "modeldata")
x_tr_df <- two_class_dat[1:20, 1:2]
x_tr_mat <- as.matrix(x_tr_df)
y_tr <- two_class_dat$Class[1:20]
x_te_df <- two_class_dat[21:23, 1:2]
x_te_mat <- as.matrix(x_te_df)
pred_ptype <-
tibble::tibble(
.pred_Class1 = numeric(0),
.pred_Class2 = numeric(0),
.pred_class = factor(0, levels = levels(two_class_dat$Class))
)
#-----------------------------------------------------------------------------
set.seed(956)
mod_df <- try(tab_pfn(x_tr_df, y_tr), silent = TRUE)
expect_s3_class(mod_df, exp_cls)
expect_snapshot(mod_df)
pred_df <- predict(mod_df, x_te_df)
expect_equal(pred_df[0, ], pred_ptype)
expect_equal(nrow(pred_df), 3L)
aug_df <- augment(mod_df, x_te_df)
expect_s3_class(aug_df, c("tbl_df", "tbl", "data.frame"))
expect_equal(nrow(aug_df), 3L)
expect_equal(ncol(aug_df), 5L)
#-----------------------------------------------------------------------------
set.seed(956)
mod_f <- try(tab_pfn(Class ~ ., data = two_class_dat[1:20, ]), silent = TRUE)
expect_s3_class(mod_f, exp_cls)
expect_snapshot(mod_f)
pred_f <- predict(mod_f, x_tr_df[21:23, ])
expect_equal(pred_f[0, ], pred_ptype)
expect_equal(nrow(pred_f), 3L)
aug_f <- augment(mod_f, x_tr_df[21:23, ])
expect_s3_class(aug_f, c("tbl_df", "tbl", "data.frame"))
expect_equal(nrow(aug_f), 3L)
expect_equal(ncol(aug_f), 5L)
#-----------------------------------------------------------------------------
set.seed(956)
mod_mat <- try(tab_pfn(x_tr_mat, y_tr), silent = TRUE)
expect_s3_class(mod_mat, exp_cls)
expect_snapshot(mod_mat)
pred_mat <- predict(mod_mat, x_te_mat)
expect_equal(pred_mat[0, ], pred_ptype)
expect_equal(nrow(pred_mat), 3L)
aug_mat <- augment(mod_mat, x_tr_df[21:23, ])
expect_s3_class(aug_mat, c("tbl_df", "tbl", "data.frame"))
expect_equal(nrow(aug_mat), 3L)
expect_equal(ncol(aug_mat), 5L)
})
test_that('classification models - recipes', {
skip_if(!is_tab_pfn_installed())
skip_on_cran()
skip_if_not_installed("modeldata")
skip_if_not_installed("recipes")
reticulate::import("torch")
library(tabpfn)
suppressPackageStartupMessages(library(recipes))
#-----------------------------------------------------------------------------
data(two_class_dat, package = "modeldata")
pred_ptype <-
tibble::tibble(
.pred_Class1 = numeric(0),
.pred_Class2 = numeric(0),
.pred_class = factor(0, levels = levels(two_class_dat$Class))
)
#-----------------------------------------------------------------------------
rec <-
recipe(Class ~ ., data = two_class_dat) |>
step_interact(~ A:B)
set.seed(956)
mod_rec <- try(tab_pfn(rec, two_class_dat[1:20, ]), silent = TRUE)
expect_s3_class(mod_rec, exp_cls)
expect_snapshot(mod_rec)
pred_rec <- predict(mod_rec, two_class_dat[50:52, ])
expect_equal(pred_rec[0, ], pred_ptype)
expect_equal(nrow(pred_rec), 3L)
aug_rec <- augment(mod_rec, two_class_dat[50:52, ])
expect_s3_class(aug_rec, c("tbl_df", "tbl", "data.frame"))
expect_equal(nrow(aug_rec), 3L)
expect_equal(ncol(aug_rec), 6L)
})
test_that('main options', {
skip_if(!is_tab_pfn_installed())
skip_on_cran()
set.seed(956)
expect_snapshot_error(
tab_pfn(Class ~ ., data = two_class_dat, num_estimators = "YES")
)
expect_snapshot_error(
tab_pfn(Class ~ ., data = two_class_dat, softmax_temperature = -1)
)
expect_snapshot_error(
tab_pfn(Class ~ ., data = two_class_dat, balance_probabilities = "nope")
)
expect_snapshot_error(
tab_pfn(Class ~ ., data = two_class_dat, average_before_softmax = "suuuure")
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.