R/press.R

Defines functions press

Documented in press

#' Run setup code and benchmarks across a grid of parameters
#'
#' @description
#' `press()` is used to run [bench::mark()] across a grid of parameters and
#' then _press_ the results together.
#'
#' The parameters you want to set are given as named arguments and a grid of
#' all possible combinations is automatically created.
#'
#' The code to setup and benchmark is given by one unnamed expression (often
#' delimited by `\{`).
#'
#' If replicates are desired a dummy variable can be used, e.g. `rep = 1:5` for
#' replicates.
#'
#' @param ... If named, parameters to define, if unnamed the expression to run.
#'   Only one unnamed expression is permitted.
#' @param .grid A pre-built grid of values to use, typically a [data.frame()] or
#'   [tibble::tibble()]. This is useful if you only want to benchmark a subset
#'   of all possible combinations.
#' @param .quiet If `TRUE`, progress messages will not be emitted.
#' @export
#' @examples
#' # Helper function to create a simple data.frame of the specified dimensions
#' create_df <- function(rows, cols) {
#'   as.data.frame(setNames(
#'     replicate(cols, runif(rows, 1, 1000), simplify = FALSE),
#'     rep_len(c("x", letters), cols)))
#' }
#'
#' # Run 4 data sizes across 3 samples with 2 replicates (24 total benchmarks)
#' press(
#'   rows = c(1000, 10000),
#'   cols = c(10, 100),
#'   rep = 1:2,
#'   {
#'     dat <- create_df(rows, cols)
#'     bench::mark(
#'       min_time = .05,
#'       bracket = dat[dat$x > 500, ],
#'       which = dat[which(dat$x > 500), ],
#'       subset = subset(dat, x > 500)
#'     )
#'   }
#' )
press <- function(..., .grid = NULL, .quiet = FALSE) {
  args <- rlang::quos(...)

  assert(
    "`.quiet` must be `TRUE` or `FALSE`",
    isTRUE(.quiet) || isFALSE(.quiet)
  )

  unnamed <- names(args) == ""

  if (sum(unnamed) < 1) {
    stop("Must supply one unnamed argument", call. = FALSE)
  }

  if (sum(unnamed) > 1) {
    stop("Must supply no more than one unnamed argument", call. = FALSE)
  }

  if (!is.null(.grid)) {
    if (any(!unnamed)) {
      stop(
        "Must supply either `.grid` or named arguments, not both",
        call. = FALSE
      )
    }
    parameters <- .grid
  } else {
    parameters <- expand.grid(
      lapply(args[!unnamed], rlang::eval_tidy),
      stringsAsFactors = FALSE
    )
  }

  # For consistent `[` methods
  parameters <- tibble::as_tibble(parameters)

  if (!.quiet) {
    status <- format(parameters, n = Inf)
    message(glue::glue("Running with:\n{status[[2]]}"))
  }

  eval_one <- function(row) {
    env <- rlang::new_data_mask(new.env(parent = emptyenv()))
    names <- names(parameters)

    for (i in seq_along(parameters)) {
      name <- names[[i]]
      column <- parameters[[i]]
      value <- column[row]
      assign(name, value, envir = env)
    }

    if (!.quiet) {
      message(status[[row + 3L]])
    }

    rlang::eval_tidy(args[[which(unnamed)]], data = env)
  }

  res <- lapply(seq_len(nrow(parameters)), eval_one)
  rows <- vapply(res, NROW, integer(1))

  if (!all(rows == rows[[1]])) {
    stop("Results must have equal rows", call. = FALSE)
    # TODO: print parameters / results that are unequal?
  }
  res <- do.call(rbind, res)
  parameters <- parameters[rep(seq_len(nrow(parameters)), each = rows[[1]]), ]
  bench_mark(tibble::as_tibble(cbind(res[1], parameters, res[-1])))
}
jimhester/bench documentation built on Jan. 18, 2025, 4:54 p.m.