tests/testthat/helper-cal.R

# ---------------------------- Test data sets ----------------------------------
# ------------------------------ >> Binary -------------------------------------
.cal_env <- new.env()

testthat_cal_binary <- function() {
  ret <- .cal_env$tune_results

  if (is.null(ret)) {
    ret_file <- test_path("cal_files/binary_sim.rds")

    if (!file.exists(ret_file)) {
      if (!dir.exists(test_path("cal_files"))) {
        dir.create(test_path("cal_files"))
      }

      set.seed(111)

      sim_data <- modeldata::sim_classification(500)

      rec <- recipes::recipe(class ~ ., data = sim_data) %>%
        recipes::step_ns(linear_01, deg_free = tune::tune("linear_01"))

      ret <- tune::tune_grid(
        object = parsnip::set_engine(parsnip::logistic_reg(), "glm"),
        preprocessor = rec,
        resamples = rsample::vfold_cv(sim_data, v = 2, repeats = 3),
        control = tune::control_resamples(save_pred = TRUE)
      )
      saveRDS(ret, ret_file, version = 2)
    } else {
      ret <- readRDS(ret_file)
    }
    .cal_env$tune_results <- ret
    cp <- tune::collect_predictions(ret, summarize = TRUE)
    .cal_env$tune_results_count <- nrow(cp)
  }

  ret
}

testthat_cal_binary_count <- function() {
  ret <- .cal_env$tune_results_count
  if (is.null(ret)) {
    invisible(testthat_cal_binary())
    ret <- .cal_env$tune_results_count
  }
  ret
}

testthat_cal_sampled <- function() {
  ret <- .cal_env$resampled_data
  if (is.null(ret)) {
    set.seed(100)
    ret <- rsample::vfold_cv(segment_logistic)
    .cal_env$resampled_data <- ret
  }
  ret
}

# -------------------------- >> Multiclass (Tune) ------------------------------

testthat_cal_multiclass <- function() {
  ret <- .cal_env$tune_results_multi

  if (is.null(ret)) {
    ret_file <- test_path("cal_files/multiclass_ames.rds")

    if (!file.exists(ret_file)) {
      if (!dir.exists(test_path("cal_files"))) {
        dir.create(test_path("cal_files"))
      }

      set.seed(111)

      df <- sim_multinom_df(500)

      ranger_recipe <- recipes::recipe(
        formula = class ~ .,
        data = df
      )

      ranger_spec <- parsnip::rand_forest(
        mtry = tune(),
        min_n = tune(),
        trees = 200
      ) %>%
        parsnip::set_mode("classification") %>%
        parsnip::set_engine("ranger")

      ret <- tune::tune_grid(
        object = ranger_spec,
        preprocessor = ranger_recipe,
        resamples = rsample::vfold_cv(df, v = 2, repeats = 3),
        control = tune::control_resamples(save_pred = TRUE)
      )

      saveRDS(ret, ret_file, version = 2)
    } else {
      ret <- readRDS(ret_file)
    }
    .cal_env$tune_results_multi <- ret
    cp <- tune::collect_predictions(ret, summarize = TRUE)
    .cal_env$tune_results_multi_count <- nrow(cp)
  }

  ret
}

testthat_cal_multiclass_count <- function() {
  ret <- .cal_env$tune_results_multi_count
  if (is.null(ret)) {
    invisible(testthat_cal_multiclass())
    ret <- .cal_env$tune_results_multi_count
  }
  ret
}

# -------------------------- >> Multiclass (Sim) -------------------------------

testthat_cal_sim_multi <- function() {
  x <- "sim_multi"
  ret <- .cal_env[[x]]

  if (is.null(ret)) {
    pt <- paste0("cal_files/", x, ".rds")

    ret_file <- test_path(pt)

    if (!file.exists(ret_file)) {
      if (!dir.exists(test_path("cal_files"))) {
        dir.create(test_path("cal_files"))
      }

      set.seed(1)
      train <- sim_multinom_df(200)
      test <- sim_multinom_df()

      model <- randomForest::randomForest(class ~ ., train)

      ret <- model %>%
        predict(test, type = "prob") %>%
        as.data.frame() %>%
        dplyr::rename_all(~ paste0(".pred_", .x)) %>%
        dplyr::mutate(class = test$class)

      saveRDS(ret, ret_file, version = 2)
    } else {
      ret <- readRDS(ret_file)
    }
    .cal_env[[x]] <- ret
  }
  ret
}


testthat_cal_reg <- function() {
  ret <- .cal_env$reg_tune_results

  if (is.null(ret)) {
    ret_file <- test_path("cal_files/reg_sim.rds")

    if (!file.exists(ret_file)) {
      if (!dir.exists(test_path("cal_files"))) {
        dir.create(test_path("cal_files"))
      }

      set.seed(111)

      sim_data <- modeldata::sim_regression(100)[, 1:3]

      rec <- recipes::recipe(outcome ~ ., data = sim_data) %>%
        recipes::step_ns(predictor_01, deg_free = tune::tune("predictor_01"))

      ret <- tune::tune_grid(
        object = parsnip::linear_reg(),
        preprocessor = rec,
        resamples = rsample::bootstraps(sim_data, times = 3),
        control = tune::control_resamples(save_pred = TRUE)
      )
      saveRDS(ret, ret_file, version = 2)
    } else {
      ret <- readRDS(ret_file)
    }
    .cal_env$reg_tune_results <- ret
    cp <- tune::collect_predictions(ret, summarize = TRUE)
    .cal_env$reg_tune_results_count <- nrow(cp)
  }

  ret
}

