R/01_S7_Hyperparameters.R

Defines functions get_tabnet_config setup_TabNet setup_RadialSVM setup_Isotonic setup_LightRuleFit setup_LightGBM setup_LightRF setup_LightCART setup_GLMNET setup_CART setup_GAM setup_GLM print.Hyperparameters

Documented in setup_CART setup_GAM setup_GLM setup_GLMNET setup_Isotonic setup_LightCART setup_LightGBM setup_LightRF setup_LightRuleFit setup_RadialSVM setup_TabNet

# S7_Hyperparameters.R
# ::rtemis::
# 2025 EDG rtemis.org

# References ----
# S7
# https://github.com/RConsortium/S7
# https://rconsortium.github.io/S7/
# LightGBM parameters
# https://lightgbm.readthedocs.io/en/latest/Parameters.html

# `tuned` values ----
# -9: Set by Tuner: Actively being tuned (Values fixed by Tuner).
# -2: Set by constructor: Not tunable (No tunable_hyperparameters).
# -1: Set by constructor: Not tunable (tunable_hyperparameters exist, but none of them have more than one value).
#  0: Set by constructor: Untuned but tunable (at least one of tunable_hyperparameters has more than one value).
#  1: Set by Tuner: Tuned (Started as 0, set to 1 when tuned).

# `resampled` values ----
# 0: Running on single training set.
# 1: Running on cross-validated training sets.

#' @title Hyperparameters
#'
#' @description
#' Superclass for hyperparameters.
#'
#' @field algorithm Character: Algorithm name.
#' @field hyperparameters Named list of algorithm hyperparameter values.
#' @field tuned Integer: Tuning status.
#' @field resampled Integer: Cross-validation status.
#' @field tunable_hyperparameters Character: Names of tunable hyperparameters.
#' @field fixed_hyperparameters Character: Names of fixed hyperparameters.
#'
#' @author EDG
#' @noRd
Hyperparameters <- new_class(
  name = "Hyperparameters",
  properties = list(
    algorithm = class_character,
    hyperparameters = class_list,
    tunable_hyperparameters = class_character,
    fixed_hyperparameters = class_character,
    tuned = class_integer,
    resampled = class_integer
  ),
  constructor = function(
    algorithm,
    hyperparameters,
    tunable_hyperparameters,
    fixed_hyperparameters
  ) {
    # Test if any tunable_hyperparameters have more than one value
    if (length(tunable_hyperparameters) > 0) {
      if (any(sapply(hyperparameters[tunable_hyperparameters], length) > 1)) {
        tuned <- 0L # Search values defined for tunable hyperparameters.
      } else {
        tuned <- -1L # No search values defined for tunable hyperparameters.
      }
    } else {
      tuned <- -2L # No tunable hyperparameters
    }
    # GLMNET
    if (algorithm == "GLMNET") {
      if (is.null(hyperparameters[["lambda"]])) {
        tuned <- 0L
      }
    }
    # LightGBM
    if (algorithm == "LightGBM") {
      if (is.null(hyperparameters[["nrounds"]])) {
        tuned <- 0L
      }
    }
    # SVM
    # Check kernel-specific hyperparameters
    if (algorithm == "SVM") {
      # linear => cost
      if (hyperparameters[["kernel"]] == "linear") {
        if (length(hyperparameters[["cost"]]) > 1) {
          tuned <- 0L
        }
      } else if (hyperparameters[["kernel"]] == "polynomial") {
        if (length(hyperparameters[["degree"]]) > 1) {
          tuned <- 0L
        }
      } else if (hyperparameters[["kernel"]] == "radial") {
        if (length(hyperparameters[["sigma"]]) > 1) {
          tuned <- 0L
        }
      }
    }
    new_object(
      S7_object(),
      algorithm = algorithm,
      hyperparameters = hyperparameters,
      tunable_hyperparameters = tunable_hyperparameters,
      fixed_hyperparameters = fixed_hyperparameters,
      tuned = tuned,
      resampled = 0L
    )
  }
) # /Hyperparameters

# Print Hyperparameters ----
#' Print Hyperparameters
#'
#' Print method for Hyperparameters object.
#'
#' @param x Hyperparameters object.
#' @param pad Integer: Left padding for printed output.
#' @param ... Not used.
#'
#' @author EDG
#' @noRd
print.Hyperparameters <- function(x, pad = 0L, ...) {
  objcat(paste(x@algorithm, "Hyperparameters"))
  printls(props(x)[-1], pad = pad)
  if (x@tuned == -9L) {
    cat(hilite2("\n  Hyperparameters are being tuned.\n"))
  } else if (x@tuned == -2L) {
    cat(hilite2("\n  No hyperparameters are tunable.\n"))
  } else if (x@tuned == 0L) {
    need_tuning <- names(get_params_need_tuning(x))
    cat(hilite2(
      "\n  ",
      ngettext(length(need_tuning), "Hyperparameter ", "Hyperparameters "),
      oxfordcomma(
        need_tuning,
        format_fn = underline
      ),
      ngettext(length(need_tuning), " needs ", " need "),
      "tuning.\n"
    ))
  } else if (x@tuned == -1L) {
    cat(hilite2("\n  No search values defined for tunable hyperparameters.\n"))
  } else if (x@tuned == 1L) {
    cat(hilite2("\n  Hyperparameters are tuned.\n"))
  }
  invisible(x)
}
method(print, Hyperparameters) <- function(x, ...) {
  print.Hyperparameters(x)
} # rtemis::print.Hyperparameters

# is_tuned() ----
is_tuned <- new_generic("is_tuned", "x")
method(is_tuned, Hyperparameters) <- function(x) {
  x@tuned == 1L
} # /is_tuned.Hyperparameters

# get_tuned_status() ----
get_tuned_status <- new_generic("get_tuned_status", "x")
method(get_tuned_status, Hyperparameters) <- function(x) {
  if (length(x@tunable_hyperparameters) > 0) {
    if (any(sapply(x@hyperparameters[x@tunable_hyperparameters], length) > 1)) {
      0L
    } else {
      -1L
    }
  } else {
    -2L
  }
} # /get_tuned_status.Hyperparameters

# Update Hyperparameters ----
#' Update Hyperparameters
#'
#' @param x Hyperparameters object.
#' @param hyperparameters Named list of algorithm hyperparameter values.
#'
#' @author EDG
#' @keywords internal
#' @noRd
# update <- new_generic("update", "x")
method(update, Hyperparameters) <- function(
  object,
  hyperparameters,
  tuned = NULL,
  ...
) {
  for (hp in names(hyperparameters)) {
    object@hyperparameters[[hp]] <- hyperparameters[[hp]]
  }
  # Update tuned status
  if (is.null(tuned)) {
    object@tuned <- get_tuned_status(object)
  } else {
    object@tuned <- tuned
  }
  object
} # /update.Hyperparameters

# Freeze Hyperparameters ----
#' Freeze Hyperparameters
#'
#' @param x Hyperparameters object.
#'
#' @author EDG
#' @keywords internal
#' @noRd
freeze <- new_generic("freeze", "x")
method(freeze, Hyperparameters) <- function(x) {
  x@tuned <- -1
}

