tests/testthat/test_Learner.R

test_that("construction", {
  l = Learner$new("test-learner", task_type = "classif", predict_types = "prob")
  expect_class(l, "Learner")
})

test_that("clone", {
  l1 = lrn("classif.rpart")$train(tsk("iris"))
  l2 = l1$clone(deep = TRUE)

  expect_different_address(l1$state$log, l2$state$log)
  expect_different_address(l1$param_set, l2$param_set)

  l1$param_set$values = list(xval = 10L)
  expect_equal(l1$param_set$values$xval, 10L)
  expect_equal(l2$param_set$values$xval, 0L)
})

test_that("Learners are called with invoke / small footprint of call", {
  task = tsk("boston_housing")
  learner = lrn("regr.rpart")
  learner$train(task)
  call = as.character(learner$model$call)
  expect_character(call, min.len = 1L, any.missing = FALSE)
  expect_true(any(grepl("task$formula()", call, fixed = TRUE)))
  expect_true(any(grepl("task$data", call, fixed = TRUE)))
  expect_lt(sum(nchar(call)), 1000)
})

test_that("Extra data slots of learners are kept / reset", {
  task = tsk("boston_housing")
  learner = lrn("regr.rpart")
  learner$train(task)
  learner$state$foo = "bar"
  expect_equal(learner$state$foo, "bar")
  learner$predict(task)
  expect_equal(learner$state$foo, "bar")

  learner$train(task)
  expect_null(learner$state$foo)
})

test_that("task is checked in train() / predict()", {
  learner = lrn("regr.rpart")
  expect_error(learner$train(tsk("pima")), "type")
  expect_error(learner$predict(tsk("pima")), "type")
})

test_that("learner timings", {
  learner = lrn("regr.rpart")
  t = learner$timings
  expect_equal(unname(t), as.double(c(NA, NA)))
  expect_equal(names(t), c("train", "predict"))


  learner$train(tsk("mtcars"))
  t = learner$timings
  expect_number(t[["train"]])
  expect_equal(t[["predict"]], NA_real_)

  learner$predict(tsk("mtcars"))
  t = learner$timings
  expect_number(t[["train"]])
  expect_number(t[["predict"]])
})

test_that("predict on newdata works / classif", {
  task = tsk("iris")$filter(1:120)
  learner = lrn("classif.featureless")
  expect_error(learner$predict(task), "trained")
  learner$train(task)
  expect_task(learner$state$train_task, null_backend_ok = TRUE)
  newdata = tsk("iris")$filter(121:150)$data()

  # passing the task
  p = learner$predict_newdata(newdata = newdata, task = task)
  expect_data_table(as.data.table(p), nrows = 30)
  expect_set_equal(as.data.table(p)$row_ids, 1:30)
  expect_factor(p$truth, any.missing = FALSE, levels = task$class_names)

  # rely on internally stored task representation
  p = learner$predict_newdata(newdata = newdata, task = NULL)
  expect_data_table(as.data.table(p), nrows = 30)
  expect_set_equal(as.data.table(p)$row_ids, 1:30)
  expect_factor(p$truth, any.missing = FALSE, levels = task$class_names)

  # with missing target column
  newdata$Species = NULL
  p = learner$predict_newdata(newdata = newdata, task = task)
  expect_data_table(as.data.table(p), nrows = 30)
  expect_set_equal(as.data.table(p)$row_ids, 1:30)
  expect_factor(p$truth, levels = task$class_names)
  expect_true(allMissing(p$truth))
})

test_that("train task is properly cloned (#383)", {
  lr = lrn("classif.rpart")$train(tsk("iris"))
  p = lr$predict_newdata(iris[1:3, 1:4])
  expect_equal(length(p$row_ids), 3)
  p = lr$predict_newdata(iris[1:3, 1:4])
  expect_equal(length(p$row_ids), 3)
})

