tests/testthat/test-set_pred.R

test_that("set_pred() works", {
  set_new_model("fire")
  set_model_mode("fire", "partition")
  set_model_engine("fire", "partition", "stats")
  set_pred(
    model = "fire",
    mode = "partition",
    eng = "stats",
    type = "raw",
    value = list(
      pre = NULL,
      post = NULL,
      func = c(fun = "predict"),
      args =
        list(
          object = rlang::expr(object$fit),
          newdata = rlang::expr(new_data),
          type = "response"
        )
    )
  )

  expect_identical(
    get_pred_type("fire", "raw"),
    tibble::tibble(
      engine = "stats",
      mode = "partition",
      type = "raw",
      value = list(
        list(
          pre = NULL,
          post = NULL,
          func = c(fun = "predict"),
          args = list(
            object = rlang::expr(object$fit),
            newdata = rlang::expr(new_data),
            type = "response"
          )
        )
      )
    )
  )
})

test_that("set_pred() errors with wrong `model` argument", {
  set_new_model("tent")
  set_model_mode("tent", "partition")
  set_model_engine("tent", "partition", "stats")
  set_new_model("shed")
  set_model_mode("shed", "partition")
  set_model_engine("shed", "partition", "stats")

  expect_snapshot(
    error = TRUE,
    set_pred("light")
  )

  expect_snapshot(
    error = TRUE,
    set_pred(
      model = c("tent", "shed"),
      mode = "partition",
      eng = "stats",
      type = "raw",
      value = list(
        pre = NULL,
        post = NULL,
        func = c(fun = "predict"),
        args =
          list(
            object = rlang::expr(object$fit),
            newdata = rlang::expr(new_data),
            type = "response"
          )
      )
    )
  )
})

test_that("set_pred() errors with wrong `mode` argument", {
  set_new_model("game")
  set_model_mode("game", "partition")
  set_model_engine("game", "partition", "stats")

  expect_snapshot(
    error = TRUE,
    set_pred("game")
  )

  expect_snapshot(
    error = TRUE,
    set_pred("game", c("classification", "regression"))
  )

  expect_snapshot(
    error = TRUE,
    set_pred("game", NULL)
  )

  expect_snapshot(
    error = TRUE,
    set_pred(
      model = "game",
      mode = "not partition",
      eng = "stats",
      type = "raw",
      value = list(
        pre = NULL,
        post = NULL,
        func = c(fun = "predict"),
        args =
          list(
            object = rlang::expr(object$fit),
            newdata = rlang::expr(new_data),
            type = "response"
          )
      )
    )
  )
})

test_that("set_pred() errors with wrong `engine` argument", {
  set_new_model("dog")
  set_model_mode("dog", "partition")

  expect_snapshot(
    error = TRUE,
    set_pred("dog", "partition")
  )

  expect_snapshot(
    error = TRUE,
    set_pred("dog", "partition", c("glmnet", "stats"))
  )

  expect_snapshot(
    error = TRUE,
    set_model_engine("dog", "partition", NULL)
  )
})