# Lock Hyperparameters ----
#' Lock Hyperparameters
#'
#' @param x Hyperparameters object.
#'
#' @author EDG
#' @keywords internal
#' @noRd
lock <- new_generic("lock", "x")
method(lock, Hyperparameters) <- function(x) {
  x@tuned <- 1
}

# Make Hyperparameters@hyperparameters@name `$`-accessible
method(`$`, Hyperparameters) <- function(x, name) {
  x@hyperparameters[[name]]
}

# `$`-autocomplete Hyperparameters@hyperparameters ----
method(`.DollarNames`, Hyperparameters) <- function(x, pattern = "") {
  all_names <- names(x@hyperparameters)
  grep(pattern, all_names, value = TRUE)
}

# Make Hyperparameters@hyperparameters@name `[[`-accessible
method(`[[`, Hyperparameters) <- function(x, name) {
  x@hyperparameters[[name]]
}

#' needs_tuning ----
#'
#' @keywords internal
#' @noRd
needs_tuning <- new_generic("needs_tuning", "x")
method(needs_tuning, Hyperparameters) <- function(x) {
  x@tuned == 0
} # /needs_tuning.Hyperparameters

# get_params_need_tuning ----
#' Get hyperparameters that need tuning in an algorithm-specific way.
#'
#' @keywords internal
#' @noRd
method(get_params_need_tuning, Hyperparameters) <- function(x) {
  # -> list
  # Get tunable hyperparameters with more than one value
  x@hyperparameters[x@tunable_hyperparameters[
    sapply(x@hyperparameters[x@tunable_hyperparameters], length) > 1
  ]]
} # /get_params_need_tuning.Hyperparameters

# get_params.(Hyperparameters, character) ----
method(get_params, list(Hyperparameters, class_character)) <- function(
  x,
  param_names
) {
  sapply(param_names, function(p) x@hyperparameters[p], USE.NAMES = FALSE)
}


# GLMHyperparameters ----
#' @author EDG
#'
#' @keywords internal
#' @noRd
GLMHyperparameters <- new_class(
  name = "GLMHyperparameters",
  parent = Hyperparameters,
  constructor = function(ifw) {
    new_object(
      Hyperparameters(
        algorithm = "GLM",
        hyperparameters = list(
          ifw = ifw
        ),
        tunable_hyperparameters = "ifw",
        fixed_hyperparameters = character()
      )
    )
  } # /constructor
) # /rtemis::GLMHyperparameters

#' Setup GLM Hyperparameters
#'
#' Setup hyperparameters for GLM training.
#'
#' @param ifw (Tunable) Logical: If TRUE, use Inverse Frequency Weighting in classification.
#'
#' @return GLMHyperparameters object.
#'
#' @author EDG
#' @export
setup_GLM <- function(ifw = FALSE) {
  GLMHyperparameters(ifw = ifw)
}

# GAMHyperparameters ----
GAM_tunable <- c("k", "ifw")
GAM_fixed <- character()

#' @author EDG
#' @keywords internal
#' @noRd
GAMHyperparameters <- new_class(
  name = "GAMHyperparameters",
  parent = Hyperparameters,
  constructor = function(k, ifw) {
    new_object(
      Hyperparameters(
        algorithm = "GAM",
        hyperparameters = list(
          k = k,
          ifw = ifw
        ),
        tunable_hyperparameters = GAM_tunable,
        fixed_hyperparameters = GAM_fixed
      )
    )
  } # /constructor
) # /rtemis::GAMHyperparameters

#' Setup GAM Hyperparameters
#'
#' Setup hyperparameters for GAM training.
#'
#' Get more information from [mgcv::gam].
#'
#' @param k (Tunable) Integer: Number of knots.
#' @param ifw (Tunable) Logical: If TRUE, use Inverse Frequency Weighting in classification.
#'
#' @return GAMHyperparameters object.
#'
#' @author EDG
#' @export
setup_GAM <- function(k = 5L, ifw = FALSE) {
  k <- clean_posint(k)
  GAMHyperparameters(k = k, ifw = ifw)
} # /rtemis::setup_GAM


# CARTHyperparameters ----
CART_tunable <- c("cp", "maxdepth", "minsplit", "minbucket", "prune_cp", "ifw")
CART_fixed <- c(
  "method",
  "model",
  "maxcompete",
  "maxsurrogate",
  "usesurrogate",
  "surrogatestyle",
  "xval",
  "cost"
)

#' @title CARTHyperparameters
#'
#' @description
#' Hyperparameters subclass for CART.
#'
#' @author EDG
#' @keywords internal
#' @noRd
CARTHyperparameters <- new_class(
  name = "CARTHyperparameters",
  parent = Hyperparameters,
  constructor = function(
    cp,
    maxdepth,
    minsplit,
    minbucket,
    prune_cp,
    method,
    model,
    maxcompete,
    maxsurrogate,
    usesurrogate,
    surrogatestyle,
    xval,
    cost,
    ifw
  ) {
    new_object(
      Hyperparameters(
        algorithm = "CART",
        hyperparameters = list(
          cp = cp,
          maxdepth = maxdepth,
          minsplit = minsplit,
          minbucket = minbucket,
          prune_cp = prune_cp,
          method = method,
          model = model,
          maxcompete = maxcompete,
          maxsurrogate = maxsurrogate,
          usesurrogate = usesurrogate,
          surrogatestyle = surrogatestyle,
          xval = xval,
          cost = cost,
          ifw = ifw
        ),
        tunable_hyperparameters = CART_tunable,
        fixed_hyperparameters = CART_fixed
      )
    )
  } # /constructor
) # /rtemis::CARTHyperparameters

