tests/testthat/test-surv_reg_survreg.R

if (R.Version()$major < "4") {
  data(lung, package = 'survival')
} else {
  data(cancer, package = 'survival')
}

basic_form <- survival::Surv(time, status) ~ group
complete_form <- survival::Surv(time) ~ group

# ------------------------------------------------------------------------------

test_that('survival execution', {
  skip_on_travis()

  rlang::local_options(lifecycle_verbosity = "quiet")
  surv_basic <- surv_reg() %>% set_engine("survival")
  surv_lnorm <- surv_reg(dist = "lognormal") %>% set_engine("survival")

  expect_error(
    res <- fit(
      surv_basic,
      survival::Surv(time, status) ~ age + sex,
      data = lung,
      control = ctrl
    ),
    regexp = NA
  )

  expect_error(
    res <- fit(
      surv_lnorm,
      survival::Surv(time) ~ age + sex,
      data = lung,
      control = ctrl
    ),
    regexp = NA
  )
  expect_error(
    res <- fit_xy(
      surv_basic,
      x = lung[, c("age", "sex")],
      y = lung$time,
      control = ctrl
    )
  )
})

test_that('survival prediction', {
  skip_on_travis()

  rlang::local_options(lifecycle_verbosity = "quiet")
  surv_basic <- surv_reg() %>% set_engine("survival")
  surv_lnorm <- surv_reg(dist = "lognormal") %>% set_engine("survival")

  res <- fit(
    surv_basic,
    survival::Surv(time, status) ~ age + sex,
    data = lung,
    control = ctrl
  )
  exp_pred <- predict(extract_fit_engine(res), head(lung))
  exp_pred <- tibble(.pred = unname(exp_pred))
  expect_equal(exp_pred, predict(res, head(lung)))

  exp_quant <- predict(extract_fit_engine(res), head(lung), p = (2:4)/5, type = "quantile")
  exp_quant <-
    apply(exp_quant, 1, function(x)
      tibble(.pred = x, .quantile = (2:4) / 5))
  exp_quant <- tibble(.pred = exp_quant)
  obs_quant <- predict(res, head(lung), type = "quantile", quantile = (2:4)/5)

  expect_equal(as.data.frame(exp_quant), as.data.frame(obs_quant))

})
topepo/parsnip documentation built on April 16, 2024, 3:23 a.m.