test_that("predict on newdata works / regr", {
  task = tsk("boston_housing")
  train = which(seq_len(task$nrow) %% 2 == 0L)
  test = setdiff(seq_len(task$nrow), train)

  learner = lrn("regr.featureless")
  learner$train(task)

  newdata = task$clone()$filter(test)$data()
  p = learner$predict_newdata(newdata)

  expect_data_table(as.data.table(p), nrows = length(test))
  expect_set_equal(as.data.table(p)$row_ids, seq_along(test))
})


test_that("predict on newdata works / no target column", {
  task = tsk("boston_housing")
  train = which(seq_len(task$nrow) %% 2 == 0L)
  test = setdiff(seq_len(task$nrow), train)

  learner = lrn("regr.featureless")
  learner$train(task$clone()$filter(train))

  newdata = remove_named(task$clone()$filter(test)$data(), task$target_names)
  p = learner$predict_newdata(newdata = newdata)

  expect_data_table(as.data.table(p), nrows = length(test))
  expect_set_equal(as.data.table(p)$row_ids, seq_along(test))

  # newdata with 1 col and learner that does not support missings
  xdt = data.table(x = 1, y = 1)
  task = as_task_regr(xdt, target = "y")
  learner = lrn("regr.featureless")
  learner$properties = setdiff(learner$properties, "missings")
  learner$train(task)
  learner$predict_newdata(xdt[, 1])
})

test_that("predict on newdata works / empty data frame", {
  task = tsk("mtcars")
  splits = partition(task)
  learner = lrn("regr.featureless")
  learner$train(task$clone()$filter(splits$train))

  newdata = as_data_backend(task$data(rows = integer()))
  p = learner$predict_newdata(newdata = newdata)

  expect_data_table(as.data.table(p), nrows = 0)
  expect_set_equal(as.data.table(p)$row_ids, integer())
})

test_that("predict on newdata works / data backend input", {
  task = tsk("mtcars")
  splits = partition(task)
  learner = lrn("regr.featureless")
  learner$train(task$clone()$filter(splits$train))

  newdata = task$data(rows = splits$test)
  newdata$new_row_id = sample(1e4, nrow(newdata))
  backend = as_data_backend(newdata, primary_key = "new_row_id")
  p = learner$predict_newdata(newdata = backend)

  expect_data_table(as.data.table(p), nrows = backend$nrow)
  expect_set_equal(as.data.table(p)$row_ids, backend$rownames)
})


test_that("predict on newdata works / titanic use case", {
  skip_if_not_installed("mlr3data")
  data("titanic", package = "mlr3data")
  drop = c("cabin", "name", "ticket", "passenger_id")
  data = setDT(remove_named(titanic, drop))
  task = TaskClassif$new(id = "titanic", data[!is.na(survived), ], target = "survived")

  learner = lrn("classif.rpart")$train(task)
  p = learner$predict_newdata(newdata = data[is.na(survived)])
  expect_prediction_classif(p)
  expect_factor(p$response, levels = task$class_names, any.missing = FALSE)
  expect_factor(p$truth, levels = task$class_names)
  expect_true(allMissing(p$truth))
})

test_that("predict train + test set", {
  task = tsk("iris")
  m1 = msr("debug_classif", id = "tr", predict_sets = "train")
  m2 = msr("debug_classif", id = "te", predict_sets = "test")
  m3 = msr("debug_classif", id = "trte", predict_sets = c("train", "test"))
  measures = list(m1, m2, m3)
  hout = rsmp("holdout")$instantiate(task)
  n_train = length(hout$train_set(1))
  n_test = length(hout$test_set(1))

  learner = lrn("classif.rpart")
  rr = resample(task, learner, hout)
  expect_warning(expect_warning(rr$aggregate(measures = measures), "predict sets", fixed = TRUE))

  learner = lrn("classif.rpart", predict_sets = "train")
  rr = resample(task, learner, hout)
  expect_warning(expect_warning(rr$aggregate(measures = measures), "predict sets", fixed = TRUE))

  learner = lrn("classif.rpart", predict_sets = c("train", "test"))
  rr = resample(task, learner, hout)
  aggr = rr$aggregate(measures = measures)
  expect_equal(unname(is.na(aggr)), c(FALSE, FALSE, FALSE))
  expect_equal(unname(aggr), c(n_train, n_test, n_train + n_test))
})

