tests/testthat/test-orsf_vi.R

pbc_vi <- pbc_orsf

pbc_vi$junk <- rnorm(nrow(pbc_orsf))

pbc_vi$junk_cat <- factor(
 sample(letters[1:5], size = nrow(pbc_orsf), replace = TRUE)
)

# simulate a variable with unused factor level
levels(pbc_vi$edema) <- c(levels(pbc_vi$edema), 'empty_lvl')

formula <- Surv(time, status) ~ protime + edema + bili + junk + junk_cat

test_that(
 desc = paste(
  "(1) variable importance is independent from function order",
  "(2) variable importance is independent from n_thread",
  "(3) variable importance is correct"
 ),
 code = {

  for(importance in c('negate', 'permute', 'anova')){

   for(group_factors in c(TRUE, FALSE)){

    fit_with_vi <- orsf(pbc_vi,
                        formula = formula,
                        importance = importance,
                        n_tree = n_tree_test,
                        group_factors = group_factors,
                        tree_seeds = seeds_standard)


    vi_during_fit <- orsf_vi(fit_with_vi,
                             group_factors = group_factors)

    wrapper_fun <- switch(
     importance,
     'anova' = orsf_vi_anova,
     'permute' = orsf_vi_permute,
     'negate' = orsf_vi_negate
    )

    expect_equal(
     vi_during_fit,
     wrapper_fun(fit_with_vi, group_factors = group_factors)
    )


    if(group_factors){
     expect_true("edema" %in% names(vi_during_fit))
    } else {
     expect_true("edema_0.5" %in% names(vi_during_fit))
     expect_true("edema_1" %in% names(vi_during_fit))
     expect_true(vi_during_fit['edema_empty_lvl'] == 0)
    }


    if(importance != 'anova'){

     fit_no_vi <- orsf(pbc_vi,
                       formula = formula,
                       importance = 'none',
                       n_tree = n_tree_test,
                       group_factors = group_factors,
                       tree_seeds = seeds_standard)

     expect_error(orsf_vi(fit_no_vi), regexp = 'no variable importance')

     vi_after_fit <- orsf_vi(fit_no_vi,
                             importance = importance,
                             group_factors = group_factors)

     fit_vi_custom <- orsf(pbc_vi,
                           formula = formula,
                           n_tree = n_tree_test,
                           oobag_fun = oobag_c_risk,
                           importance = importance,
                           tree_seeds = seeds_standard)

     vi_custom_during_fit <- orsf_vi(fit_vi_custom,
                                     group_factors = group_factors)

     vi_custom_after_fit <- orsf_vi(fit_no_vi,
                                    importance = importance,
                                    group_factors = group_factors,
                                    oobag_fun = oobag_c_risk)


     expect_equal(vi_during_fit, vi_after_fit)
     expect_equal(vi_custom_after_fit, vi_after_fit)
     expect_equal(vi_custom_during_fit, vi_after_fit)

     fit_custom_oobag <- orsf(pbc_vi,
                              formula = formula,
                              importance = importance,
                              n_tree = n_tree_test,
                              oobag_fun = oobag_c_risk,
                              group_factors = group_factors,
                              tree_seeds = seeds_standard)

     vi_custom_oobag <- orsf_vi(fit_custom_oobag,
                                group_factors = group_factors)

     # why equal?  oobag_c_risk is a 'custom' eval fun
     # that is equivalent to the eval fun we use by default
     expect_equal(vi_during_fit, vi_custom_oobag)

    }

    fit_threads <- orsf(pbc_vi,
                        formula = formula,
                        importance = importance,
                        n_tree = n_tree_test,
                        n_thread = 0,
                        group_factors = group_factors,
                        tree_seeds = seeds_standard)

    vi_threads <- orsf_vi(fit_threads,
                          group_factors = group_factors)

    expect_equal(vi_during_fit, vi_threads)

    good_vars <- c('bili',
                   'protime',
                   if(group_factors) 'edema' else c("edema_1"))

    bad_vars <- setdiff(names(vi_during_fit), c(good_vars, "edema_0.5"))

    vi_good_vars <- vi_during_fit[good_vars]
    vi_bad_vars <- vi_during_fit[bad_vars]

    for(j in seq_along(vi_good_vars)){
     expect_true( mean(vi_bad_vars < vi_good_vars[j]) > 1/2 )
    }

   }

  }

 }

)

test_that(
 desc = 'can only compute anova vi during fit',
 code = {
  fit_no_vi <- orsf(pbc_vi, time+status~.,
                    n_tree = n_tree_test,
                    importance = 'none')
  expect_error(orsf_vi_anova(fit_no_vi), regexp = 'ANOVA')
  expect_error(orsf_vi(fit_no_vi, importance = 'anova'), regexp = 'ANOVA')
 }
)

test_that(
 desc = 'can only compute vi if data were attached to fit',
 code = {
  fit_no_data <- orsf(pbc_vi, time+status~.,
                      n_tree = n_tree_test,
                      attach_data = FALSE)
  expect_error(orsf_vi_anova(fit_no_data), regexp = 'training data')
  expect_error(orsf_vi_negate(fit_no_data), regexp = 'training data')
  expect_error(orsf_vi_permute(fit_no_data), regexp = 'training data')
 }
)


test_that(
 desc = 'informative errors for custom functions',
 code = {

  fit_no_vi <- orsf(pbc_vi, formula, importance = 'none', n_tree = 1)

  expect_error(
   orsf_vi_anova(object = 'nope'),
   regexp = 'inherit'
  )

  expect_error(
   orsf_vi_negate(fit_no_vi, oobag_fun = oobag_fun_bad_name),
   regexp = 'y_mat'
  )

  expect_error(
   orsf_vi_negate(fit_no_vi, oobag_fun = oobag_fun_bad_name_2),
   regexp = 's_vec'
  )

  expect_error(
   orsf_vi_negate(fit_no_vi, oobag_fun = oobag_fun_bad_name_3),
   regexp = 'w_vec'
  )

  expect_error(
   orsf_vi_negate(fit_no_vi, oobag_fun = oobag_fun_bad_out),
   regexp = 'length 1'
  )

  expect_error(
   orsf_vi_negate(fit_no_vi, oobag_fun = oobag_fun_bad_out_2),
   regexp = 'type character'
  )

  expect_error(
   orsf_vi_negate(fit_no_vi, oobag_fun = oobag_fun_errors_on_test),
   regexp = 'encountered an error'
  )

  expect_error(
   orsf_vi_negate(fit_no_vi, oobag_fun = oobag_fun_4_args),
   regexp = 'has 4'
  )

 }
)

Try the aorsf package in your browser

Any scripts or data that you put into this service are public.

aorsf documentation built on Oct. 26, 2023, 5:08 p.m.