testthat_cal_reg_count <- function() {
  ret <- .cal_env$reg_tune_results_count
  if (is.null(ret)) {
    invisible(testthat_cal_reg())
    ret <- .cal_env$reg_tune_results_count
  }
  ret
}

testthat_cal_reg_sampled <- function() {
  ret <- .cal_env$resampled_reg_data
  if (is.null(ret)) {
    set.seed(100)
    ret <- rsample::vfold_cv(boosting_predictions_oob)
    .cal_env$resampled_reg_data <- ret
  }
  ret
}

sim_multinom_df <- function(n = 1000) {
  modeldata::sim_multinomial(
    n,
    ~ -0.5 + 0.6 * abs(A),
    ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, -2),
    ~ -0.6 * A + 0.50 * B - A * B
  )
}

testthat_cal_fit_rs <- function() {
  ret <- .cal_env$resample_results

  if (is.null(ret)) {
    ret_file <- test_path("cal_files/fit_rs.rds")

    if (!file.exists(ret_file)) {
      if (!dir.exists(test_path("cal_files"))) {
        dir.create(test_path("cal_files"))
      }
      suppressPackageStartupMessages(library(tune))
      suppressPackageStartupMessages(library(parsnip))
      suppressPackageStartupMessages(library(rsample))

      ctrl <- control_resamples(save_pred = TRUE)

      set.seed(111)
      rs_bin <-
        modeldata::sim_classification(100)[, 1:3] %>%
        dplyr::rename(outcome = class) %>%
        vfold_cv() %>%
        fit_resamples(logistic_reg(), outcome ~ ., resamples = ., control = ctrl)
      set.seed(112)
      rs_mlt <-
        sim_multinom_df(500) %>%
        dplyr::rename(outcome = class) %>%
        vfold_cv() %>%
        fit_resamples(mlp() %>% set_mode("classification"),
          outcome ~ .,
          resamples = ., control = ctrl
        )
      set.seed(113)
      rs_reg <-
        modeldata::sim_regression(100)[, 1:3] %>%
        vfold_cv() %>%
        fit_resamples(linear_reg(), outcome ~ ., resamples = ., control = ctrl)

      ret <- list(binary = rs_bin, multin = rs_mlt, reg = rs_reg)

      saveRDS(ret, file = ret_file, version = 2)
    } else {
      ret <- readRDS(ret_file)
    }
    .cal_env$resample_results <- ret
  }

  ret
}

# --------------------------- Custom Expect Functions --------------------------

expect_cal_type <- function(x, type) {
  expect_equal(x$type, type)
}

expect_cal_method <- function(x, method) {
  expect_equal(x$method, method)
}

expect_cal_estimate <- function(x, class) {
  expect_s3_class(x$estimates[[1]]$estimate, class)
}

expect_cal_rows <- function(x, n = 1010) {
  expect_equal(x$rows, n)
}


# ------------------------------------------------------------------------------

save_png <- function(code, width = 400, height = 400) {
  path <- tempfile(fileext = ".png")
  png(path, width = width, height = height)
  on.exit(dev.off())
  code

  path
}

expect_snapshot_plot <- function(name, code) {
  skip_on_os("windows")
  skip_on_os("linux")
  skip_on_os("solaris")

  name <- paste0(name, ".png")

  # Announce the file before touching `code`. This way, if `code`
  # unexpectedly fails or skips, testthat will not auto-delete the
  # corresponding snapshot file.
  announce_snapshot_file(name = name)

  path <- save_png(code)
  expect_snapshot_file(path, name)
}

has_facet <- function(x) {
  inherits(x$facet, c("FacetWrap", "FacetGrid"))
}

are_groups_configs <- function(x) {
  fltrs <- purrr::map(x$estimates, ~ .x$filter)

  # Check if anything is in the filter slot
  are_null <- purrr::map_lgl(fltrs, ~ all(is.null(.x)))
  if (all(are_null)) {
    return(FALSE)
  }

  fltr_vars <- purrr::map(fltrs, all.vars)
  are_config <- purrr::map_lgl(fltr_vars, ~ identical(.x, ".config"))
  all(are_config)
}

bin_with_configs <- function() {
  set.seed(1)
  segment_logistic %>%
    dplyr::mutate(.config = sample(letters[1:2], nrow(segment_logistic), replace = TRUE))
}

mnl_with_configs <- function() {
  data("hpc_cv", package = "modeldata")

  set.seed(1)
  hpc_cv %>%
    dplyr::mutate(.config = sample(letters[1:2], nrow(hpc_cv), replace = TRUE))
}

reg_with_configs <- function() {
  data("solubility_test", package = "modeldata")

  set.seed(1)

  solubility_test %>%
    dplyr::mutate(.config = sample(letters[1:2], nrow(solubility_test), replace = TRUE))
}

holdout_length <- function(x) {
  length(as.integer(x, data = "assessment"))
}
topepo/probably documentation built on April 6, 2024, 7:32 p.m.