test_that("assertions work (#357)", {
  measures = lapply(c("classif.auc", "classif.acc"), msr)
  task = tsk("iris")
  lrn = lrn("classif.featureless")
  p = lrn$train(task)$predict(task)

  expect_error(p$score(msr("classif.auc"), "predict.type"))
  expect_error(p$score(msr("regr.mse"), "task.type"))
})

test_that("reset()", {
  task = tsk("iris")
  lrn = lrn("classif.rpart")

  lrn$train(task)
  expect_list(lrn$state, names = "unique")
  expect_learner(lrn$reset())
  expect_null(lrn$state)
})

test_that("empty predict set (#421)", {
  task = tsk("iris")
  learner = lrn("classif.featureless")
  resampling = rsmp("holdout", ratio = 1)
  hout = resampling$instantiate(task)
  learner$train(task, hout$train_set(1))
  pred = learner$predict(task, hout$test_set(1))
  expect_prediction(pred)
  expect_true(any(grepl("No data to predict on", learner$log$msg)))
})

test_that("fallback learner is deep cloned (#511)", {
  l = lrn("classif.rpart")
  l$fallback = lrn("classif.featureless")
  expect_different_address(l$fallback, l$clone(deep = TRUE)$fallback)
})

test_that("learner cannot be trained with TuneToken present", {
  task = tsk("boston_housing")
  learner = lrn("regr.rpart", cp = paradox::to_tune(0.1, 0.3))
  expect_error(learner$train(task),
    regexp = "<LearnerRegrRpart:regr.rpart> cannot be trained with TuneToken present in hyperparameter: cp",
    fixed = TRUE)
})

test_that("integer<->numeric conversion in newdata (#533)", {
  data = data.table(y = runif(10), x = 1:10)
  newdata = data.table(y = runif(10), x = 1:10 + 0.1)

  task = TaskRegr$new("test", data, "y")
  learner = lrn("regr.featureless")
  learner$train(task)
  expect_prediction(learner$predict_newdata(data))
  expect_prediction(learner$predict_newdata(newdata))
})

test_that("weights", {
  data = cbind(iris, w = rep(c(1, 100, 1), each = 50))
  task = TaskClassif$new("weighted_task", data, "Species")
  task$set_col_roles("w", "weight")

  learner = lrn("classif.rpart")
  learner$train(task)

  conf = learner$predict(task)$confusion
  expect_equal(unname(conf[, 2]), c(0, 50, 0)) # no errors in class with weight 100

  expect_prediction(learner$predict_newdata(data[1:3, ]))
  expect_prediction(learner$predict_newdata(iris[1:3, ]))
})

test_that("mandatory properties",  {
  task = tsk("iris")
  learner = lrn("classif.rpart")
  learner$properties = setdiff(learner$properties, "multiclass")
  resample = rsmp("holdout")

  expect_error(learner$train(task), "multiclass")
  expect_error(resample(task, learner, resample), "multiclass")
  expect_error(benchmark(benchmark_grid(task, learner, resample)), "multiclass")
})

test_that("train task is cloned (#382)", {
  train_task = tsk("iris")
  lr = lrn("classif.rpart")$train(train_task)
  expect_different_address(lr$state$train_task, train_task)

  lrc = lr$clone(deep = TRUE)
  expect_different_address(lr$state$train_task, lrc$state$train_task)
})

test_that("Error on missing data (#413)", {
  task = tsk("pima")
  learner = lrn("classif.rpart")
  learner$properties = setdiff(learner$properties, "missings")
  expect_error(learner$train(task), "missing values")
})

test_that("Task prototype is stored in state", {
  task = tsk("boston_housing")
  learner = lrn("regr.rpart")
  learner$train(task)

  prototype = learner$state$data_prototype
  expect_data_table(prototype, nrows = 0, ncols = 18)
  expect_names(names(prototype), permutation.of = c(task$feature_names, task$target_names))
})

