tests/testthat/test-performance.R

test_that("get_outcome_type works", {
  expect_equal(get_outcome_type(c(1, 2, 1)), "continuous")
  expect_equal(get_outcome_type(c("a", "b", "b")), "binary")
  expect_equal(get_outcome_type(c("a", "b", "c")), "multiclass")
})
test_that("get_outcome_type errors when num_outcomes < 2", {
  error_msg <- "A continuous, binary, or multi-class outcome variable is required, but this dataset has "
  expect_error(
    get_outcome_type(c(1, 1)),
    error_msg
  )
  expect_error(
    get_outcome_type(c("a", "a", "a")),
    error_msg
  )
  expect_error(
    get_outcome_type(c()),
    error_msg
  )
})

test_that("get_perf_metric_fn works", {
  expect_equal(get_perf_metric_fn("continuous"), caret::defaultSummary)
  expect_equal(get_perf_metric_fn("binary"), caret::multiClassSummary)
  expect_equal(get_perf_metric_fn("multiclass"), caret::multiClassSummary)
  expect_error(get_perf_metric_fn("asdf"), "Outcome type of outcome must be one of:")
})

test_that("get_perf_metric_name works", {
  expect_equal(get_perf_metric_name("continuous"), "RMSE")
  expect_equal(get_perf_metric_name("binary"), "AUC")
  expect_equal(get_perf_metric_name("multiclass"), "logLoss")
  expect_error(get_perf_metric_name("asdf"), "Outcome type of outcome must be one of:")
})

test_that("calc_perf_metrics works", {
  expect_equal(
    calc_perf_metrics(otu_mini_bin_results_glmnet$test_data,
      otu_mini_bin_results_glmnet$trained_model,
      "dx",
      caret::multiClassSummary,
      class_probs = TRUE
    ),
    unlist(c(otu_mini_bin_results_glmnet$performance[, !(colnames(otu_mini_bin_results_glmnet$performance) %in% c("cv_metric_AUC", "method", "seed"))]))
  )
})

test_that("get_performance_tbl works", {
  set.seed(2019)
  expect_equal(
    get_performance_tbl(
      otu_mini_bin_results_glmnet$trained_model,
      otu_mini_bin_results_glmnet$test_data,
      "dx",
      caret::multiClassSummary,
      "AUC",
      TRUE,
      "glmnet",
      seed = 2019
    ),
    otu_mini_bin_results_glmnet$performance
  )
  expect_warning(
    get_performance_tbl(
      otu_mini_bin_results_glmnet$trained_model,
      otu_mini_bin_results_glmnet$test_data,
      "dx",
      caret::multiClassSummary,
      "not_a_perf_metric",
      TRUE,
      "glmnet",
      seed = 2019
    ),
    "The performance metric provided does not match the metric used to train the data."
  )
})

test_that("calc_perf_bootstrap_split returns consistent results", {
  ml_result <- otu_mini_bin_results_glmnet
  outcome_colname <- "dx"
  model <- ml_result$trained_model
  test_dat <- ml_result$test_data
  outcome_type <- get_outcome_type(test_dat %>% dplyr::pull(outcome_colname))
  class_probs <- outcome_type != "continuous"
  method <- model$modelInfo$label
  seed <- ml_result$performance %>% dplyr::pull(seed)
  perf_metric_function <- get_perf_metric_fn(outcome_type)
  perf_metric_name <- model$metric
  class_probs <- outcome_type != "continuous"
  method <- model$trained_model$modelInfo$label

  set.seed(20230129)
  split_1 <- rsample::bootstraps(test_dat, times = 3) %>%
    dplyr::pull(splits) %>%
    .[[1]]
  perf_1 <- calc_perf_bootstrap_split(
    split_1, model, outcome_colname,
    perf_metric_function,
    perf_metric_name,
    class_probs,
    method,
    seed
  )
  set.seed(20230129)
  split_2 <- rsample::bootstraps(test_dat, times = 3) %>%
    dplyr::pull(splits) %>%
    .[[1]]
  perf_2 <- calc_perf_bootstrap_split(
    split_2, model, outcome_colname,
    perf_metric_function,
    perf_metric_name,
    class_probs,
    method,
    seed
  )

  expect_equal(perf_1, perf_2)
})

