tests/testthat/test-classification.R

test_that('classification models', {
  skip_if(!is_tab_pfn_installed())
  skip_on_cran()
  skip_if_not_installed("modeldata")

  #-----------------------------------------------------------------------------

  data(two_class_dat, package = "modeldata")
  x_tr_df <- two_class_dat[1:20, 1:2]
  x_tr_mat <- as.matrix(x_tr_df)
  y_tr <- two_class_dat$Class[1:20]

  x_te_df <- two_class_dat[21:23, 1:2]
  x_te_mat <- as.matrix(x_te_df)

  pred_ptype <-
    tibble::tibble(
      .pred_Class1 = numeric(0),
      .pred_Class2 = numeric(0),
      .pred_class = factor(0, levels = levels(two_class_dat$Class))
    )

  #-----------------------------------------------------------------------------

  set.seed(956)
  mod_df <- try(tab_pfn(x_tr_df, y_tr), silent = TRUE)
  expect_s3_class(mod_df, exp_cls)
  expect_snapshot(mod_df)

  pred_df <- predict(mod_df, x_te_df)
  expect_equal(pred_df[0, ], pred_ptype)
  expect_equal(nrow(pred_df), 3L)

  aug_df <- augment(mod_df, x_te_df)
  expect_s3_class(aug_df, c("tbl_df", "tbl", "data.frame"))
  expect_equal(nrow(aug_df), 3L)
  expect_equal(ncol(aug_df), 5L)

  #-----------------------------------------------------------------------------

  set.seed(956)
  mod_f <- try(tab_pfn(Class ~ ., data = two_class_dat[1:20, ]), silent = TRUE)
  expect_s3_class(mod_f, exp_cls)
  expect_snapshot(mod_f)

  pred_f <- predict(mod_f, x_tr_df[21:23, ])
  expect_equal(pred_f[0, ], pred_ptype)
  expect_equal(nrow(pred_f), 3L)

  aug_f <- augment(mod_f, x_tr_df[21:23, ])
  expect_s3_class(aug_f, c("tbl_df", "tbl", "data.frame"))
  expect_equal(nrow(aug_f), 3L)
  expect_equal(ncol(aug_f), 5L)

  #-----------------------------------------------------------------------------

  set.seed(956)
  mod_mat <- try(tab_pfn(x_tr_mat, y_tr), silent = TRUE)
  expect_s3_class(mod_mat, exp_cls)
  expect_snapshot(mod_mat)

  pred_mat <- predict(mod_mat, x_te_mat)
  expect_equal(pred_mat[0, ], pred_ptype)
  expect_equal(nrow(pred_mat), 3L)

  aug_mat <- augment(mod_mat, x_tr_df[21:23, ])
  expect_s3_class(aug_mat, c("tbl_df", "tbl", "data.frame"))
  expect_equal(nrow(aug_mat), 3L)
  expect_equal(ncol(aug_mat), 5L)
})

test_that('classification models - recipes', {
  skip_if(!is_tab_pfn_installed())
  skip_on_cran()
  skip_if_not_installed("modeldata")
  skip_if_not_installed("recipes")

  reticulate::import("torch")

  library(tabpfn)
  suppressPackageStartupMessages(library(recipes))

  #-----------------------------------------------------------------------------

  data(two_class_dat, package = "modeldata")

  pred_ptype <-
    tibble::tibble(
      .pred_Class1 = numeric(0),
      .pred_Class2 = numeric(0),
      .pred_class = factor(0, levels = levels(two_class_dat$Class))
    )

  #-----------------------------------------------------------------------------

  rec <-
    recipe(Class ~ ., data = two_class_dat) |>
    step_interact(~ A:B)

  set.seed(956)
  mod_rec <- try(tab_pfn(rec, two_class_dat[1:20, ]), silent = TRUE)
  expect_s3_class(mod_rec, exp_cls)
  expect_snapshot(mod_rec)

  pred_rec <- predict(mod_rec, two_class_dat[50:52, ])
  expect_equal(pred_rec[0, ], pred_ptype)
  expect_equal(nrow(pred_rec), 3L)

  aug_rec <- augment(mod_rec, two_class_dat[50:52, ])
  expect_s3_class(aug_rec, c("tbl_df", "tbl", "data.frame"))
  expect_equal(nrow(aug_rec), 3L)
  expect_equal(ncol(aug_rec), 6L)
})

test_that('main options', {
  skip_if(!is_tab_pfn_installed())
  skip_on_cran()
  set.seed(956)
  expect_snapshot_error(
    tab_pfn(Class ~ ., data = two_class_dat, num_estimators = "YES")
  )
  expect_snapshot_error(
    tab_pfn(Class ~ ., data = two_class_dat, softmax_temperature = -1)
  )
  expect_snapshot_error(
    tab_pfn(Class ~ ., data = two_class_dat, balance_probabilities = "nope")
  )
  expect_snapshot_error(
    tab_pfn(Class ~ ., data = two_class_dat, average_before_softmax = "suuuure")
  )
})

Try the tabpfn package in your browser

Any scripts or data that you put into this service are public.

tabpfn documentation built on March 18, 2026, 5:07 p.m.