tests/testthat/helpers.R

# sample games we'll use to check with
game_ids <- c("1999_01_MIN_ATL", "2019_01_GB_CHI")

test_dir <- getwd()

pbp_cache <- tempfile("pbp_cache", fileext = ".rds")

load_test_pbp <- function(pbp = pbp_cache, dir = test_dir){
  if (file.exists(pbp) && !is.null(dir)){
    if(interactive()) cli::cli_alert_info("Will return pbp from cache")
    return(readRDS(pbp))
  }

  g <- readRDS(file.path(test_dir, paste0("games.rds")))

  # model output differs across machines so we round to 4 significant digits
  # to prevent failing tests
  pbp_data <- build_nflfastR_pbp(game_ids, dir = dir, games = g)
  if(!is.null(dir)) saveRDS(pbp_data, pbp)
  pbp_data
}

save_test_object <- function(object){
  obj_name <- deparse(substitute(object))
  tmp_file <- tempfile(obj_name, fileext = ".csv")
  modify_digits <- dplyr::mutate_if(object, is.numeric, signif, digits = 3)
  data.table::fwrite(modify_digits, tmp_file, na = "NA")
  invisible(tmp_file)
}

load_expectation <- function(type = c("pbp", "sc", "sc_weekly", "ep", "wp"),
                             dir = test_dir){
  type <- match.arg(type)
  file_name <- switch (
    type,
    "pbp" = "expected_pbp.rds",
    "sc" = "expected_sc.rds",
    "sc_weekly" = "expected_sc_weekly.rds",
    "ep" = "expected_ep.rds",
    "wp" = "expected_wp.rds",
  )
  strip_nflverse_attributes(readRDS(file.path(dir, file_name))) |>
    # we gotta round floating point numbers because of different model output
    # across platforms
    round_double_to_digits()
}

# strip nflverse attributes for tests because timestamp and version cause failures
# .internal.selfref is a data.table attribute that is not necessary in this case
strip_nflverse_attributes <- function(df){
  input_attrs <- names(attributes(df))
  input_remove <- input_attrs[grepl("nflverse|.internal.selfref|nflfastR", input_attrs)]
  attributes(df)[input_remove] <- NULL
  df
}

round_double_to_digits <- function(df, digits = 3){
  dplyr::mutate(df, dplyr::across(
    .cols = relevant_variables(),
    .fns = function(vec){
      formatC(vec, digits = digits, format = "fg") |>
        as.numeric() |>
        suppressWarnings()
    }
  ))
}

relevant_variables <- function(){
  c(
    dplyr::any_of(c(
      "no_score_prob", "opp_fg_prob", "opp_safety_prob", "opp_td_prob", "fg_prob",
      "safety_prob", "td_prob", "ep", "cp", "cpoe", "pass_oe", "xpass"
    )),
    dplyr::ends_with("epa"),
    dplyr::ends_with("wp"),
    dplyr::ends_with("wp_post"),
    dplyr::ends_with("wpa"),
    dplyr::starts_with("xyac")
  )
}
nflverse/nflfastR documentation built on April 17, 2025, 9:34 p.m.