#' Setup CART Hyperparameters
#'
#' Setup hyperparameters for CART training.
#'
#' Get more information from [rpart::rpart] and [rpart::rpart.control].
#'
#' @param cp (Tunable) Numeric: Complexity parameter.
#' @param maxdepth (Tunable) Integer: Maximum depth of tree.
#' @param minsplit (Tunable) Integer: Minimum number of observations in a node to split.
#' @param minbucket (Tunable) Integer: Minimum number of observations in a terminal node.
#' @param prune_cp (Tunable) Numeric: Complexity for cost-complexity pruning after tree is built
#' @param method String: Splitting method.
#' @param model Logical: If TRUE, return a model.
#' @param maxcompete Integer: Maximum number of competitive splits.
#' @param maxsurrogate Integer: Maximum number of surrogate splits.
#' @param usesurrogate Integer: Number of surrogate splits to use.
#' @param surrogatestyle Integer: Type of surrogate splits.
#' @param xval Integer: Number of cross-validation folds.
#' @param cost Numeric (>=0): One for each feature.
#' @param ifw Logical: If TRUE, use Inverse Frequency Weighting in classification.
#'
#' @return CARTHyperparameters object.
#'
#' @author EDG
#' @export
setup_CART <- function(
  # tunable
  cp = 0.01,
  maxdepth = 20L,
  minsplit = 2L,
  minbucket = 1L, # round(minsplit / 3),
  prune_cp = NULL,
  # fixed
  method = "auto",
  model = TRUE,
  maxcompete = 4L,
  maxsurrogate = 5L,
  usesurrogate = 2L,
  surrogatestyle = 0L,
  xval = 0L,
  cost = NULL,
  ifw = FALSE
) {
  check_inherits(cp, "numeric")
  maxdepth <- clean_int(maxdepth)
  minsplit <- clean_int(minsplit)
  minbucket <- clean_int(minbucket)
  check_inherits(prune_cp, "numeric")
  check_inherits(method, "character")
  check_inherits(model, "logical")
  maxcompete <- clean_int(maxcompete)
  maxsurrogate <- clean_int(maxsurrogate)
  usesurrogate <- clean_int(usesurrogate)
  surrogatestyle <- clean_int(surrogatestyle)
  xval <- clean_int(xval)
  check_inherits(cost, "numeric")
  CARTHyperparameters(
    cp = cp,
    maxdepth = maxdepth,
    minsplit = minsplit,
    minbucket = minbucket,
    prune_cp = prune_cp,
    method = method,
    model = model,
    maxcompete = maxcompete,
    maxsurrogate = maxsurrogate,
    usesurrogate = usesurrogate,
    surrogatestyle = surrogatestyle,
    xval = xval,
    cost = cost,
    ifw = ifw
  )
} # /setup_CART

# Test that all CART hyperparameters are set by setup_CART
stopifnot(all(c(CART_tunable, CART_fixed) %in% names(formals(setup_CART))))

# GLMNETHyperparameters ----
GLMNET_tunable <- c("alpha", "ifw")
GLMNET_fixed <- c(
  "family",
  "offset",
  "which_lambda_cv",
  "nlambda",
  "penalty_factor",
  "standardize",
  "intercept"
)

#' @title GLMNETHyperparameters
#'
#' @description
#' Hyperparameters subclass for GLMNET.
#'
#' @author EDG
#' @keywords internal
#' @noRd
GLMNETHyperparameters <- new_class(
  name = "GLMNETHyperparameters",
  parent = Hyperparameters,
  constructor = function(
    alpha,
    family,
    offset,
    which_lambda_cv,
    nlambda,
    lambda,
    penalty_factor,
    standardize,
    intercept,
    ifw
  ) {
    check_float01inc(alpha)
    check_inherits(which_lambda_cv, "character")
    nlambda <- clean_posint(nlambda)
    check_inherits(penalty_factor, "numeric")
    check_inherits(standardize, "logical")
    new_object(
      Hyperparameters(
        algorithm = "GLMNET",
        hyperparameters = list(
          alpha = alpha,
          family = family,
          offset = offset,
          which_lambda_cv = which_lambda_cv,
          nlambda = nlambda,
          lambda = lambda,
          penalty_factor = penalty_factor,
          standardize = standardize,
          intercept = intercept,
          ifw = ifw
        ),
        tunable_hyperparameters = GLMNET_tunable,
        fixed_hyperparameters = GLMNET_fixed
      )
    )
  } # /constructor
) # /rtemis::GLMNETHyperparameters

#' Setup GLMNET Hyperparameters
#'
#' Setup hyperparameters for GLMNET training.
#'
#' Get more information from [glmnet::glmnet].
#'
#' @param alpha (Tunable) Numeric: Mixing parameter.
#' @param family Character: Family for GLMNET.
#' @param offset Numeric: Offset for GLMNET.
#' @param which_lambda_cv Character: Which lambda to use for prediction:
#' "lambda.1se" or "lambda.min"
#' @param nlambda Positive integer: Number of lambda values.
#' @param lambda Numeric: Lambda values.
#' @param penalty_factor Numeric: Penalty factor for each feature.
#' @param standardize Logical: If TRUE, standardize features.
#' @param intercept Logical: If TRUE, include intercept.
#' @param ifw Logical: If TRUE, use Inverse Frequency Weighting in classification.
#'
#' @return GLMNETHyperparameters object.
#'
#' @author EDG
#' @export
setup_GLMNET <- function(
  # tunable
  alpha = 1,
  # fixed
  family = NULL,
  offset = NULL,
  which_lambda_cv = "lambda.1se",
  nlambda = 100L,
  lambda = NULL,
  penalty_factor = NULL,
  standardize = TRUE,
  intercept = TRUE,
  ifw = FALSE
) {
  check_float01inc(alpha)
  check_inherits(which_lambda_cv, "character")
  nlambda <- clean_posint(nlambda)
  check_inherits(penalty_factor, "numeric")
  check_logical(standardize)
  check_logical(ifw)
  GLMNETHyperparameters(
    family = family,
    offset = offset,
    alpha = alpha,
    which_lambda_cv = which_lambda_cv,
    nlambda = nlambda,
    lambda = lambda,
    penalty_factor = penalty_factor,
    standardize = standardize,
    intercept = intercept,
    ifw = ifw
  )
} # /setup_GLMNET

# Test that all GLMNET hyperparameters are set by setup_GLMNET
stopifnot(all(
  c(GLMNET_tunable, GLMNET_fixed) %in% names(formals(setup_GLMNET))
))

method(get_params_need_tuning, GLMNETHyperparameters) <- function(x) {
  # Get tunable hyperparameters with more than one value
  out <- x@hyperparameters[x@tunable_hyperparameters[
    sapply(x@hyperparameters[x@tunable_hyperparameters], length) > 1
  ]]
  if (is.null(x[["lambda"]])) {
    out <- c(out, list(lambda = NULL))
  }
  out
} # /get_params_need_tuning.GLMNETHyperparameters

# LightCARTHyperparameters ----
LightCART_tunable <- c(
  "num_leaves",
  "max_depth",
  "lambda_l1",
  "lambda_l2",
  "min_data_in_leaf",
  "max_cat_threshold",
  "min_data_per_group",
  "linear_tree",
  "ifw"
)
LightCART_fixed <- c("objective")

#' @title LightCARTHyperparameters
#'
#' @description
#' Hyperparameters subclass for LightCART
#'
#' @author EDG
#' @keywords internal
#' @noRd
LightCARTHyperparameters <- new_class(
  name = "LightCARTHyperparameters",
  parent = Hyperparameters,
  constructor = function(
    num_leaves,
    max_depth,
    lambda_l1,
    lambda_l2,
    min_data_in_leaf,
    max_cat_threshold,
    min_data_per_group,
    linear_tree,
    objective,
    ifw
  ) {
    new_object(
      Hyperparameters(
        algorithm = "LightCART",
        hyperparameters = list(
          num_leaves = num_leaves,
          max_depth = max_depth,
          lambda_l1 = lambda_l1,
          lambda_l2 = lambda_l2,
          min_data_in_leaf = min_data_in_leaf,
          max_cat_threshold = max_cat_threshold,
          min_data_per_group = min_data_per_group,
          linear_tree = linear_tree,
          objective = objective,
          ifw = ifw
        ),
        tunable_hyperparameters = LightCART_tunable,
        fixed_hyperparameters = LightCART_fixed
      )
    )
  } # /constructor
) # /rtemis::LightCARTHyperparameters

