tests/testthat/test_partykit_surv_ctree.R

test_that("autotest", {
  with_seed(42, {
    learner = lrn("surv.ctree")
    expect_learner(learner)
    result = run_autotest(learner, check_replicable = FALSE)
    expect_true(result, info = result$error)
  })
})

test_that("correct prediction types", {
  with_seed(42, {
    task = tsk("rats")$filter(sample(1:300, 50))
    part = partition(task, ratio = 0.9)
    train_rows = part$train
    test_rows = part$test
    unique_times = task$unique_times(train_rows)

    learner = lrn("surv.ctree")
    p = learner$train(task, train_rows)$predict(task, test_rows)
    expect_matrix(p$data$distr, nrows = length(test_rows),
                  max.cols = length(unique_times))
    expect_numeric(p$crank, len = length(test_rows))
  })
})
mlr-org/mlr3extralearners documentation built on Sept. 16, 2024, 3:11 a.m.