tests/testthat/test_pipeop_tunethreshold.R

context("tunethreshold")

test_that("threshold works for multiclass", {
  t = tsk("iris")
  po_cv =  po("learner_cv", learner = lrn("classif.rpart", predict_type = "prob"))
  res = po_cv$train(list(t))
  po_thr = po("tunethreshold")
  expect_pipeop(po_thr)
  po_thr$train(res)
  thr = po_thr$state$threshold
  expect_numeric(thr, len = 3L, lower = 0, upper = 1)
  expect_set_equal(names(thr), t$class_names)
  res2 = po_cv$predict(list(t))
  out = po_thr$predict(res2)[[1]]
  expect_prediction(out)
  expect_true(out$score() < 0.33)
})

test_that("threshold works for binary", {
  t = tsk("pima")
  po_cv =  po("learner_cv", learner = lrn("classif.rpart", predict_type = "prob"))
  res = po_cv$train(list(t))
  po_thr = po("tunethreshold")
  expect_pipeop(po_thr)
  po_thr$train(res)
  thr = po_thr$state$threshold
  expect_numeric(thr, len = 2, lower = 0, upper = 1)
  expect_set_equal(names(thr), t$class_names)
  res2 = po_cv$predict(list(t))
  out = po_thr$predict(res2)[[1]]
  expect_prediction(out)
  expect_true(out$score() < 0.33)
  po_cv =  po("learner_cv", learner = lrn("classif.rpart", predict_type = "response")) %>>%
    po("tunethreshold")
  expect_error(po_cv$train(t), "prob")
})

Try the mlr3pipelines package in your browser

Any scripts or data that you put into this service are public.

mlr3pipelines documentation built on Sept. 21, 2022, 9:09 a.m.