#' Setup LightCART Hyperparameters
#'
#' Setup hyperparameters for LightCART training.
#'
#' Get more information from [lightgbm::lgb.train].
#'
#' @param num_leaves (Tunable) Positive integer: Maximum number of leaves in one tree.
#' @param max_depth (Tunable) Integer: Maximum depth of trees.
#' @param lambda_l1 (Tunable) Numeric: L1 regularization.
#' @param lambda_l2 (Tunable) Numeric: L2 regularization.
#' @param min_data_in_leaf (Tunable) Positive integer: Minimum number of data in a leaf.
#' @param max_cat_threshold (Tunable) Positive integer: Maximum number of categories for categorical features.
#' @param min_data_per_group (Tunable) Positive integer: Minimum number of observations per categorical group.
#' @param linear_tree (Tunable) Logical: If TRUE, use linear trees.
#' @param objective Character: Objective function.
#' @param ifw Logical: If TRUE, use Inverse Frequency Weighting in classification.
#'
#' @return LightCARTHyperparameters object.
#'
#' @author EDG
#' @export
setup_LightCART <- function(
  num_leaves = 32L,
  max_depth = -1L,
  lambda_l1 = 0,
  lambda_l2 = 0,
  min_data_in_leaf = 20L,
  max_cat_threshold = 32L,
  min_data_per_group = 100L,
  linear_tree = FALSE,
  objective = NULL,
  ifw = FALSE
) {
  num_leaves <- clean_posint(num_leaves)
  max_depth <- clean_int(max_depth)
  check_float01inc(lambda_l1)
  check_float01inc(lambda_l2)
  min_data_in_leaf <- clean_posint(min_data_in_leaf)
  max_cat_threshold <- clean_posint(max_cat_threshold)
  min_data_per_group <- clean_posint(min_data_per_group)
  check_logical(linear_tree)
  LightCARTHyperparameters(
    num_leaves = num_leaves,
    max_depth = max_depth,
    lambda_l1 = lambda_l1,
    lambda_l2 = lambda_l2,
    min_data_in_leaf = min_data_in_leaf,
    max_cat_threshold = max_cat_threshold,
    min_data_per_group = min_data_per_group,
    linear_tree = linear_tree,
    objective = objective,
    ifw = ifw
  )
} # /rtemis::setup_LightCART


# LightRFHyperparameters ----
LightRF_tunable <- c(
  "nrounds",
  "num_leaves",
  "max_depth",
  "feature_fraction",
  "subsample",
  "lambda_l1",
  "lambda_l2",
  "max_cat_threshold",
  "min_data_per_group",
  "ifw"
)
LightRF_fixed <- c(
  "objective",
  "device_type",
  "tree_learner",
  "boosting_type",
  "learning_rate",
  "subsample_freq",
  "early_stopping_rounds",
  "force_col_wise",
  "num_threads"
)

#' @title LightRFHyperparameters
#'
#' @description
#' Hyperparameters subclass for LightRF
#'
#' @author EDG
#' @keywords internal
#' @noRd
LightRFHyperparameters <- new_class(
  name = "LightRFHyperparameters",
  parent = Hyperparameters,
  constructor = function(
    nrounds,
    num_leaves,
    max_depth,
    feature_fraction,
    subsample,
    lambda_l1,
    lambda_l2,
    max_cat_threshold,
    min_data_per_group,
    linear_tree,
    ifw,
    # fixed
    objective,
    device_type,
    tree_learner,
    force_col_wise,
    num_threads
  ) {
    new_object(
      Hyperparameters(
        algorithm = "LightRF",
        hyperparameters = list(
          nrounds = nrounds,
          num_leaves = num_leaves,
          max_depth = max_depth,
          feature_fraction = feature_fraction,
          subsample = subsample,
          lambda_l1 = lambda_l1,
          lambda_l2 = lambda_l2,
          max_cat_threshold = max_cat_threshold,
          min_data_per_group = min_data_per_group,
          linear_tree = linear_tree,
          ifw = ifw,
          # fixed
          objective = objective,
          device_type = device_type,
          tree_learner = tree_learner,
          force_col_wise = force_col_wise,
          num_threads = num_threads,
          # unsettable: LightGBM params for RF
          boosting_type = "rf",
          learning_rate = 1, # no effect? in boosting_type 'rf', but set for clarity
          subsample_freq = 1L, # a.k.a. bagging_freq
          early_stopping_rounds = -1L
        ),
        tunable_hyperparameters = LightRF_tunable,
        fixed_hyperparameters = LightRF_fixed
      )
    )
  }
) # /rtemis::LightRFHyperparameters

#' Setup LightRF Hyperparameters
#'
#' Setup hyperparameters for LightRF training.
#'
#' Get more information from [lightgbm::lgb.train].
#' Note that hyperparameters subsample_freq and early_stopping_rounds are fixed,
#' and cannot be set because they are what makes `lightgbm` train a random forest.
#' These can all be set when training gradient boosting with LightGBM.
#'
#' @param nrounds (Tunable) Positive integer: Number of boosting rounds.
#' @param num_leaves (Tunable) Positive integer: Maximum number of leaves in one tree.
#' @param max_depth (Tunable) Integer: Maximum depth of trees.
#' @param feature_fraction (Tunable) Numeric: Fraction of features to use.
#' @param subsample (Tunable) Numeric: Fraction of data to use.
#' @param lambda_l1 (Tunable) Numeric: L1 regularization.
#' @param lambda_l2 (Tunable) Numeric: L2 regularization.
#' @param max_cat_threshold (Tunable) Positive integer: Maximum number of categories for categorical features.
#' @param min_data_per_group (Tunable) Positive integer: Minimum number of observations per categorical group.
#' @param linear_tree Logical: If TRUE, use linear trees.
#' @param objective Character: Objective function.
#' @param ifw Logical: If TRUE, use Inverse Frequency Weighting in classification.
#' @param device_type Character: "cpu" or "gpu".
#' @param tree_learner Character: "serial", "feature", "data", or "voting".
#' @param force_col_wise Logical: Use only with CPU - If TRUE, force col-wise histogram building
#' @param num_threads Integer: Number of threads to use. 0 means default number of threads in OpenMP.
#'
#' @return LightRFHyperparameters object.
#'
#' @author EDG
#' @export
setup_LightRF <- function(
  nrounds = 500L,
  num_leaves = 4096L,
  max_depth = -1L,
  feature_fraction = 0.7,
  subsample = .623, # a.k.a. bagging_fraction
  lambda_l1 = 0,
  lambda_l2 = 0,
  max_cat_threshold = 32L,
  min_data_per_group = 32L,
  linear_tree = FALSE,
  ifw = FALSE,
  # fixed
  objective = NULL,
  device_type = "cpu",
  tree_learner = "serial",
  force_col_wise = TRUE,
  num_threads = 0L # 0 means default number of threads in OpenMP
) {
  nrounds <- clean_posint(nrounds)
  num_leaves <- clean_posint(num_leaves)
  max_depth <- clean_int(max_depth)
  check_float01inc(feature_fraction)
  check_float01inc(subsample)
  check_float01inc(lambda_l1)
  check_float01inc(lambda_l2)
  max_cat_threshold <- clean_posint(max_cat_threshold)
  min_data_per_group <- clean_posint(min_data_per_group)
  check_logical(linear_tree)
  num_threads <- clean_int(num_threads)
  LightRFHyperparameters(
    nrounds = nrounds,
    num_leaves = num_leaves,
    max_depth = max_depth,
    feature_fraction = feature_fraction,
    subsample = subsample,
    lambda_l1 = lambda_l1,
    lambda_l2 = lambda_l2,
    max_cat_threshold = max_cat_threshold,
    min_data_per_group = min_data_per_group,
    linear_tree = linear_tree,
    ifw = ifw,
    objective = objective,
    device_type = device_type,
    tree_learner = tree_learner,
    force_col_wise = force_col_wise,
    num_threads = num_threads
  )
} # /rtemis::setupLightRF

