Nothing
#' Compile and Validate Keras Model Architectures
#'
#' @title Compile Keras Models Over a Grid of Hyperparameters
#' @description
#' Pre-compiles Keras models for each hyperparameter combination in a grid.
#'
#' This function is a powerful debugging tool to use before running a full
#' `tune::tune_grid()`. It allows you to quickly validate multiple model
#' architectures, ensuring they can be successfully built and compiled without
#' the time-consuming process of actually fitting them. It helps catch common
#' errors like incompatible layer shapes or invalid argument values early.
#'
#' @details
#' The function iterates through each row of the provided `grid`. For each
#' hyperparameter combination, it attempts to build and compile the Keras model
#' defined by the `spec`. The process is wrapped in a `try-catch` block to
#' gracefully handle and report any errors that occur during model instantiation
#' or compilation.
#'
#' The output is a tibble that mirrors the input `grid`, with additional columns
#' containing the compiled model object or the error message, making it easy to
#' inspect which architectures are valid.
#'
#' @param spec A `parsnip` model specification created by
#' `create_keras_sequential_spec()` or `create_keras_functional_spec()`.
#' @param grid A `tibble` or `data.frame` containing the grid of hyperparameters
#' to evaluate. Each row represents a unique model architecture to be compiled.
#' @param x A data frame or matrix of predictors. This is used to infer the
#' `input_shape` for the Keras model.
#' @param y A vector or factor of outcomes. This is used to infer the output
#' shape and the default loss function for the Keras model.
#'
#' @return A `tibble` with the following columns:
#' \itemize{
#' \item Columns from the input `grid`.
#' \item `compiled_model`: A list-column containing the compiled Keras model
#' objects. If compilation failed, the element will be `NULL`.
#' \item `error`: A list-column containing `NA` for successes or a
#' character string with the error message for failures.
#' }
#'
#' @examples
#' \donttest{
#' if (requireNamespace("keras3", quietly = TRUE)) {
#' library(keras3)
#' library(parsnip)
#' library(dials)
#'
#' # 1. Define layer blocks
#' input_block <- function(model, input_shape) {
#' keras_model_sequential(input_shape = input_shape)
#' }
#' hidden_block <- function(model, units = 32) {
#' model |> layer_dense(units = units, activation = "relu")
#' }
#' output_block <- function(model, num_classes) {
#' model |> layer_dense(units = num_classes, activation = "softmax")
#' }
#'
#' # 2. Define a kerasnip model specification
#' create_keras_sequential_spec(
#' model_name = "my_mlp_grid",
#' layer_blocks = list(
#' input = input_block,
#' hidden = hidden_block,
#' output = output_block
#' ),
#' mode = "classification"
#' )
#'
#' mlp_spec <- my_mlp_grid(
#' hidden_units = tune(),
#' compile_loss = "categorical_crossentropy",
#' compile_optimizer = "adam"
#' )
#'
#' # 3. Create a hyperparameter grid
#' # Include an invalid value (-10) to demonstrate error handling
#' param_grid <- tibble::tibble(
#' hidden_units = c(32, 64, -10)
#' )
#'
#' # 4. Prepare dummy data
#' x_train <- matrix(rnorm(100 * 10), ncol = 10)
#' y_train <- factor(sample(0:1, 100, replace = TRUE))
#'
#' # 5. Compile models over the grid
#' compiled_grid <- compile_keras_grid(
#' spec = mlp_spec,
#' grid = param_grid,
#' x = x_train,
#' y = y_train
#' )
#'
#' print(compiled_grid)
#' remove_keras_spec("my_mlp_grid")
#'
#' # 6. Inspect the results
#' # The row with `hidden_units = -10` will show an error.
#' }
#' }
#' @importFrom dplyr bind_rows filter select
#' @importFrom cli cli_h1 cli_alert_danger cli_h2 cli_text cli_bullets cli_code cli_alert_info cli_alert_success
#' @export
compile_keras_grid <- function(spec, grid, x, y) {
# Input validation
if (!inherits(spec, "model_spec")) {
stop("`spec` must be a `parsnip` model specification.")
}
if (!is.data.frame(grid)) {
stop("`grid` must be a data frame or tibble.")
}
model_env <- get_model_env()
model_name <- class(spec)[1]
fit_info_name <- paste0(model_name, "_fit")
model_info <- model_env[[fit_info_name]]
if (is.null(model_info)) {
stop("Could not find model information for this specification.")
}
fit_fun_char <- model_info |>
purrr::pluck("value") |>
purrr::pluck(1) |>
purrr::pluck("func") |>
purrr::pluck(2)
build_fn <- if (any(grepl("sequential", fit_fun_char))) {
build_and_compile_sequential_model
} else if (any(grepl("functional", fit_fun_char))) {
build_and_compile_functional_model
} else {
stop("Unsupported fit function in model spec.")
}
layer_blocks <- model_info |>
purrr::pluck("value") |>
purrr::pluck(1) |>
purrr::pluck("defaults") |>
purrr::pluck("layer_blocks")
# Prepare to store results
results <- purrr::map(1:nrow(grid), function(i) {
params <- as.list(grid[i, ])
active_args <- purrr::discard(spec$args, function(arg) {
inherits(rlang::get_expr(arg), "rlang_zap")
})
evaluated_args <- purrr::map(active_args, rlang::eval_tidy)
args <- list()
args$x <- x
args$y <- y
args$layer_blocks <- layer_blocks
for (name in names(evaluated_args)) {
args[[name]] <- evaluated_args[[name]]
}
for (name in names(params)) {
args[[name]] <- params[[name]]
}
# Use tryCatch to handle potential errors in model building/compilation
result <- tryCatch(
{
model <- do.call(build_fn, args)
# Capture the model summary
list(
compiled_model = list(model),
error = NA_character_
)
},
error = function(e) {
list(
compiled_model = list(NULL),
error = as.character(e$message)
)
}
)
# Combine grid params with the result
tibble::as_tibble(c(params, result))
})
# Combine all results into a single tibble
dplyr::bind_rows(results)
}
#' Filter a Grid to Only Valid Hyperparameter Sets
#'
#' @title Extract Valid Grid from Compilation Results
#' @description
#' This helper function filters the results from `compile_keras_grid()` to
#' return a new hyperparameter grid containing only the combinations that
#' compiled successfully.
#'
#' @details
#' After running `compile_keras_grid()`, you can use this function to remove
#' problematic hyperparameter combinations before proceeding to the full
#' `tune::tune_grid()`.
#'
#' @param compiled_grid A tibble, the result of a call to `compile_keras_grid()`.
#'
#' @return A tibble containing the subset of the original grid that resulted in
#' a successful model compilation. The `compiled_model` and `error` columns
#' are removed, leaving a clean grid ready for tuning.
#'
#' @examples
#' \donttest{
#' if (requireNamespace("keras3", quietly = TRUE)) {
#' library(keras3)
#' library(parsnip)
#' library(dials)
#'
#' # 1. Define layer blocks
#' input_block <- function(model, input_shape) {
#' keras_model_sequential(input_shape = input_shape)
#' }
#' hidden_block <- function(model, units = 32) {
#' model |> layer_dense(units = units, activation = "relu")
#' }
#' output_block <- function(model, num_classes) {
#' model |> layer_dense(units = num_classes, activation = "softmax")
#' }
#'
#' # 2. Define a kerasnip model specification
#' create_keras_sequential_spec(
#' model_name = "my_mlp_grid_2",
#' layer_blocks = list(
#' input = input_block,
#' hidden = hidden_block,
#' output = output_block
#' ),
#' mode = "classification"
#' )
#'
#' mlp_spec <- my_mlp_grid_2(
#' hidden_units = tune(),
#' compile_loss = "categorical_crossentropy",
#' compile_optimizer = "adam"
#' )
#'
#' # 3. Create a hyperparameter grid
#' param_grid <- tibble::tibble(
#' hidden_units = c(32, 64, -10)
#' )
#'
#' # 4. Prepare dummy data
#' x_train <- matrix(rnorm(100 * 10), ncol = 10)
#' y_train <- factor(sample(0:1, 100, replace = TRUE))
#'
#' # 5. Compile models over the grid
#' compiled_grid <- compile_keras_grid(
#' spec = mlp_spec,
#' grid = param_grid,
#' x = x_train,
#' y = y_train
#' )
#'
#' # 6. Extract the valid grid
#' valid_grid <- extract_valid_grid(compiled_grid)
#' print(valid_grid)
#' remove_keras_spec("my_mlp_grid_2")
#' }
#' }
#' @export
extract_valid_grid <- function(compiled_grid) {
if (
!is.data.frame(compiled_grid) ||
!all(
c("error", "compiled_model") %in% names(compiled_grid)
)
) {
stop(
"`compiled_grid` must be a data frame produced by `compile_keras_grid()`."
)
}
compiled_grid |>
dplyr::filter(is.na(error)) |>
dplyr::select(-c(compiled_model, error))
}
#' Display a Summary of Compilation Errors
#'
#' @title Inform About Compilation Errors
#' @description
#' This helper function inspects the results from `compile_keras_grid()` and
#' prints a formatted, easy-to-read summary of any compilation errors that
#' occurred.
#'
#' @details
#' This is most useful for interactive debugging of complex tuning grids where
#' some hyperparameter combinations may lead to invalid Keras models.
#'
#' @param compiled_grid A tibble, the result of a call to `compile_keras_grid()`.
#' @param n A single integer for the maximum number of distinct errors to
#' display in detail.
#'
#' @return Invisibly returns the input `compiled_grid`. Called for its side
#' effect of printing a summary to the console.
#'
#' @examples
#' \donttest{
#' if (requireNamespace("keras3", quietly = TRUE)) {
#' library(keras3)
#' library(parsnip)
#' library(dials)
#'
#' # 1. Define layer blocks
#' input_block <- function(model, input_shape) {
#' keras_model_sequential(input_shape = input_shape)
#' }
#' hidden_block <- function(model, units = 32) {
#' model |> layer_dense(units = units, activation = "relu")
#' }
#' output_block <- function(model, num_classes) {
#' model |> layer_dense(units = num_classes, activation = "softmax")
#' }
#'
#' # 2. Define a kerasnip model specification
#' create_keras_sequential_spec(
#' model_name = "my_mlp_grid_3",
#' layer_blocks = list(
#' input = input_block,
#' hidden = hidden_block,
#' output = output_block
#' ),
#' mode = "classification"
#' )
#'
#' mlp_spec <- my_mlp_grid_3(
#' hidden_units = tune(),
#' compile_loss = "categorical_crossentropy",
#' compile_optimizer = "adam"
#' )
#'
#' # 3. Create a hyperparameter grid
#' param_grid <- tibble::tibble(
#' hidden_units = c(32, 64, -10)
#' )
#'
#' # 4. Prepare dummy data
#' x_train <- matrix(rnorm(100 * 10), ncol = 10)
#' y_train <- factor(sample(0:1, 100, replace = TRUE))
#'
#' # 5. Compile models over the grid
#' compiled_grid <- compile_keras_grid(
#' spec = mlp_spec,
#' grid = param_grid,
#' x = x_train,
#' y = y_train
#' )
#'
#' # 6. Inform about errors
#' inform_errors(compiled_grid)
#' remove_keras_spec("my_mlp_grid_3")
#' }
#' }
#' @export
inform_errors <- function(compiled_grid, n = 10) {
if (
!is.data.frame(compiled_grid) || !all(c("error") %in% names(compiled_grid))
) {
stop(
"`compiled_grid` must be a data frame produced by `compile_keras_grid()`."
)
}
error_grid <- compiled_grid |>
dplyr::filter(!is.na(error))
if (nrow(error_grid) > 0) {
cli::cli_h1("Compilation Errors Summary")
cli::cli_alert_danger(
"{nrow(error_grid)} of {nrow(compiled_grid)} models failed to compile."
)
for (i in 1:min(nrow(error_grid), n)) {
row <- error_grid[i, ]
params <- row |> dplyr::select(-c(compiled_model, error))
cli::cli_h2("Error {i}/{nrow(error_grid)}")
cli::cli_text("Hyperparameters:")
cli::cli_bullets(paste0(names(params), ": ", as.character(params)))
cli::cli_text("Error Message:")
cli::cli_code(row$error)
}
if (nrow(error_grid) > n) {
cli::cli_alert_info("... and {nrow(error_grid) - n} more errors.")
}
} else {
cli::cli_alert_success("All models compiled successfully!")
}
invisible(compiled_grid)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.