tests/testthat/test-partykit.R

withr::local_envvar("OMP_THREAD_LIMIT" = 1)

test_that("condition inference trees", {
  skip_if_not_installed("partykit")
  skip_if_not_installed("modeldata")

  suppressPackageStartupMessages(library(partykit))

  expect_snapshot(
    decision_tree() |> set_engine("partykit") |> set_mode("regression")
  )
  expect_snapshot(
    decision_tree() |>
      set_engine("partykit", teststat = "maximum") |>
      set_mode("classification")
  )

  # ----------------------------------------------------------------------------
  # regression

  expect_no_error({
    ct_fit_1 <-
      decision_tree() |>
      set_engine("partykit") |>
      set_mode("regression") |>
      fit(mpg ~ ., data = mtcars)
  })
  pk_fit_1 <- ctree(mpg ~ ., data = mtcars)
  expect_equal(pk_fit_1$fitted, ct_fit_1$fit$fitted)

  expect_no_error(ct_pred_1 <- predict(ct_fit_1, mtcars)$.pred)
  pk_pred_1 <- unname(predict(pk_fit_1, mtcars))
  expect_equal(pk_pred_1, ct_pred_1)

  expect_no_error({
    ct_fit_2 <-
      decision_tree(tree_depth = 1) |>
      set_engine("partykit") |>
      set_mode("regression") |>
      fit(mpg ~ ., data = mtcars)
  })
  pk_fit_2 <- ctree(
    mpg ~ .,
    data = mtcars,
    control = ctree_control(maxdepth = 1)
  )
  expect_equal(pk_fit_2$fitted, ct_fit_2$fit$fitted)

  expect_no_error(ct_pred_2 <- predict(ct_fit_2, mtcars)$.pred)
  pk_pred_2 <- unname(predict(pk_fit_2, mtcars))
  expect_equal(pk_pred_2, ct_pred_2)

  expect_no_error({
    ct_fit_3 <-
      decision_tree() |>
      set_engine("partykit", mincriterion = .99) |>
      set_mode("regression") |>
      fit(mpg ~ ., data = mtcars)
  })
  pk_fit_3 <- ctree(
    mpg ~ .,
    data = mtcars,
    control = ctree_control(mincriterion = .99)
  )
  expect_equal(pk_fit_3$fitted, ct_fit_3$fit$fitted)

  expect_no_error(ct_pred_3 <- predict(ct_fit_3, mtcars)$.pred)
  pk_pred_3 <- unname(predict(pk_fit_3, mtcars))
  expect_equal(pk_pred_3, ct_pred_3)

  # ----------------------------------------------------------------------------
  # classification

  data(ad_data, package = "modeldata")

  expect_no_error({
    ct_fit_4 <-
      decision_tree() |>
      set_engine("partykit") |>
      set_mode("classification") |>
      fit(Class ~ ., data = ad_data)
  })
  pk_fit_4 <- ctree(Class ~ ., data = ad_data)
  expect_equal(pk_fit_4$fitted, ct_fit_4$fit$fitted)

  expect_no_error(ct_pred_4 <- predict(ct_fit_4, ad_data)$.pred_class)
  pk_pred_4 <- unname(predict(pk_fit_4, ad_data))
  expect_equal(pk_pred_4, ct_pred_4)

  expect_no_error(ct_prob_4 <- predict(ct_fit_4, ad_data, type = "prob")[[2]])
  pk_prob_4 <- unname(predict(pk_fit_4, ad_data, type = "prob")[, 2])
  expect_equal(pk_prob_4, ct_prob_4)
})


test_that("condition inference forests", {
  skip_if_not_installed("partykit")
  skip_if_not_installed("modeldata")

  suppressPackageStartupMessages(library(partykit))

  expect_snapshot(
    rand_forest() |> set_engine("partykit") |> set_mode("regression")
  )
  expect_snapshot(
    rand_forest() |>
      set_engine("partykit", teststat = "maximum") |>
      set_mode("classification")
  )

  # ----------------------------------------------------------------------------
  # regression

  expect_no_error({
    set.seed(1)
    cf_fit_1 <-
      rand_forest(trees = 5) |>
      set_engine("partykit") |>
      set_mode("regression") |>
      fit(mpg ~ ., data = mtcars)
  })
  set.seed(1)
  pk_fit_1 <- cforest(mpg ~ ., data = mtcars, ntree = 5)
  expect_equal(pk_fit_1$fitted, cf_fit_1$fit$fitted)

  expect_no_error(cf_pred_1 <- predict(cf_fit_1, mtcars)$.pred)
  pk_pred_1 <- unname(predict(pk_fit_1, mtcars))
  expect_equal(pk_pred_1, cf_pred_1)

  expect_no_error({
    set.seed(1)
    cf_fit_2 <-
      rand_forest(trees = 5, mtry = 2) |>
      set_engine("partykit") |>
      set_mode("regression") |>
      fit(mpg ~ ., data = mtcars)
  })
  set.seed(1)
  pk_fit_2 <- cforest(mpg ~ ., data = mtcars, ntree = 5, mtry = 2)
  expect_equal(pk_fit_2$fitted, cf_fit_2$fit$fitted)

  expect_no_error(cf_pred_2 <- predict(cf_fit_2, mtcars)$.pred)
  pk_pred_2 <- unname(predict(pk_fit_2, mtcars))
  expect_equal(pk_pred_2, cf_pred_2)

  expect_no_error({
    set.seed(1)
    cf_fit_3 <-
      rand_forest(trees = 5) |>
      set_engine("partykit", mincriterion = .99) |>
      set_mode("regression") |>
      fit(mpg ~ ., data = mtcars)
  })
  set.seed(1)
  pk_fit_3 <- cforest(
    mpg ~ .,
    data = mtcars,
    ntree = 5,
    control = ctree_control(mincriterion = .99)
  )
  expect_equal(pk_fit_3$fitted, cf_fit_3$fit$fitted)

  expect_no_error(cf_pred_3 <- predict(cf_fit_3, mtcars)$.pred)
  pk_pred_3 <- unname(predict(pk_fit_3, mtcars))
  expect_equal(pk_pred_3, cf_pred_3)

  # ----------------------------------------------------------------------------
  # classification

  data(ad_data, package = "modeldata")

  expect_no_error({
    set.seed(1)
    cf_fit_4 <-
      rand_forest(trees = 5) |>
      set_engine("partykit") |>
      set_mode("classification") |>
      fit(Class ~ ., data = ad_data)
  })
  set.seed(1)
  pk_fit_4 <- cforest(Class ~ ., data = ad_data, ntree = 5)
  expect_equal(pk_fit_4$fitted, cf_fit_4$fit$fitted)

  expect_no_error(cf_pred_4 <- predict(cf_fit_4, ad_data)$.pred_class)
  pk_pred_4 <- unname(predict(pk_fit_4, ad_data))
  expect_equal(pk_pred_4, cf_pred_4)

  expect_no_error(cf_prob_4 <- predict(cf_fit_4, ad_data, type = "prob")[[2]])
  pk_prob_4 <- unname(predict(pk_fit_4, ad_data, type = "prob")[, 2])
  expect_equal(pk_prob_4, cf_prob_4)
})
tidymodels/bonsai documentation built on July 3, 2025, 7:35 p.m.