# Test that all LightRF hyperparameters are set by setup_LightRF
# LightRF fixed hyperparameters are not editable.
stopifnot(all(LightRF_tunable %in% names(formals(setup_LightRF))))


# LightGBMHyperparameters ----
LightGBM_tunable <- c(
  "num_leaves",
  "max_depth",
  "learning_rate",
  "feature_fraction",
  "subsample",
  "subsample_freq",
  "lambda_l1",
  "lambda_l2",
  "max_cat_threshold",
  "min_data_per_group",
  "linear_tree",
  "ifw"
)
LightGBM_fixed <- c(
  "max_nrounds",
  "force_nrounds",
  "early_stopping_rounds",
  "objective",
  "device_type",
  "tree_learner",
  "force_col_wise",
  "num_threads"
)

#' @title LightGBMHyperparameters
#'
#' @description
#' Hyperparameters subclass for LightGBM
#'
#' @author EDG
#' @keywords internal
#' @noRd
LightGBMHyperparameters <- new_class(
  name = "LightGBMHyperparameters",
  parent = Hyperparameters,
  constructor = function(
    max_nrounds,
    force_nrounds,
    early_stopping_rounds,
    # tunable
    num_leaves,
    max_depth,
    learning_rate,
    feature_fraction,
    subsample,
    subsample_freq,
    lambda_l1,
    lambda_l2,
    max_cat_threshold,
    min_data_per_group,
    linear_tree,
    ifw,
    objective,
    device_type,
    tree_learner,
    force_col_wise,
    num_threads
  ) {
    nrounds <- if (!is.null(force_nrounds)) {
      force_nrounds
    } else {
      NULL
    }
    new_object(
      Hyperparameters(
        algorithm = "LightGBM",
        hyperparameters = list(
          nrounds = nrounds,
          max_nrounds = max_nrounds,
          force_nrounds = force_nrounds,
          early_stopping_rounds = early_stopping_rounds,
          num_leaves = num_leaves,
          max_depth = max_depth,
          learning_rate = learning_rate,
          subsample = subsample,
          subsample_freq = subsample_freq,
          lambda_l1 = lambda_l1,
          lambda_l2 = lambda_l2,
          max_cat_threshold = max_cat_threshold,
          min_data_per_group = min_data_per_group,
          linear_tree = linear_tree,
          ifw = ifw,
          objective = objective,
          device_type = device_type,
          tree_learner = tree_learner,
          force_col_wise = force_col_wise,
          num_threads = num_threads
        ),
        tunable_hyperparameters = LightGBM_tunable,
        fixed_hyperparameters = LightGBM_fixed
      )
    )
  }
) # /rtemis::LightGBMHyperparameters

method(update, LightGBMHyperparameters) <- function(
  object,
  hyperparameters,
  tuned = NULL,
  ...
) {
  for (hp in names(hyperparameters)) {
    object@hyperparameters[[hp]] <- hyperparameters[[hp]]
  }
  # Update tuned status
  if (is.null(tuned)) {
    object@tuned <- get_tuned_status(object)
  } else {
    object@tuned <- tuned
  }
  # Update nrounds (e.g. in LightRuleFit)
  if (
    is.null(object@hyperparameters[["nrounds"]]) &&
      !is.null(object@hyperparameters[["force_nrounds"]])
  ) {
    object@hyperparameters[["nrounds"]] <- object@hyperparameters[[
      "force_nrounds"
    ]]
  }
  object
} # /update.Hyperparameters

# References:
# LightGBM parameters: https://lightgbm.readthedocs.io/en/latest/Parameters.html

#' Setup LightGBM Hyperparameters
#'
#' Setup hyperparameters for LightGBM training.
#'
#' Get more information from [lightgbm::lgb.train].
#'
#' @param max_nrounds Positive integer: Maximum number of boosting rounds.
#' @param force_nrounds Positive integer: Use this many boosting rounds. Disable search for nrounds.
#' @param early_stopping_rounds Positive integer: Number of rounds without improvement to stop training.
#' @param num_leaves (Tunable) Positive integer: Maximum number of leaves in one tree.
#' @param max_depth (Tunable) Integer: Maximum depth of trees.
#' @param learning_rate (Tunable) Numeric: Learning rate.
#' @param feature_fraction (Tunable) Numeric: Fraction of features to use.
#' @param subsample (Tunable) Numeric: Fraction of data to use.
#' @param subsample_freq (Tunable) Positive integer: Frequency of subsample.
#' @param lambda_l1 (Tunable) Numeric: L1 regularization.
#' @param lambda_l2 (Tunable) Numeric: L2 regularization.
#' @param max_cat_threshold (Tunable) Positive integer: Maximum number of categories for categorical features.
#' @param min_data_per_group (Tunable) Positive integer: Minimum number of observations per categorical group.
#' @param linear_tree Logical: If TRUE, use linear trees.
#' @param objective Character: Objective function.
#' @param ifw Logical: If TRUE, use Inverse Frequency Weighting in classification.
#' @param device_type Character: "cpu" or "gpu".
#' @param tree_learner Character: "serial", "feature", "data", or "voting".
#' @param force_col_wise Logical: Use only with CPU - If TRUE, force col-wise histogram building
#' @param num_threads Integer: Number of threads to use. 0 means default number of threads in OpenMP.
#'
#' @return LightGBMHyperparameters object.
#'
#' @author EDG
#' @export
setup_LightGBM <- function(
  # nrounds will be auto-tuned if force_nrounds is NULL with a value up to max_nrounds and
  # using early_stopping_rounds.
  max_nrounds = 500L,
  force_nrounds = NULL,
  early_stopping_rounds = 10L,
  # tunable
  num_leaves = 8L,
  max_depth = -1L,
  learning_rate = 0.01,
  feature_fraction = 1.0,
  subsample = 1.0, # a.k.a. bagging_fraction {check:hyper}
  subsample_freq = 1L,
  lambda_l1 = 0,
  lambda_l2 = 0,
  max_cat_threshold = 32L,
  min_data_per_group = 32L,
  linear_tree = FALSE,
  ifw = FALSE,
  objective = NULL,
  device_type = "cpu",
  tree_learner = "serial",
  force_col_wise = TRUE,
  num_threads = 0L # 0 means default number of threads in OpenMP
) {
  max_nrounds <- clean_posint(max_nrounds)
  force_nrounds <- clean_posint(force_nrounds)
  early_stopping_rounds <- clean_posint(early_stopping_rounds)
  num_leaves <- clean_posint(num_leaves)
  max_depth <- clean_int(max_depth)
  check_floatpos1(learning_rate)
  check_floatpos1(feature_fraction)
  check_floatpos1(subsample)
  subsample_freq <- clean_posint(subsample_freq)
  check_inherits(lambda_l1, "numeric")
  check_inherits(lambda_l2, "numeric")
  max_cat_threshold <- clean_posint(max_cat_threshold)
  min_data_per_group <- clean_posint(min_data_per_group)
  check_logical(linear_tree)
  LightGBMHyperparameters(
    max_nrounds = max_nrounds,
    force_nrounds = force_nrounds,
    early_stopping_rounds = early_stopping_rounds,
    num_leaves = num_leaves,
    max_depth = max_depth,
    learning_rate = learning_rate,
    feature_fraction = feature_fraction,
    subsample = subsample,
    subsample_freq = subsample_freq,
    lambda_l1 = lambda_l1,
    lambda_l2 = lambda_l2,
    max_cat_threshold = max_cat_threshold,
    min_data_per_group = min_data_per_group,
    linear_tree = linear_tree,
    ifw = ifw,
    objective = objective,
    device_type = device_type,
    tree_learner = tree_learner,
    force_col_wise = force_col_wise,
    num_threads = num_threads
  )
} # /rtemis::setupLightGBM