test_that("bootstrap_performance returns consistent results", {
  set.seed(20230129)
  expect_warning(
    boot_1 <- bootstrap_performance(
      otu_mini_bin_results_glmnet,
      "dx",
      bootstrap_times = 1,
      alpha = 1
    ),
    "Recommend at least 1000 non-missing bootstrap resamples for terms"
  )
  boot_2 <- structure(
    list(
      term = c(
        "Accuracy", "AUC", "Balanced_Accuracy",
        "cv_metric_AUC", "Detection_Rate", "F1", "Kappa", "logLoss",
        "Neg_Pred_Value", "Pos_Pred_Value", "prAUC", "Precision", "Recall",
        "Sensitivity", "Specificity"
      ),
      .lower = c(
        0.512820512820513,
        0.592391304347826, 0.520380434782609, 0.622173713235294, 0.282051282051282,
        0.536585365853659, 0.0389105058365759, 0.688040379015192, 0.428571428571429,
        0.611111111111111, 0.524734824159823, 0.611111111111111, 0.478260869565217,
        0.478260869565217, 0.5625
      ),
      .estimate = c(
        0.512820512820513,
        0.592391304347826, 0.520380434782609, 0.622173713235294, 0.282051282051282,
        0.536585365853659, 0.0389105058365759, 0.688040379015192, 0.428571428571429,
        0.611111111111111, 0.524734824159823, 0.611111111111111, 0.478260869565217,
        0.478260869565217, 0.5625
      ),
      .upper = c(
        0.512820512820513, 0.592391304347826,
        0.520380434782609, 0.622173713235294, 0.282051282051282, 0.536585365853659,
        0.0389105058365759, 0.688040379015192, 0.428571428571429, 0.611111111111111,
        0.524734824159823, 0.611111111111111, 0.478260869565217, 0.478260869565217,
        0.5625
      ),
      .alpha = c(
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1
      ),
      .method = c(
        "percentile", "percentile", "percentile", "percentile",
        "percentile", "percentile", "percentile", "percentile", "percentile",
        "percentile", "percentile", "percentile", "percentile", "percentile",
        "percentile"
      )
    ),
    class = c("tbl_df", "tbl", "data.frame"),
    row.names = c(NA, -15L)
  )
  expect_equal(boot_1 %>% dplyr::arrange(term), boot_2 %>% dplyr::arrange(term))
})

test_that("sensspec calculations work", {
  expect_equal(
    calc_model_sensspec(
      otu_mini_bin_results_glmnet$trained_model,
      otu_mini_bin_results_glmnet$test_data,
      "dx"
    ),
    readRDS(
      testthat::test_path("fixtures", "otu_mini_bin_results_glmnet_sensspec.Rds")
    )
  )
  sensspec_dat <- readRDS(test_path("fixtures", "sensspec_dat.Rds"))
  expect_equal(
    calc_mean_roc(sensspec_dat),
    readRDS(testthat::test_path("fixtures", "sensspec_roc.Rds"))
  )
  expect_equal(
    calc_mean_prc(sensspec_dat),
    readRDS(testthat::test_path("fixtures", "sensspec_prc.Rds"))
  )
})

test_that("calc_baseline_precision works", {
  expect_equal(
    calc_baseline_precision(otu_mini_bin, "dx", "cancer"),
    0.49
  )
  expect_equal(
    calc_baseline_precision(otu_mini_bin, "dx", "normal"),
    0.51
  )
  expect_equal(
    data.frame(y = c("a", "b", "a", "b")) %>%
      calc_baseline_precision("y", "a"),
    0.50
  )
  expect_equal(
    data.frame(y = c("a", rep.int("b", 4))) %>%
      calc_baseline_precision(outcome_colname = "y", pos_outcome = "a"),
    0.2
  )
  expect_error(
    data.frame(y = c("a")) %>%
      calc_baseline_precision("y", "a"),
    "A binary or multi-class outcome variable is required"
  )
  expect_error(
    data.frame(y = c("b")) %>%
      calc_baseline_precision("y", "a"),
    "A binary or multi-class outcome variable is required"
  )
})
test_that("calc_balanced_precision", {
  # when precision is already balanced
  expect_equal(calc_balanced_precision(0.3, 0.5), 0.3)
  expect_equal(calc_balanced_precision(0.9, 0.5), 0.9)

  # other situations
  expect_equal(calc_balanced_precision(0.2, 0.2), 0.5)
})
SchlossLab/mikropml documentation built on Aug. 24, 2023, 9:51 p.m.