tuning_run: Tune hyperparameters using training flags

View source: R/training_run.R

tuning_runR Documentation

Tune hyperparameters using training flags

Description

Run all combinations of the specifed training flags. The number of combinations can be reduced by specifying the sample parameter, which will result in a random sample of the flag combinations being run.

Usage

tuning_run(
  file = "train.R",
  context = "local",
  config = Sys.getenv("R_CONFIG_ACTIVE", unset = "default"),
  flags = NULL,
  sample = NULL,
  properties = NULL,
  runs_dir = getOption("tfruns.runs_dir", "runs"),
  artifacts_dir = getwd(),
  echo = TRUE,
  confirm = interactive(),
  envir = parent.frame(),
  encoding = getOption("encoding")
)

Arguments

file

Path to training script (defaults to "train.R")

context

Run context (defaults to "local")

config

The configuration to use. Defaults to the active configuration for the current environment (as specified by the R_CONFIG_ACTIVE environment variable), or default when unset.

flags

Either a named list with flag values (multiple values can be provided for each flag) or a data frame that contains pre-generated combinations of flags (e.g. via base::expand.grid()). The latter can be useful for subsetting combinations. See 'Examples'.

sample

Sampling rate for flag combinations (defaults to running all combinations).

properties

Named character vector with run properties. Properties are additional metadata about the run which will be subsequently available via ls_runs().

runs_dir

Directory containing runs. Defaults to "runs" beneath the current working directory (or to the value of the tfruns.runs_dir R option if specified).

artifacts_dir

Directory to capture created and modified files within. Pass NULL to not capture any artifcats.

echo

Print expressions within training script

confirm

Confirm before executing tuning run.

envir

The environment in which the script should be evaluated

encoding

The encoding of the training script; see file().

Value

Data frame with summary of all training runs performed during tuning.

Examples

## Not run: 
library(tfruns)

# using a list as input to the flags argument
runs <- tuning_run(
  system.file("examples/mnist_mlp/mnist_mlp.R", package = "tfruns"),
  flags = list(
    dropout1 = c(0.2, 0.3, 0.4),
    dropout2 = c(0.2, 0.3, 0.4)
  )
)
runs[order(runs$eval_acc, decreasing = TRUE), ]

# using a data frame as input to the flags argument
# resulting in the same combinations above, but remove those
# where the combined dropout rate exceeds 1
grid <- expand.grid(
  dropout1 = c(0.2, 0.3, 0.4),
  dropout2 = c(0.2, 0.3, 0.4)
)
grid$combined_droput <- grid$dropout1 + grid$dropout2
grid <- grid[grid$combined_droput <= 1, ]
runs <- tuning_run(
  system.file("examples/mnist_mlp/mnist_mlp.R", package = "tfruns"),
  flags = grid[, c("dropout1", "dropout2")]
)

## End(Not run)

tfruns documentation built on May 29, 2024, 7:31 a.m.