# Test that all LightGBM hyperparameters are set by setup_LightGBM
stopifnot(all(
  c(LightGBM_tunable, LightGBM_fixed) %in% names(formals(setup_LightGBM))
))

method(get_params_need_tuning, LightGBMHyperparameters) <- function(x) {
  # Get tunable hyperparameters with more than one value
  out <- x@hyperparameters[x@tunable_hyperparameters[
    sapply(x@hyperparameters[x@tunable_hyperparameters], length) > 1
  ]]
  if (is.null(x[["nrounds"]])) {
    out <- c(out, list(nrounds = NULL))
  }
  out
} # /get_params_need_tuning.LightGBMHyperparameters


# LightRuleFitHyperparameters ----
LightRuleFit_tunable <- c(
  "nrounds",
  "num_leaves",
  "max_depth",
  "learning_rate",
  "subsample",
  "subsample_freq",
  "lambda_l1",
  "lambda_l2",
  "alpha",
  "ifw_lightgbm",
  "ifw_glmnet"
)
LightRuleFit_fixed <- c("lambda", "objective")
LightRuleFit_lightgbm_params <- c(
  "nrounds",
  "num_leaves",
  "max_depth",
  "learning_rate",
  "subsample",
  "subsample_freq",
  "lambda_l1",
  "lambda_l2",
  "objective"
)
LightRuleFit_glmnet_params <- c("alpha", "lambda")

#' @title LightRuleFitHyperparameters
#'
#' @description
#' Hyperparameters subclass for LightRuleFit.
#'
#' @author EDG
#' @keywords internal
#' @noRd
LightRuleFitHyperparameters <- new_class(
  name = "LightRuleFitHyperparameters",
  parent = Hyperparameters,
  constructor = function(
    nrounds,
    num_leaves,
    max_depth,
    learning_rate,
    subsample,
    subsample_freq,
    lambda_l1,
    lambda_l2,
    objective,
    ifw_lightgbm,
    # GLMNET
    alpha,
    lambda,
    ifw_glmnet,
    # IFW
    ifw
  ) {
    new_object(
      Hyperparameters(
        algorithm = "LightRuleFit",
        hyperparameters = list(
          nrounds = nrounds,
          num_leaves = num_leaves,
          max_depth = max_depth,
          learning_rate = learning_rate,
          subsample = subsample,
          subsample_freq = subsample_freq,
          lambda_l1 = lambda_l1,
          lambda_l2 = lambda_l2,
          objective = objective,
          ifw_lightgbm = ifw_lightgbm,
          # GLMNET
          alpha = alpha,
          lambda = lambda,
          ifw_glmnet = ifw_glmnet,
          # IFW
          ifw = ifw
        ),
        tunable_hyperparameters = LightRuleFit_tunable,
        fixed_hyperparameters = LightRuleFit_fixed
      )
    )
  }
) # /rtemis::LightRuleFitHyperparameters

#' Setup LightRuleFit Hyperparameters
#'
#' Setup hyperparameters for LightRuleFit training.
#'
#' Get more information from [lightgbm::lgb.train].
#'
#' @param nrounds (Tunable) Positive integer: Number of boosting rounds.
#' @param num_leaves (Tunable) Positive integer: Maximum number of leaves in one tree.
#' @param max_depth (Tunable) Integer: Maximum depth of trees.
#' @param learning_rate (Tunable) Numeric: Learning rate.
#' @param subsample (Tunable) Numeric: Fraction of data to use.
#' @param subsample_freq (Tunable) Positive integer: Frequency of subsample.
#' @param lambda_l1 (Tunable) Numeric: L1 regularization.
#' @param lambda_l2 (Tunable) Numeric: L2 regularization.
#' @param objective Character: Objective function.
#' @param ifw_lightgbm (Tunable) Logical: If TRUE, use Inverse Frequency Weighting in the LightGBM
#' step.
#' @param objective Character: Objective function.
#' @param alpha (Tunable) Numeric: Alpha for GLMNET.
#' @param lambda Numeric: Lambda for GLMNET.
#' @param ifw_glmnet (Tunable) Logical: If TRUE, use Inverse Frequency Weighting in the GLMNET step.
#' @param ifw Logical: If TRUE, use Inverse Frequency Weighting in classification. This applies IFW
#' to both LightGBM and GLMNET.
#'
#' @return LightRuleFitHyperparameters object.
#'
#' @author EDG
#' @export
setup_LightRuleFit <- function(
  nrounds = 200L,
  num_leaves = 32L,
  max_depth = 4L,
  learning_rate = 0.1,
  subsample = 0.666,
  subsample_freq = 1L,
  lambda_l1 = 0,
  lambda_l2 = 0,
  objective = NULL,
  ifw_lightgbm = FALSE,
  alpha = 1,
  lambda = NULL,
  ifw_glmnet = FALSE,
  ifw = FALSE
) {
  nrounds <- clean_posint(nrounds)
  num_leaves <- clean_posint(num_leaves)
  max_depth <- clean_int(max_depth)
  check_floatpos1(learning_rate)
  check_floatpos1(subsample)
  subsample_freq <- clean_posint(subsample_freq)
  check_inherits(lambda_l1, "numeric")
  check_inherits(lambda_l2, "numeric")
  check_float01inc(alpha)
  check_inherits(lambda, "numeric")
  check_logical(ifw_lightgbm)
  check_logical(ifw_glmnet)
  check_logical(ifw)
  # If ifw, cannot have ifw_lightgbm or ifw_glmnet
  if (ifw) {
    if (ifw_lightgbm) {
      stop("Cannot set ifw and ifw_lightgbm at the same time.")
    }
    if (ifw_glmnet) {
      stop("Cannot set ifw and ifw_glmnet at the same time.")
    }
  }
  LightRuleFitHyperparameters(
    nrounds = nrounds,
    num_leaves = num_leaves,
    max_depth = max_depth,
    learning_rate = learning_rate,
    subsample = subsample,
    subsample_freq = subsample_freq,
    lambda_l1 = lambda_l1,
    lambda_l2 = lambda_l2,
    objective = objective,
    ifw_lightgbm = ifw_lightgbm,
    alpha = alpha,
    lambda = lambda,
    ifw_glmnet = ifw_glmnet,
    ifw = ifw
  )
} # /rtemis::setup_LightRuleFit