test_that("Models can be replaced", {
  task = tsk("boston_housing")
  learner = lrn("regr.featureless")
  learner$train(task)

  learner$model$location = 1
  expect_equal(learner$model$location, 1)
})

test_that("validation task's backend is removed", {
  learner = lrn("regr.rpart")
  task = tsk("mtcars")$divide(ids = 1:10)
  learner$train(task)
  expect_true(is.null(learner$state$train_task$internal_valid_task$backend))
})

test_that("manual $train() stores validation hash and validation ids", {
  task = tsk("iris")
  l = lrn("classif.debug", validate = 0.2)
  l$train(task)
  expect_character(l$state$internal_valid_task_hash)

  l = lrn("classif.debug", validate = "predefined")
  task = tsk("iris")
  task$divide(ids = 1:10)
  l$train(task)
  expect_equal(l$state$internal_valid_task_hash, task$internal_valid_task$hash)

  # nothing is stored for learners that don't do it
  l2 = lrn("classif.featureless")
  l2$train(task)
  expect_true(is.null(l2$state$internal_valid_task_hash))
})

test_that("error when training a learner that sets valiadte to 'predefined' on a task without a validation task", {
  task = tsk("iris")
  learner = lrn("classif.debug", validate = "predefined")
  expect_error(learner$train(task), "is set to 'predefined'")
  task$divide(ids = 1:10)
  expect_class(learner, "Learner")
})

test_that("properties are also checked on validation task", {
  task = tsk("iris")
  row = task$data(1)
  row[[1]][1] = NA
  row$..row_id = 151
  task$rbind(row)
  task$divide(ids = 151)
  learner = lrn("classif.debug", validate = "predefined")
  learner$properties = setdiff(learner$properties, "missings")
  expect_error(learner$train(task), "missing values")
})

test_that("validate changes phash and hash", {
  l1 = lrn("classif.debug")
  l2 = lrn("classif.debug")
  l2$validate = 0.2
  expect_false(l1$hash == l2$hash)
})

test_that("marshaling and encapsulation", {
  task = tsk("iris")
  learner = lrn("classif.debug", count_marshaling = TRUE)

  # callr encapsulation causes marshaling
  learner$encapsulate = c(train = "callr")
  learner$train(task)
  expect_equal(learner$model$marshal_count, 1)
  expect_false(learner$marshaled)
  expect_prediction(learner$predict(task))

  # no marshaling with no other encapsulation
  learner$encapsulate = c(train = "none")
  learner$train(task)
  expect_equal(learner$model$marshal_count, 0)
})

test_that("marshal state", {
  state = lrn("classif.debug")$train(tsk("iris"))$state
  sm = marshal_model(state)
  expect_true(is_marshaled_model(sm))
  expect_equal(state, unmarshal_model(marshal_model(state)))
})


test_that("internal_valid_task is created correctly", {
  LearnerClassifTest = R6Class("LearnerClassifTest", inherit = LearnerClassifDebug,
    public = list(
      task = NULL
    ),
    private = list(
      .train = function(task, ...)  {
        self$task = task$clone(deep = TRUE)
        super$.train(task, ...)
      }
    )
  )
  # validate = NULL (but task has one)
  learner = LearnerClassifTest$new()
  task = tsk("iris")$divide(0.3)
  learner$train(task)
  learner$validate = NULL
  expect_true(is.null(learner$internal_valid_scores))
  expect_true(is.null(learner$task$internal_valid_task))

  # validate = NULL (but task has none)
  learner1 = LearnerClassifTest$new()
  task1 = tsk("iris")
  learner1$train(task1)
  expect_true(is.null(learner1$internal_valid_scores))
  expect_true(is.null(learner1$task$internal_valid_task))

  # validate = "test"
  LearnerClassifTest2 = R6Class("LearnerClassifTest2", inherit = LearnerClassifDebug,
    public = list(
      expected_valid_ids = NULL,
      expected_train_ids = NULL
    ),
    private = list(
      .train = function(task, ...)  {
        if (!test_permutation(task$internal_valid_task$row_ids, self$expected_valid_ids)) {
          stopf("something went wrong")
        }
        if (!test_permutation(task$row_ids, self$expected_train_ids)) {
          stopf("something went wrong")
        }
        super$.train(task, ...)
      }
    )
  )
  task2 = tsk("iris")
  learner2 = LearnerClassifTest2$new()
  learner2$validate = "test"
  resampling = rsmp("holdout")$instantiate(task2)
  learner2$expected_valid_ids = resampling$test_set(1)
  learner2$expected_train_ids = resampling$train_set(1)
  expect_error(resample(task2, learner2, resampling), regexp = NA)

  # ratio works
  LearnerClassifTest3 = R6Class("LearnerClassifTest3", inherit = LearnerClassifDebug,
    private = list(
      .train = function(task, ...)  {
        if (length(task$internal_valid_task$row_ids) != 20) {
          stopf("something went wrong")
        }
        super$.train(task, ...)
      }
    )
  )
  task = tsk("iris")$filter(1:100)
  learner3 = LearnerClassifTest3$new()
  learner3$validate = 0.2
  learner3$train(task)

  # check that validation task is reset discarded at the end
  learner4 = lrn("classif.debug", validate = 0.2)
  task = tsk("iris")
  learner4$train(task)
  expect_true(is.null(task$internal_valid_task))
})

