# Initialize model environments
all_modes <- c("partition")
# ------------------------------------------------------------------------------
pred_types <- c("cluster")
# ------------------------------------------------------------------------------
## Rules about model-related information
### Definitions:
# - the model is the model type (e.g. "k_means", etc)
# - the model's mode is the species of model such as "partition"
# - the engines are within a model and mode and describe the
# method/implementation of the model in question. These are often R package
# names.
### The package dependencies are model- and engine-specific.
### They are used across modes
### The `fit` information is a list of data that is needed to fit the model.
### This information is specific to an engine and mode.
### The `predict` information is also list of data that is needed to make some
### sort of prediction on the model object. The possible types are contained
### in `pred_types` and this information is specific to the engine, mode, and
### type (although there are no types across different modes).
# ------------------------------------------------------------------------------
celery <- rlang::new_environment()
celery$models <- NULL
celery$modes <- c(all_modes, "unknown")
# check if class and mode and engine are compatible
check_spec_mode_engine_val <- function(cls, eng, mode) {
all_modes <- get_from_env_celery(paste0(cls, "_modes"))
if (!(mode %in% all_modes)) {
rlang::abort(paste0("'", mode, "' is not a known mode for model `", cls, "()`."))
}
model_info <- rlang::env_get(get_model_env_celery(), cls)
# Cases where the model definition is in celery but all of the engines
# are contained in a different package
if (nrow(model_info) == 0) {
check_mode_with_no_engine(cls, mode)
return(invisible(NULL))
}
# ------------------------------------------------------------------------------
# First check engine against any mode for the given model class
spec_engs <- model_info$engine
# engine is allowed to be NULL
if (!is.null(eng) && !(eng %in% spec_engs)) {
rlang::abort(
paste0(
"Engine '", eng, "' is not supported for `", cls, "()`. See ",
"`show_engines('", cls, "')`."
)
)
}
# ----------------------------------------------------------------------------
# Check modes based on model and engine
spec_modes <- model_info$mode
if (!is.null(eng)) {
spec_modes <- spec_modes[model_info$engine == eng]
}
spec_modes <- unique(c("unknown", spec_modes))
if (is.null(mode) || length(mode) > 1) {
stop_incompatible_mode(spec_modes, eng)
} else if (!(mode %in% spec_modes)) {
stop_incompatible_mode(spec_modes, eng)
}
# ----------------------------------------------------------------------------
# Check engine based on model and model
# How check for compatibility with the chosen mode (if any)
if (!is.null(mode) && mode != "unknown") {
spec_engs <- spec_engs[model_info$mode == mode]
}
spec_engs <- unique(spec_engs)
if (!is.null(eng) && !(eng %in% spec_engs)) {
stop_incompatible_engine(spec_engs, mode)
}
invisible(NULL)
}
#' Working with the celery model environment
#'
#' These functions read and write to the environment where the package stores
#' information about model specifications.
#'
#' @param items A character string of objects in the model environment.
#' @param ... Named values that will be assigned to the model environment.
#' @param name A single character value for a new symbol in the model environment.
#' @param value A single value for a new value in the model environment.
#' @keywords internal
#' @examples
#' # Access the model data:
#' current_code <- get_model_env_celery()
#' ls(envir = current_code)
#' @keywords internal
#' @export
get_model_env_celery <- function() {
current <- utils::getFromNamespace("celery", ns = "celery")
current
}
#' @rdname get_model_env_celery
#' @keywords internal
#' @export
get_from_env_celery <- function(items) {
mod_env <- get_model_env_celery()
rlang::env_get(mod_env, items, default = NULL)
}
#' @rdname get_model_env_celery
#' @keywords internal
#' @export
set_env_val_celery <- function(name, value) {
if (length(name) != 1 || !is.character(name)) {
rlang::abort("`name` should be a single character value.")
}
mod_env <- get_model_env_celery()
x <- list(value)
names(x) <- name
rlang::env_bind(mod_env, !!!x)
}
# ------------------------------------------------------------------------------
#' Tools to Register Models
#'
#' These functions are similar to constructors and can be used to validate
#' that there are no conflicts with the underlying model structures used by the
#' package.
#'
#' @param model A single character string for the model type (e.g.
#' `"k_means"`, etc).
#' @param mode A single character string for the model mode (e.g. "partition").
#' @param eng A single character string for the model engine.
#' @param arg A single character string for the model argument name.
#' @param has_submodel A single logical for whether the argument
#' can make predictions on multiple submodels at once.
#' @param func A named character vector that describes how to call
#' a function. `func` should have elements `pkg` and `fun`. The
#' former is optional but is recommended and the latter is
#' required. For example, `c(pkg = "stats", fun = "lm")` would be
#' used to invoke the usual linear regression function. In some
#' cases, it is helpful to use `c(fun = "predict")` when using a
#' package's `predict` method.
#' @param fit_obj A list with elements `interface`, `protect`,
#' `func` and `defaults`. See the package vignette "Making a
#' `celery` model from scratch".
#' @param pred_obj A list with elements `pre`, `post`, `func`, and `args`.
#' @param type A single character value for the type of prediction. Possible
#' values are: `cluster` and `raw`.
#' @param pkg An options character string for a package name.
#' @param celery A single character string for the "harmonized" argument name
#' that `celery` exposes.
#' @param original A single character string for the argument name that
#' underlying model function uses.
#' @param value A list that conforms to the `fit_obj` or `pred_obj` description
#' below, depending on context.
#' @param pre,post Optional functions for pre- and post-processing of prediction
#' results.
#' @param options A list of options for engine-specific preprocessing encodings.
#' See Details below.
#' @param ... Optional arguments that should be passed into the `args` slot for
#' prediction objects.
#' @keywords internal
#' @details These functions are available for users to add their own models or
#' engines (in a package or otherwise) so that they can be accessed using
#' `celery`.
#'
#' In short, `celery` stores an environment object that contains all of the
#' information and code about how models are used (e.g. fitting, predicting,
#' etc). These functions can be used to add models to that environment as well
#' as helper functions that can be used to makes sure that the model data is
#' in the right format.
#'
#' `check_model_exists_celery()` checks the model value and ensures that the
#' model has already been registered. `check_model_doesnt_exist_celery()`
#' checks the model value and also checks to see if it is novel in the
#' environment.
#'
#' The options for engine-specific encodings dictate how the predictors should
#' be handled. These options ensure that the data that `celery` gives to the
#' underlying model allows for a model fit that is as similar as possible to
#' what it would have produced directly.
#'
#' For example, if `fit()` is used to fit a model that does not have a formula
#' interface, typically some predictor preprocessing must be conducted.
#' `glmnet` is a good example of this.
#'
#' There are four options that can be used for the encodings:
#'
#' `predictor_indicators` describes whether and how to create indicator/dummy
#' variables from factor predictors. There are three options: `"none"` (do not
#' expand factor predictors), `"traditional"` (apply the standard
#' `model.matrix()` encodings), and `"one_hot"` (create the complete set
#' including the baseline level for all factors). This encoding only affects
#' cases when [fit.cluster_spec()] is used and the underlying model has an x/y
#' interface.
#'
#' Another option is `compute_intercept`; this controls whether
#' `model.matrix()` should include the intercept in its formula. This affects
#' more than the inclusion of an intercept column. With an intercept,
#' `model.matrix()` computes dummy variables for all but one factor levels.
#' Without an intercept, `model.matrix()` computes a full set of indicators
#' for the _first_ factor variable, but an incomplete set for the remainder.
#'
#' Next, the option `remove_intercept` will remove the intercept column
#' _after_ `model.matrix()` is finished. This can be useful if the model
#' function (e.g. `lm()`) automatically generates an intercept.
#'
#' Finally, `allow_sparse_x` specifies whether the model function can natively
#' accommodate a sparse matrix representation for predictors during fitting
#' and tuning.
#'
#' @examples
#' # set_new_model_celery("shallow_learning_model")
#'
#' # Show the information about a model:
#' show_model_info_celery("k_means")
#' @keywords internal
#' @export
set_new_model_celery <- function(model) {
check_model_doesnt_exist_celery(model)
current <- get_model_env_celery()
set_env_val_celery("models", c(current$models, model))
set_env_val_celery(
model,
dplyr::tibble(engine = character(0), mode = character(0))
)
set_env_val_celery(
paste0(model, "_pkgs"),
dplyr::tibble(engine = character(0), pkg = list(), mode = character(0))
)
set_env_val_celery(paste0(model, "_modes"), "unknown")
set_env_val_celery(
paste0(model, "_args"),
dplyr::tibble(
engine = character(0),
celery = character(0),
original = character(0),
func = list(),
has_submodel = logical(0)
)
)
set_env_val_celery(
paste0(model, "_fit"),
dplyr::tibble(
engine = character(0),
mode = character(0),
value = list()
)
)
set_env_val_celery(
paste0(model, "_predict"),
dplyr::tibble(
engine = character(0),
mode = character(0),
type = character(0),
value = list()
)
)
invisible(NULL)
}
#' @rdname set_new_model_celery
#' @export
check_model_doesnt_exist_celery <- function(model) {
if (rlang::is_missing(model) || length(model) != 1 || !is.character(model)) {
rlang::abort("Please supply a character string for a model name (e.g. `'k_means'`)")
}
current <- get_model_env_celery()
if (any(current$models == model)) {
rlang::abort(glue::glue("Model `{model}` already exists"))
}
invisible(NULL)
}
#' @rdname set_new_model_celery
#' @keywords internal
#' @export
set_model_mode_celery <- function(model, mode) {
check_model_exists_celery(model)
check_mode_val(mode)
current <- get_model_env_celery()
if (!any(current$modes == mode)) {
current$modes <- unique(c(current$modes, mode))
}
set_env_val_celery(
paste0(model, "_modes"),
unique(c(get_from_env_celery(paste0(model, "_modes")), mode))
)
invisible(NULL)
}
#' @rdname set_new_model_celery
#' @export
check_model_exists_celery <- function(model) {
if (rlang::is_missing(model) || length(model) != 1 || !is.character(model)) {
rlang::abort("Please supply a character string for a model name (e.g. `'k_means'`)")
}
current <- get_model_env_celery()
if (!any(current$models == model)) {
rlang::abort(glue::glue("Model `{model}` has not been registered."))
}
invisible(NULL)
}
check_mode_val <- function(mode) {
if (rlang::is_missing(mode) || length(mode) != 1 || !is.character(mode)) {
rlang::abort("Please supply a character string for a mode (e.g. `'partition'`).")
}
invisible(NULL)
}
# ------------------------------------------------------------------------------
#' @rdname set_new_model_celery
#' @keywords internal
#' @export
set_model_engine_celery <- function(model, mode, eng) {
check_model_exists_celery(model)
check_mode_val(mode)
check_eng_val(eng)
check_mode_val(eng)
check_mode_for_new_engine(model, eng, mode)
current <- get_model_env_celery()
new_eng <- dplyr::tibble(engine = eng, mode = mode)
old_eng <- get_from_env_celery(model)
engs <-
old_eng %>%
dplyr::bind_rows(new_eng) %>%
dplyr::distinct()
set_env_val_celery(model, engs)
set_model_mode_celery(model, mode)
invisible(NULL)
}
check_eng_val <- function(eng) {
if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng)) {
rlang::abort("Please supply a character string for an engine name (e.g. `'stats'`)")
}
invisible(NULL)
}
check_mode_for_new_engine <- function(cls, eng, mode) {
all_modes <- get_from_env_celery(paste0(cls, "_modes"))
if (!(mode %in% all_modes)) {
rlang::abort(paste0("'", mode, "' is not a known mode for model `", cls, "()`."))
}
invisible(NULL)
}
#' @rdname set_new_model_celery
#' @keywords internal
#' @export
set_dependency_celery <- function(model, eng, pkg = "celery", mode = NULL) {
check_model_exists_celery(model)
check_eng_val(eng)
check_pkg_val(pkg)
current <- get_model_env_celery()
model_info <- get_from_env_celery(model)
pkg_info <- get_from_env_celery(paste0(model, "_pkgs"))
# ----------------------------------------------------------------------------
# Check engine
has_engine <-
model_info %>%
dplyr::distinct(engine) %>%
dplyr::filter(engine == eng) %>%
nrow()
if (has_engine != 1) {
rlang::abort(
glue::glue("The engine '{eng}' has not been registered for model '{model}'.")
)
}
# ----------------------------------------------------------------------------
# check mode; if missing assign all modes
all_modes <- unique(model_info$mode[model_info$engine == eng])
if (is.null(mode)) {
# For backward compatibility
mode <- all_modes
} else {
if (length(mode) > 1) {
rlang::abort("'mode' should be a single character value or NULL.")
}
if (!any(mode == all_modes)) {
rlang::abort(glue::glue("mode '{mode}' is not a valid mode for '{model}'"))
}
}
# ----------------------------------------------------------------------------
new_pkgs <- tibble::tibble(engine = eng, pkg = list(pkg), mode = mode)
# Add the new entry to the existing list for this engine (if any) and
# keep unique results
eng_pkgs <-
pkg_info %>%
dplyr::filter(engine == eng) %>%
dplyr::bind_rows(new_pkgs) %>%
# Take unique combinations in case packages have alread been registered
dplyr::distinct() %>%
# In case there are existing results (in a list column pkg), aggregate the
# list results and re-list their unique values.
dplyr::group_by(mode, engine) %>%
dplyr::summarize(pkg = list(unique(unlist(pkg))), .groups = "drop") %>%
dplyr::select(engine, pkg, mode)
pkg_info <-
pkg_info %>%
dplyr::filter(engine != eng) %>%
dplyr::bind_rows(eng_pkgs) %>%
dplyr::arrange(engine, mode)
set_env_val_celery(paste0(model, "_pkgs"), pkg_info)
invisible(NULL)
}
#' @rdname set_new_model_celery
#' @keywords internal
#' @export
get_dependency_celery <- function(model) {
check_model_exists_celery(model)
pkg_name <- paste0(model, "_pkgs")
if (!any(pkg_name != rlang::env_names(get_model_env_celery()))) {
rlang::abort(glue::glue("`{model}` does not have a dependency list in celery."))
}
rlang::env_get(get_model_env_celery(), pkg_name)
}
check_pkg_val <- function(pkg) {
if (rlang::is_missing(pkg) || length(pkg) != 1 || !is.character(pkg)) {
rlang::abort("Please supply a single character value for the package name.")
}
invisible(NULL)
}
#' @rdname set_new_model_celery
#' @keywords internal
#' @export
set_fit_celery <- function(model, mode, eng, value) {
check_model_exists_celery(model)
check_eng_val(eng)
check_spec_mode_engine_val(model, eng, mode)
check_fit_info(value)
current <- get_model_env_celery()
model_info <- get_from_env_celery(model)
old_fits <- get_from_env_celery(paste0(model, "_fit"))
has_engine <-
model_info %>%
dplyr::filter(engine == eng & mode == !!mode) %>%
nrow()
if (has_engine != 1) {
rlang::abort(glue::glue(
"The combination of '{eng}' and mode '{mode}' has not ",
"been registered for model '{model}'."
))
}
has_fit <-
old_fits %>%
dplyr::filter(engine == eng & mode == !!mode) %>%
nrow()
if (has_fit > 0) {
rlang::abort(glue::glue(
"The combination of '{eng}' and mode '{mode}' ",
"already has a fit component for model '{model}'."
))
}
new_fit <-
dplyr::tibble(
engine = eng,
mode = mode,
value = list(value)
)
updated <- try(dplyr::bind_rows(old_fits, new_fit), silent = TRUE)
if (inherits(updated, "try-error")) {
rlang::abort("An error occured when adding the new fit module.")
}
set_env_val_celery(
paste0(model, "_fit"),
updated
)
invisible(NULL)
}
#' @rdname set_new_model_celery
#' @keywords internal
#' @export
get_fit_celery <- function(model) {
check_model_exists_celery(model)
fit_name <- paste0(model, "_fit")
if (!any(fit_name != rlang::env_names(get_model_env_celery()))) {
rlang::abort(glue::glue("`{model}` does not have a `fit` method in celery."))
}
rlang::env_get(get_model_env_celery(), fit_name)
}
check_fit_info <- function(fit_obj) {
if (is.null(fit_obj)) {
rlang::abort("The `fit` module cannot be NULL.")
}
# check required data elements
exp_nms <- c("defaults", "func", "interface", "protect")
has_req_nms <- exp_nms %in% names(fit_obj)
if (!all(has_req_nms)) {
rlang::abort(
glue::glue(
"The `fit` module should have elements: ",
glue::glue_collapse(glue::glue("`{exp_nms}`"), sep = ", ")
)
)
}
# check optional data elements
opt_nms <- c("data")
other_nms <- setdiff(exp_nms, names(fit_obj))
has_opt_nms <- other_nms %in% opt_nms
if (any(!has_opt_nms)) {
msg <- glue::glue(
"The `fit` module can only have optional elements: ",
glue::glue_collapse(glue::glue("`{exp_nms}`"), sep = ", ")
)
rlang::abort(msg)
}
if (any(other_nms == "data")) {
data_nms <- names(fit_obj$data)
if (length(data_nms == 0) || any(data_nms == "")) {
rlang::abort("All elements of the `data` argument vector must be named.")
}
}
check_interface_val(fit_obj$interface)
check_func_val(fit_obj$func)
if (!is.list(fit_obj$defaults)) {
rlang::abort("The `defaults` element should be a list: ")
}
invisible(NULL)
}
check_interface_val <- function(x) {
exp_interf <- c("data.frame", "formula", "matrix")
if (length(x) != 1 || !(x %in% exp_interf)) {
rlang::abort(
glue::glue(
"The `interface` element should have a single value of: ",
glue::glue_collapse(glue::glue("`{exp_interf}`"), sep = ", ")
)
)
}
invisible(NULL)
}
check_func_val <- function(func) {
msg <-
paste(
"`func` should be a named vector with element 'fun' and the optional ",
"elements 'pkg', 'range', 'trans', and 'values'.",
"`func` and 'pkg' should both be single character strings."
)
if (rlang::is_missing(func) || !is.vector(func)) {
rlang::abort(msg)
}
nms <- sort(names(func))
if (all(is.null(nms))) {
rlang::abort(msg)
}
if (length(func) == 1) {
if (isTRUE(any(nms != "fun"))) {
rlang::abort(msg)
}
} else {
# check for extra names:
allow_nms <- c("fun", "pkg", "range", "trans", "values")
nm_check <- nms %in% c("fun", "pkg", "range", "trans", "values")
not_allowed <- nms[!(nms %in% allow_nms)]
if (length(not_allowed) > 0) {
rlang::abort(msg)
}
}
if (!is.character(func[["fun"]])) {
rlang::abort(msg)
}
if (any(nms == "pkg") && !is.character(func[["pkg"]])) {
rlang::abort(msg)
}
invisible(NULL)
}
#' @rdname set_new_model_celery
#' @keywords internal
#' @export
get_encoding_celery <- function(model) {
check_model_exists_celery(model)
nm <- paste0(model, "_encoding")
res <- try(get_from_env_celery(nm), silent = TRUE)
if (inherits(res, "try-error")) {
# for objects made before encodings were specified in celery
res <-
get_from_env_celery(model) %>%
dplyr::mutate(
model = model,
predictor_indicators = "traditional",
compute_intercept = TRUE,
remove_intercept = TRUE,
allow_sparse_x = FALSE
) %>%
dplyr::select(
model, engine, mode, predictor_indicators,
compute_intercept, remove_intercept
)
}
res
}
#' @export
#' @rdname set_new_model_celery
#' @keywords internal
set_encoding_celery <- function(model, mode, eng, options) {
check_model_exists_celery(model)
check_eng_val(eng)
check_mode_val(mode)
check_encodings(options)
keys <- tibble::tibble(model = model, engine = eng, mode = mode)
options <- tibble::as_tibble(options)
new_values <- dplyr::bind_cols(keys, options)
current_db_list <- ls(envir = get_model_env_celery())
nm <- paste(model, "encoding", sep = "_")
if (any(current_db_list == nm)) {
current <- get_from_env_celery(nm)
dup_check <-
current %>%
dplyr::inner_join(
new_values,
by = c("model", "engine", "mode", "predictor_indicators")
)
if (nrow(dup_check)) {
rlang::abort(glue::glue("Engine '{eng}' and mode '{mode}' already have defined encodings for model '{model}'."))
}
} else {
current <- NULL
}
db_values <- dplyr::bind_rows(current, new_values)
set_env_val_celery(nm, db_values)
invisible(NULL)
}
check_encodings <- function(x) {
if (!is.list(x)) {
rlang::abort("`values` should be a list.")
}
req_args <- list(
predictor_indicators = rlang::na_chr,
compute_intercept = rlang::na_lgl,
remove_intercept = rlang::na_lgl,
allow_sparse_x = rlang::na_lgl
)
missing_args <- setdiff(names(req_args), names(x))
if (length(missing_args) > 0) {
rlang::abort(
glue::glue(
"The values passed to `set_encoding_celery()` are missing arguments: ",
paste0("'", missing_args, "'", collapse = ", ")
)
)
}
extra_args <- setdiff(names(x), names(req_args))
if (length(extra_args) > 0) {
rlang::abort(
glue::glue(
"The values passed to `set_encoding_celery()` had extra arguments: ",
paste0("'", extra_args, "'", collapse = ", ")
)
)
}
invisible(x)
}
# ------------------------------------------------------------------------------
#' @rdname set_new_model_celery
#' @keywords internal
#' @export
set_model_arg_celery <- function(model, eng, celery, original, func, has_submodel) {
check_model_exists_celery(model)
check_eng_val(eng)
check_arg_val(celery)
check_arg_val(original)
check_func_val(func)
check_submodels_val(has_submodel)
current <- get_model_env_celery()
old_args <- get_from_env_celery(paste0(model, "_args"))
new_arg <-
dplyr::tibble(
engine = eng,
celery = celery,
original = original,
func = list(func),
has_submodel = has_submodel
)
updated <- try(dplyr::bind_rows(old_args, new_arg), silent = TRUE)
if (inherits(updated, "try-error")) {
rlang::abort("An error occured when adding the new argument.")
}
updated <- vctrs::vec_unique(updated)
set_env_val_celery(paste0(model, "_args"), updated)
invisible(NULL)
}
check_arg_val <- function(arg) {
if (rlang::is_missing(arg) || length(arg) != 1 || !is.character(arg)) {
rlang::abort("Please supply a character string for the argument.")
}
invisible(NULL)
}
check_submodels_val <- function(has_submodel) {
if (!is.logical(has_submodel) || length(has_submodel) != 1) {
rlang::abort("The `submodels` argument should be a single logical.")
}
invisible(NULL)
}
check_mode_with_no_engine <- function(cls, mode) {
spec_modes <- get_from_env_celery(paste0(cls, "_modes"))
if (!(mode %in% spec_modes)) {
stop_incompatible_mode(spec_modes, cls = cls)
}
}
stop_incompatible_mode <- function(spec_modes, eng = NULL, cls = NULL) {
if (is.null(eng) & is.null(cls)) {
msg <- "Available modes are: "
}
if (!is.null(eng) & is.null(cls)) {
msg <- glue::glue("Available modes for engine {eng} are: ")
}
if (is.null(eng) & !is.null(cls)) {
msg <- glue::glue("Available modes for model type {cls} are: ")
}
if (!is.null(eng) & !is.null(cls)) {
msg <- glue::glue("Available modes for model type {cls} with engine {eng} are: ")
}
msg <- glue::glue(
msg,
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
)
rlang::abort(msg)
}
stop_incompatible_engine <- function(spec_engs, mode) {
msg <- glue::glue(
"Available engines for mode {mode} are: ",
glue::glue_collapse(glue::glue("'{spec_engs}'"), sep = ", ")
)
rlang::abort(msg)
}
# ------------------------------------------------------------------------------
#' @rdname set_new_model_celery
#' @keywords internal
#' @export
show_model_info_celery <- function(model) {
check_model_exists_celery(model)
current <- get_model_env_celery()
cat("Information for `", model, "`\n", sep = "")
cat(
" modes:",
paste0(get_from_env_celery(paste0(model, "_modes")), collapse = ", "),
"\n\n"
)
engines <- get_from_env_celery(model)
if (nrow(engines) > 0) {
cat(" engines: \n")
engines %>%
dplyr::mutate(
mode = format(paste0(mode, ": "))
) %>%
dplyr::group_by(mode) %>%
dplyr::summarize(
engine = paste0(sort(engine), collapse = ", ")
) %>%
dplyr::mutate(
lab = paste0(" ", mode, engine, "\n")
) %>%
dplyr::ungroup() %>%
dplyr::pull(lab) %>%
cat(sep = "")
cat("\n")
} else {
cat(" no registered engines.\n\n")
}
args <- get_from_env_celery(paste0(model, "_args"))
if (nrow(args) > 0) {
cat(" arguments: \n")
args %>%
dplyr::select(engine, celery, original) %>%
dplyr::distinct() %>%
dplyr::mutate(
engine = format(paste0(" ", engine, ": ")),
celery = paste0(" ", format(celery), " --> ", original, "\n")
) %>%
dplyr::group_by(engine) %>%
dplyr::mutate(
engine2 = ifelse(dplyr::row_number() == 1, engine, ""),
celery = ifelse(dplyr::row_number() == 1, paste0("\n", celery), celery),
lab = paste0(engine2, celery)
) %>%
dplyr::ungroup() %>%
dplyr::pull(lab) %>%
cat(sep = "")
cat("\n")
} else {
cat(" no registered arguments.\n\n")
}
fits <- get_from_env_celery(paste0(model, "_fit"))
if (nrow(fits) > 0) {
cat(" fit modules:\n")
fits %>%
dplyr::select(-value) %>%
dplyr::mutate(engine = paste0(" ", engine)) %>%
as.data.frame() %>%
print(row.names = FALSE)
cat("\n")
} else {
cat(" no registered fit modules.\n\n")
}
preds <- get_from_env_celery(paste0(model, "_predict"))
if (nrow(preds) > 0) {
cat(" prediction modules:\n")
preds %>%
dplyr::group_by(mode, engine) %>%
dplyr::summarize(methods = paste0(sort(type), collapse = ", ")) %>%
dplyr::ungroup() %>%
dplyr::mutate(mode = paste0(" ", mode)) %>%
as.data.frame() %>%
print(row.names = FALSE)
cat("\n")
} else {
cat(" no registered prediction modules.\n\n")
}
invisible(NULL)
}
# ------------------------------------------------------------------------------
#' @rdname set_new_model_celery
#' @keywords internal
#' @export
set_pred_celery <- function(model, mode, eng, type, value) {
check_model_exists_celery(model)
check_eng_val(eng)
check_spec_mode_engine_val(model, eng, mode)
check_pred_info(value, type)
check_unregistered(model, mode, eng)
model_info <- get_from_env_celery(model)
new_pred <-
dplyr::tibble(
engine = eng,
mode = mode,
type = type,
value = list(value)
)
pred_check <- is_discordant_info(model, mode, eng, new_pred, pred_type = type, component = "predict")
if (!pred_check) {
return(invisible(NULL))
}
old_pred <- get_from_env_celery(paste0(model, "_predict"))
updated <- try(dplyr::bind_rows(old_pred, new_pred), silent = TRUE)
if (inherits(updated, "try-error")) {
rlang::abort("An error occured when adding the new fit module.")
}
set_env_val_celery(paste0(model, "_predict"), updated)
invisible(NULL)
}
#' @rdname set_new_model_celery
#' @keywords internal
#' @export
get_pred_type_celery <- function(model, type) {
check_model_exists_celery(model)
pred_name <- paste0(model, "_predict")
if (!any(pred_name != rlang::env_names(get_model_env_celery()))) {
rlang::abort(glue::glue("`{model}` does not have any `pred` methods in celery."))
}
all_preds <- rlang::env_get(get_model_env_celery(), pred_name)
if (!any(all_preds$type == type)) {
rlang::abort(glue::glue("`{model}` does not have any prediction methods incelery."))
}
dplyr::filter(all_preds, type == !!type)
}
check_pred_info <- function(pred_obj, type) {
if (all(type != pred_types)) {
rlang::abort(
glue::glue(
"The prediction type should be one of: ",
glue::glue_collapse(glue::glue("'{pred_types}'"), sep = ", ")
)
)
}
exp_nms <- c("args", "func", "post", "pre")
if (!isTRUE(all.equal(sort(names(pred_obj)), exp_nms))) {
rlang::abort(
glue::glue(
"The `predict` module should have elements: ",
glue::glue_collapse(glue::glue("`{exp_nms}`"), sep = ", ")
)
)
}
if (!is.null(pred_obj$pre) & !is.function(pred_obj$pre)) {
rlang::abort("The `pre` module should be null or a function: ")
}
if (!is.null(pred_obj$post) & !is.function(pred_obj$post)) {
rlang::abort("The `post` module should be null or a function: ")
}
check_func_val(pred_obj$func)
if (!is.list(pred_obj$args)) {
rlang::abort("The `args` element should be a list. ")
}
invisible(NULL)
}
check_unregistered <- function(model, mode, eng) {
model_info <- get_from_env_celery(model)
has_engine <-
model_info %>%
dplyr::filter(engine == eng & mode == !!mode) %>%
nrow()
if (has_engine != 1) {
rlang::abort(
glue::glue(
"The combination of engine '{eng}' and mode '{mode}' has not ",
"been registered for model '{model}'."
)
)
}
invisible(NULL)
}
# This will be used to see if the same information is being registered for the
# same model/mode/engine (and prediction type). If it already exists and the
# new information is different, fail with a message. See issue parsnip/#653
is_discordant_info <- function(model, mode, eng, candidate,
pred_type = NULL, component = "fit") {
current <- get_from_env_celery(paste0(model, "_", component))
# For older versions of parsnip before set_encoding()
new_encoding <- is.null(current) & component == "encoding"
if (new_encoding) {
return(TRUE)
} else {
current <- dplyr::filter(current, engine == eng & mode == !!mode)
}
if (component == "predict" & !is.null(pred_type)) {
current <- dplyr::filter(current, type == pred_type)
p_type <- paste0("and prediction type '", pred_type, "'")
} else {
p_type <- ""
}
if (nrow(current) == 0) {
return(TRUE)
}
same_info <- isTRUE(all.equal(current, candidate, check.environment = FALSE))
if (!same_info) {
rlang::abort(
glue::glue(
"The combination of engine '{eng}' and mode '{mode}' {p_type} already has ",
"{component} data for model '{model}' and the new information being ",
"registered is different."
)
)
}
FALSE
}
check_spec_pred_type <- function(object, type) {
if (!spec_has_pred_type(object, type)) {
possible_preds <- names(object$spec$method$pred)
rlang::abort(c(
glue::glue("No {type} prediction method available for this model."),
glue::glue(
"Value for `type` should be one of: ",
glue::glue_collapse(glue::glue("'{possible_preds}'"), sep = ", ")
)
))
}
invisible(NULL)
}
spec_has_pred_type <- function(object, type) {
possible_preds <- names(object$spec$method$pred)
any(possible_preds == type)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.