# IsotonicHyperparameters ----
Isotonic_tunable <- character()
Isotonic_fixed <- character()

#' @title IsotonicHyperparameters
#'
#' @description
#' Hyperparameters subclass for Isotonic Regression.
#'
#' @author EDG
#' @keywords internal
#' @noRd
IsotonicHyperparameters <- new_class(
  name = "IsotonicHyperparameters",
  parent = Hyperparameters,
  constructor = function(ifw) {
    new_object(
      Hyperparameters(
        algorithm = "Isotonic",
        hyperparameters = list(
          ifw = ifw
        ),
        tunable_hyperparameters = "ifw",
        fixed_hyperparameters = Isotonic_fixed
      )
    )
  }
) # /rtemis::IsotonicHyperparameters

# setup_Isotonic ----
#' Setup Isotonic Hyperparameters
#'
#' Setup hyperparameters for Isotonic Regression.
#'
#' There are not hyperparameters for this algorithm at this moment.
#'
#' @param ifw Logical: If TRUE, use Inverse Frequency Weighting in classification.
#'
#' @return IsotonicHyperparameters object.
#'
#' @author EDG
#' @export

setup_Isotonic <- function(ifw = FALSE) {
  IsotonicHyperparameters(ifw = ifw)
} # /rtemis::setup_Isotonic


# SVMHyperparameters ----
SVM_tunable <- c("cost", "degree", "gamma", "coef0", "ifw")
SVM_fixed <- c("kernel")

#' @title SVMHyperparameters
#'
#' @description
#' Hyperparameters subclass for SVM.
#'
#' @author EDG
#' @keywords internal
#' @noRd
SVMHyperparameters <- new_class(
  name = "SVMHyperparameters",
  parent = Hyperparameters,
  constructor = function(
    hyperparameters,
    tunable_hyperparameters,
    fixed_hyperparameters
  ) {
    new_object(
      Hyperparameters(
        algorithm = "SVM",
        hyperparameters = hyperparameters,
        tunable_hyperparameters = tunable_hyperparameters,
        fixed_hyperparameters = fixed_hyperparameters
      )
    )
  } # /constructor
) # /rtemis::SVMHyperparameters


# RadialSVMHyperparameters ----
#' @title RadialSVMHyperparameters
#'
#' @description
#' Hyperparameters subclass for SVM with radial kernel.
#'
#' @author EDG
#' @keywords internal
#' @noRd
RadialSVMHyperparameters <- new_class(
  name = "RadialSVMHyperparameters",
  parent = SVMHyperparameters,
  constructor = function(cost, gamma, ifw) {
    new_object(
      SVMHyperparameters(
        hyperparameters = list(
          kernel = "radial",
          cost = cost,
          gamma = gamma,
          ifw = ifw
        ),
        tunable_hyperparameters = c("cost", "gamma", "ifw"),
        fixed_hyperparameters = c("kernel")
      )
    )
  } # /constructor
) # /rtemis::RadialSVMHyperparameters

#' Setup RadialSVM Hyperparameters
#'
#' Setup hyperparameters for RadialSVM training.
#'
#' Get more information from [e1071::svm].
#'
#' @param cost (Tunable) Numeric: Cost of constraints violation.
#' @param gamma (Tunable) Numeric: Kernel coefficient.
#' @param ifw Logical: If TRUE, use Inverse Frequency Weighting in classification.
#'
#' @return RadialSVMHyperparameters object.
#'
#' @author EDG
#' @export
setup_RadialSVM <- function(
  cost = 1,
  gamma = 0.01,
  ifw = FALSE
) {
  check_inherits(cost, "numeric")
  check_inherits(gamma, "numeric")
  check_logical(ifw)
  RadialSVMHyperparameters(
    cost = cost,
    gamma = gamma,
    ifw = ifw
  )
} # /setup_RadialSVM

setup_SVM <- setup_RadialSVM


# TabNetHyperparameters ----
tabnet_tunable <- c(
  "batch_size",
  "penalty",
  "clip_value",
  "loss",
  "epochs",
  "drop_last",
  "decision_width",
  "attention_width",
  "num_steps",
  "feature_reusage",
  "mask_type",
  "virtual_batch_size",
  "valid_split",
  "learn_rate",
  "optimizer",
  "lr_scheduler",
  "lr_decay",
  "step_size",
  "checkpoint_epochs",
  "cat_emb_dim",
  "num_independent",
  "num_shared",
  "num_independent_decoder",
  "num_shared_decoder",
  "momentum",
  "pretraining_ratio",
  "importance_sample_size",
  "early_stopping_monitor",
  "early_stopping_tolerance",
  "early_stopping_patience",
  "num_workers",
  "skip_importance",
  "early_stopping_patience",
  "ifw"
)

tabnet_fixed <- c("device", "num_workers", "skip_importance")