test_that("compatability check on validation task", {
  learner = lrn("classif.debug", validate = "predefined")
  task = tsk("german_credit")$divide(ids = 1:10)
  task$col_roles$feature = "age"
  expect_error(learner$train(task), "has different features")
  task$internal_valid_task$col_roles$feature = "age"
  task$internal_valid_task$col_roles$target = "credit_history"
  expect_error(learner$train(task), "has different target")
})
  
test_that("model is marshaled during parallel predict", {
  # by setting check_pid = TRUE, we ensure that unmarshal_model() sets the process id to the current
  # id. LearnerClassifDebug then checks during `.predict()`, whether the marshal_id of the model is equal to the current process id and errs if this is not the case.
  task = tsk("iris")
  learner = lrn("classif.debug", check_pid = TRUE)
  learner$train(task)
  learner$parallel_predict = TRUE
  pred = with_future(future::multisession, {
    learner$predict(task)
  })
  expect_class(pred, "Prediction")
})

test_that("model is marshaled during callr prediction", {
  # by setting check_pid = TRUE, we ensure that unmarshal_model() sets the process id to the current
  # id. LearnerClassifDebug then checks during `.predict()`, whether the marshal_id of the model is equal to the current process id and errs if this is not the case.
  task = tsk("iris")
  learner = lrn("classif.debug", check_pid = TRUE, encapsulate = c(predict = "callr"))
  learner$train(task)
  pred = learner$predict(task)
  expect_class(pred, "Prediction")
})

test_that("predict leaves marshaling status as-is", {
  task = tsk("iris")
  learner = lrn("classif.debug", check_pid = TRUE, encapsulate = c(predict = "callr"))
  learner$train(task)
  learner$marshal()
  expect_class(learner$predict(task), "Prediction")
  expect_true(learner$marshaled)
  learner$unmarshal()
  expect_class(learner$predict(task), "Prediction")
  expect_false(learner$marshaled)
})

test_that("learner state contains validate field", {
  learner = lrn("classif.debug", validate = 0.2)
  learner$train(tsk("iris"))
  expect_equal(learner$state$validate, 0.2)
})

test_that("learner state contains internal valid task information", {
  task = tsk("iris")
  # 1. resample
  learner = lrn("classif.debug", validate = 0.2)
  rr = resample(task, learner, rsmp("holdout"))
  expect_string(rr$learners[[1L]]$state$internal_valid_task_hash)

  # 1. manual
  learner$train(task) 
  expect_string(learner$state$internal_valid_task_hash)
})

test_that("validation task with 0 observations", {
  learner = lrn("classif.debug", validate = "predefined")
  task = tsk("iris")
  task$divide(ids = integer(0))
  expect_error({learner$train(task)}, "has 0 observations")
})
mlr-org/mlr3 documentation built on July 10, 2024, 10:53 a.m.