test_that("set_pred() errors with wrong `value` argument", {
  set_new_model("trunk")
  set_model_mode("trunk", "partition")
  set_model_engine("trunk", "partition", "stats")

  expect_snapshot(
    error = TRUE,
    set_pred("trunk", "partition", "stats", "raw")
  )

  expect_snapshot(
    error = TRUE,
    set_pred("trunk", "partition", "stats", "raw", NULL)
  )

  expect_snapshot(
    error = TRUE,
    set_pred(
      model = "trunk",
      mode = "partition",
      eng = "stats",
      type = "raw",
      value = list(
        post = NULL,
        func = c(fun = "predict"),
        args =
          list(
            object = rlang::expr(object$fit),
            newdata = rlang::expr(new_data),
            type = "response"
          )
      )
    )
  )

  expect_snapshot(
    error = TRUE,
    set_pred(
      model = "trunk",
      mode = "partition",
      eng = "stats",
      type = "raw",
      value = list(
        pre = NULL,
        func = c(fun = "predict"),
        args =
          list(
            object = rlang::expr(object$fit),
            newdata = rlang::expr(new_data),
            type = "response"
          )
      )
    )
  )

  expect_snapshot(
    error = TRUE,
    set_pred(
      model = "trunk",
      mode = "partition",
      eng = "stats",
      type = "raw",
      value = list(
        pre = NULL,
        post = NULL,
        args =
          list(
            object = rlang::expr(object$fit),
            newdata = rlang::expr(new_data),
            type = "response"
          )
      )
    )
  )

  expect_snapshot(
    error = TRUE,
    set_pred(
      model = "trunk",
      mode = "partition",
      eng = "stats",
      type = "raw",
      value = list(
        pre = NULL,
        post = NULL,
        func = c(fun = "predict")
      )
    )
  )

  expect_snapshot(
    error = TRUE,
    set_pred(
      model = "trunk",
      mode = "partition",
      eng = "stats",
      type = "raw",
      value = list(
        pre = "NULL",
        post = NULL,
        func = c(fun = "predict"),
        args =
          list(
            object = rlang::expr(object$fit),
            newdata = rlang::expr(new_data),
            type = "response"
          )
      )
    )
  )

  expect_snapshot(
    error = TRUE,
    set_pred(
      model = "trunk",
      mode = "partition",
      eng = "stats",
      type = "raw",
      value = list(
        pre = NULL,
        post = "NULL",
        func = c(fun = "predict"),
        args =
          list(
            object = rlang::expr(object$fit),
            newdata = rlang::expr(new_data),
            type = "response"
          )
      )
    )
  )

  expect_snapshot(
    error = TRUE,
    set_pred(
      model = "trunk",
      mode = "partition",
      eng = "stats",
      type = "raw",
      value = list(
        pre = NULL,
        post = NULL,
        func = c(fun = "predict"),
        args = "not a list"
      )
    )
  )
})

test_that("set_pred() errors with wrong `type` argument", {
  set_new_model("scroll")
  set_model_mode("scroll", "partition")
  set_model_engine("scroll", "partition", "stats")

  expect_snapshot(
    error = TRUE,
    set_pred(
      model = "scroll",
      mode = "partition",
      eng = "stats",
      type = "not raw",
      value = list(
        pre = NULL,
        post = NULL,
        func = c(fun = "predict"),
        args =
          list(
            object = rlang::expr(object$fit),
            newdata = rlang::expr(new_data),
            type = "response"
          )
      )
    )
  )
})

test_that("is_discordant_info() triggers for set_pred()", {
  set_new_model("diamond")
  set_model_mode("diamond", "partition")
  set_model_engine("diamond", "partition", "stats")

  set_pred(
    model = "diamond",
    mode = "partition",
    eng = "stats",
    type = "raw",
    value = list(
      pre = NULL,
      post = NULL,
      func = c(fun = "predict"),
      args =
        list(
          object = rlang::expr(object$fit),
          newdata = rlang::expr(new_data),
          type = "response"
        )
    )
  )

  expect_snapshot(
    error = TRUE,
    set_pred(
      model = "diamond",
      mode = "partition",
      eng = "stats",
      type = "raw",
      value = list(
        pre =  NULL,
        post = NULL,
        func = c(fun = "not predict"),
        args =
          list(
            object = rlang::expr(object$fit),
            newdata = rlang::expr(new_data),
            type = "response"
          )
      )
    )
  )
})

test_that("set_pred() can be called multiple times", {
  set_new_model("firetruck")
  set_model_mode("firetruck", "partition")
  set_model_engine("firetruck", "partition", "stats")
  set_pred(
    model = "firetruck",
    mode = "partition",
    eng = "stats",
    type = "raw",
    value = list(
      pre = NULL,
      post = NULL,
      func = c(fun = "predict"),
      args =
        list(
          object = rlang::expr(object$fit),
          newdata = rlang::expr(new_data),
          type = "response"
        )
    )
  )

  expect_no_error(
    set_pred(
      model = "firetruck",
      mode = "partition",
      eng = "stats",
      type = "raw",
      value = list(
        pre = NULL,
        post = NULL,
        func = c(fun = "predict"),
        args =
          list(
            object = rlang::expr(object$fit),
            newdata = rlang::expr(new_data),
            type = "response"
          )
      )
    )
  )
})

Try the modelenv package in your browser

Any scripts or data that you put into this service are public.

modelenv documentation built on March 31, 2023, 10:31 p.m.