#' @title TabNetHyperparameters
#'
#' @description
#' Hyperparameters subclass for TabNet.
#'
#' @author EDG
#' @keywords internal
#' @noRd
TabNetHyperparameters <- new_class(
  name = "TabNetHyperparameters",
  parent = Hyperparameters,
  constructor = function(
    batch_size,
    penalty,
    clip_value,
    loss,
    epochs,
    drop_last,
    decision_width,
    attention_width,
    num_steps,
    feature_reusage,
    mask_type,
    virtual_batch_size,
    valid_split,
    learn_rate,
    optimizer,
    lr_scheduler,
    lr_decay,
    step_size,
    checkpoint_epochs,
    cat_emb_dim,
    num_independent,
    num_shared,
    num_independent_decoder,
    num_shared_decoder,
    momentum,
    pretraining_ratio,
    device,
    importance_sample_size,
    early_stopping_monitor,
    early_stopping_tolerance,
    early_stopping_patience,
    num_workers,
    skip_importance,
    ifw
  ) {
    new_object(
      Hyperparameters(
        algorithm = "TabNet",
        hyperparameters = list(
          batch_size = batch_size,
          penalty = penalty,
          clip_value = clip_value,
          loss = loss,
          epochs = epochs,
          drop_last = drop_last,
          decision_width = decision_width,
          attention_width = attention_width,
          num_steps = num_steps,
          feature_reusage = feature_reusage,
          mask_type = mask_type,
          virtual_batch_size = virtual_batch_size,
          valid_split = valid_split,
          learn_rate = learn_rate,
          optimizer = optimizer,
          lr_scheduler = lr_scheduler,
          lr_decay = lr_decay,
          step_size = step_size,
          checkpoint_epochs = checkpoint_epochs,
          cat_emb_dim = cat_emb_dim,
          num_independent = num_independent,
          num_shared = num_shared,
          num_independent_decoder = num_independent_decoder,
          num_shared_decoder = num_shared_decoder,
          momentum = momentum,
          pretraining_ratio = pretraining_ratio,
          device = device,
          importance_sample_size = importance_sample_size,
          early_stopping_monitor = early_stopping_monitor,
          early_stopping_tolerance = early_stopping_tolerance,
          early_stopping_patience = early_stopping_patience,
          num_workers = num_workers,
          skip_importance = skip_importance,
          ifw = ifw
        ),
        tunable_hyperparameters = tabnet_tunable,
        fixed_hyperparameters = tabnet_fixed
      )
    )
  } # /constructor
) # /rtemis::TabNetHyperparameters

#' Setup TabNet Hyperparameters
#'
#' Setup hyperparameters for TabNet training.
#'
# Get more information from [tabnet::tabnet_config]
#'
#' @param batch_size (Tunable) Positive integer: Batch size.
#' @param penalty (Tunable) Numeric: Regularization penalty.
#' @param clip_value Numeric: Clip value.
#' @param loss Character: Loss function.
#' @param epochs (Tunable) Positive integer: Number of epochs.
#' @param drop_last Logical: If TRUE, drop last batch.
#' @param decision_width (Tunable) Positive integer: Decision width.
#' @param attention_width (Tunable) Positive integer: Attention width.
#' @param num_steps (Tunable) Positive integer: Number of steps.
#' @param feature_reusage (Tunable) Numeric: Feature reusage.
#' @param mask_type Character: Mask type.
#' @param virtual_batch_size (Tunable) Positive integer: Virtual batch size.
#' @param valid_split Numeric: Validation split.
#' @param learn_rate (Tunable) Numeric: Learning rate.
#' @param optimizer Character or torch function: Optimizer.
#' @param lr_scheduler Character or torch function: "step", "reduce_on_plateau".
#' @param lr_decay Numeric: Learning rate decay.
#' @param step_size Positive integer: Step size.
#' @param checkpoint_epochs (Tunable) Positive integer: Checkpoint epochs.
#' @param cat_emb_dim (Tunable) Positive integer: Categorical embedding dimension.
#' @param num_independent (Tunable) Positive integer: Number of independent Gated Linear Units (GLU)
#' at each step of the encoder.
#' @param num_shared (Tunable) Positive integer: Number of shared Gated Linear Units (GLU) at each
#' step of the encoder.
#' @param num_independent_decoder (Tunable) Positive integer: Number of independent GLU layers for
#' pretraining.
#' @param num_shared_decoder (Tunable) Positive integer: Number of shared GLU layers for
#' pretraining.
#' @param momentum (Tunable) Numeric: Momentum.
#' @param pretraining_ratio (Tunable) Numeric: Pretraining ratio.
#' @param device Character: Device "cpu" or "cuda".
#' @param importance_sample_size Positive integer: Importance sample size.
#' @param early_stopping_monitor Character: Early stopping monitor. "valid_loss", "train_loss",
#' "auto".
#' @param early_stopping_tolerance Numeric: Minimum relative improvement to reset the patience
#' counter.
#' @param early_stopping_patience Positive integer: Number of epochs without improving before
#' stopping.
#' @param num_workers Positive integer: Number of subprocesses for data loacding.
#' @param skip_importance Logical: If TRUE, skip importance calculation.
#' @param ifw Logical: If TRUE, use Inverse Frequency Weighting in classification.
#'
#' @return TabNetHyperparameters object.
#'
#' @author EDG
#' @export
setup_TabNet <- function(
  batch_size = 1024^2,
  penalty = 0.001,
  clip_value = NULL,
  loss = "auto",
  epochs = 50L,
  drop_last = FALSE,
  decision_width = NULL,
  attention_width = NULL,
  num_steps = 3L,
  feature_reusage = 1.3,
  mask_type = "sparsemax",
  virtual_batch_size = 256^2,
  valid_split = 0,
  learn_rate = 0.02,
  optimizer = "adam",
  lr_scheduler = NULL,
  lr_decay = 0.1,
  step_size = 30,
  checkpoint_epochs = 10L,
  cat_emb_dim = 1L,
  num_independent = 2L,
  num_shared = 2L,
  num_independent_decoder = 1L,
  num_shared_decoder = 1L,
  momentum = 0.02,
  pretraining_ratio = 0.5,
  device = "auto",
  importance_sample_size = NULL,
  early_stopping_monitor = "auto",
  early_stopping_tolerance = 0,
  early_stopping_patience = 0,
  num_workers = 0L,
  skip_importance = FALSE,
  ifw = FALSE
) {
  TabNetHyperparameters(
    batch_size = batch_size,
    penalty = penalty,
    clip_value = clip_value,
    loss = loss,
    epochs = epochs,
    drop_last = drop_last,
    decision_width = decision_width,
    attention_width = attention_width,
    num_steps = num_steps,
    feature_reusage = feature_reusage,
    mask_type = mask_type,
    virtual_batch_size = virtual_batch_size,
    valid_split = valid_split,
    learn_rate = learn_rate,
    optimizer = optimizer,
    lr_scheduler = lr_scheduler,
    lr_decay = lr_decay,
    step_size = step_size,
    checkpoint_epochs = checkpoint_epochs,
    cat_emb_dim = cat_emb_dim,
    num_independent = num_independent,
    num_shared = num_shared,
    num_independent_decoder = num_independent_decoder,
    num_shared_decoder = num_shared_decoder,
    momentum = momentum,
    pretraining_ratio = pretraining_ratio,
    device = device,
    importance_sample_size = importance_sample_size,
    early_stopping_monitor = early_stopping_monitor,
    early_stopping_tolerance = early_stopping_tolerance,
    early_stopping_patience = early_stopping_patience,
    num_workers = num_workers,
    skip_importance = skip_importance,
    ifw = ifw
  )
} # /setup_TabNet

get_tabnet_config <- function(hyperparameters) {
  check_is_S7(hyperparameters, TabNetHyperparameters)
  hpr <- hyperparameters@hyperparameters
  hpr[["ifw"]] <- NULL
  hpr
  do.call(tabnet::tabnet_config, hpr)
} # /get_tabnet_config
egenn/rtemis documentation built on June 14, 2025, 11:54 p.m.