# pillar/cli options to try to lock down formatting
options(useFancyQuotes = FALSE)
options(dplyr.print_min = 6, dplyr.print_max = 6)
options(cli.width = 85)
options(crayon.enabled = FALSE)
options(pillar.min_title_chars = Inf)

library(parsnip)
library(workflows)

# ------------------------------------------------------------------------------
# These are required to build md docs for parsnip and extensions

check_pkg_for_docs <- function(x){
  purrr::map(x, ~ rlang::check_installed(.x))
  purrr::map(x, ~ require(.x, character.only = TRUE))
}

rmd_pkgs <- c("tune", "glue", "dplyr", "parsnip", "dials", "glmnet", 
              "Cubist", "xrf", "ape")

check_pkg_for_docs(rmd_pkgs)
check_pkg_for_docs(parsnip:::extensions())


# ------------------------------------------------------------------------------
# Code to get information about main arguments and format the results to print

make_mode_list <- function(mod, eng) {
  modes <- c("regression", "classification", "censored regression")
  exts <-
    utils::read.delim(system.file("models.tsv", package = "parsnip")) %>%
    dplyr::filter(model == mod & engine == eng) %>% 
    dplyr::mutate(mode = factor(mode, levels = modes)) %>% 
    dplyr::arrange(mode)

  # Need to get mode-specific dependencies (see tidymodels/parsnip#629)
  exts
}

make_parameter_list <- function(x, defaults) {
  x %>%
    tune::tunable() %>%
    dplyr::select(-source, -component, -component_id, parsnip = name) %>%
    dplyr::mutate(
      dials = purrr::map(call_info, get_dials),
      label = purrr::map_chr(dials, ~ .x$label),
      type = purrr::map_chr(dials, ~ .x$type)
    ) %>%
    dplyr::inner_join(defaults, by = "parsnip") %>%
    dplyr::mutate(
      item =
        glue::glue("- `{parsnip}`: {label} (type: {type}, default: {default})\n\n")
    )
}

convert_args <- function(model_name) {
  envir <- get_model_env()

  args <-
    ls(envir) %>%
    tibble::tibble(name = .) %>%
    dplyr::filter(grepl("args", name)) %>%
    dplyr::mutate(model = sub("_args", "", name),
                  args  = purrr::map(name, ~envir[[.x]])) %>%
    dplyr::filter(grepl(model_name, model)) %>%
    tidyr::unnest(args) %>%
    dplyr::select(model:original) %>%
    dplyr::full_join(get_arg_defaults(model_name),
                     by = c("model", "engine", "parsnip", "original")) %>%
    dplyr::mutate(original = dplyr::if_else(!is.na(default),
                                            paste0(original, " (", default, ")"),
                                            original)) %>%
    dplyr::select(-default)

  convert_df <- args %>%
    dplyr::select(-model) %>%
    tidyr::pivot_wider(names_from = engine, values_from = original)

  convert_df %>%
    knitr::kable(col.names = paste0("**", colnames(convert_df), "**"))

}

get_arg_defaults <- function(model) {
  check_model_exists(model)
  gdf <- get(paste0("get_defaults_", model))
  gdf()
}

get_arg <- function(ns, f, arg) {
  args <- formals(getFromNamespace(f, ns))
  args <- as.list(args)
  as.character(args[[arg]])
}

get_dials <- function(x) {
  if (any(names(x) == "range")) {
    cl <- rlang::call2(x$fun, .ns = x$pkg, range = x$range)
  } else {
    cl <- rlang::call2(x$fun, .ns = x$pkg)
  }

  rlang::eval_tidy(cl)
}

# ------------------------------------------------------------------------------
# Write text about modes

descr_models <- function(mod, eng) {
  res <- get_from_env(mod) %>%
    dplyr::filter(engine == eng) %>%
    dplyr::distinct() %>%
    purrr::pluck("mode")

  if (length(res) == 1) {
    txt <- "is a single mode:"
  } else {
    txt <- "are multiple modes:"
  }
  paste("For this engine, there", txt, combine_words(res))
}

uses_extension <- function(mod, eng, mod_mode) {
  exts <-
    utils::read.delim(system.file("models.tsv", package = "parsnip")) %>%
    dplyr::filter(
      model == mod & 
        engine == eng & 
        mode == mod_mode & 
        pkg %in% parsnip:::extensions()
    ) %>% 
    dplyr::distinct(pkg) %>% 
    purrr::pluck("pkg")

  num_ext <- length(exts)
  if (num_ext > 1) {
    rlang::abort(c("There are more than one extension packages for:",
                   mod, eng, mod_mode))
  }
  if (num_ext > 0) {
    res <- paste0("The **", 
                  exts, 
                  "** extension package is required to fit this model.")
  } else {
    res <- ""
  }
  res
}

options(width = 80)


Try the parsnip package in your browser

Any scripts or data that you put into this service are public.

parsnip documentation built on Aug. 18, 2023, 1:07 a.m.