R/lightgbm.R

Defines functions check_lightgbm_aliases categorical_features_to_int categorical_columns prepare_df_lgbm lightgbm_by_tree multi_predict._lgb.Booster predict_lightgbm_regression_numeric predict_lightgbm_classification_raw predict_lightgbm_classification_class predict_lightgbm_classification_prob reshape_lightgbm_multiclass_preds sort_args process_data process_bagging process_parallelism process_objective_function process_mtry train_lightgbm

Documented in multi_predict._lgb.Booster predict_lightgbm_classification_class predict_lightgbm_classification_prob predict_lightgbm_classification_raw predict_lightgbm_regression_numeric train_lightgbm

#' Boosted trees with lightgbm
#'
#' `train_lightgbm` is a wrapper for `lightgbm` tree-based models
#' where all of the model arguments are in the main function.
#'
#' This is an internal function, not meant to be directly called by the user.
#'
#' @param x A data frame or matrix of predictors
#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
#' @param max_depth An integer for the maximum depth of the tree.
#' @param num_iterations An integer for the number of boosting iterations.
#' @param learning_rate A numeric value between zero and one to control the learning rate.
#' @param feature_fraction_bynode Fraction of predictors that will be randomly sampled
#' at each split.
#' @param min_data_in_leaf A numeric value for the minimum sum of instances needed
#'  in a child to continue to split.
#' @param min_gain_to_split A number for the minimum loss reduction required to make a
#'  further partition on a leaf node of the tree.
#' @param bagging_fraction Subsampling proportion of rows. Setting this argument
#'  to a non-default value will also set `bagging_freq = 1`. See the Bagging
#'  section in `?details_boost_tree_lightgbm` for more details.
#' @param early_stopping_round Number of iterations without an improvement in
#' the objective function occur before training should be halted.
#' @param validation The _proportion_ of the training data that are used for
#' performance assessment and potential early stopping.
#' @param counts A logical; should `feature_fraction_bynode` be interpreted as the
#' _number_ of predictors that will be randomly sampled at each split?
#' `TRUE` indicates that `mtry` will be interpreted in its sense as a _count_,
#' `FALSE` indicates that the argument will be interpreted in its sense as a
#' _proportion_.
#' @param quiet A logical; should logging by [lightgbm::lgb.train()] be muted?
#' @param ... Other options to pass to [lightgbm::lgb.train()]. Arguments
#' will be correctly routed to the `param` argument, or as a main argument,
#' depending on their name.
#' @return A fitted `lightgbm.Model` object.
#' @keywords internal
#' @export
train_lightgbm <- function(x, y, max_depth = -1, num_iterations = 100, learning_rate = 0.1,
                           feature_fraction_bynode = 1, min_data_in_leaf = 20,
                           min_gain_to_split = 0, bagging_fraction = 1,
                           early_stopping_round = NULL, validation = 0,
                           counts = TRUE, quiet = FALSE, ...) {

  force(x)
  force(y)

  if (!is.logical(quiet)) {
    rlang::abort("'quiet' should be a logical value.")
  }

  feature_fraction_bynode <-
    process_mtry(feature_fraction_bynode = feature_fraction_bynode,
                 counts = counts, x = x, is_missing = missing(feature_fraction_bynode))

  check_lightgbm_aliases(...)

  args <- list(
    param = list(
      num_iterations = num_iterations,
      learning_rate = learning_rate,
      max_depth = max_depth,
      feature_fraction_bynode = feature_fraction_bynode,
      min_data_in_leaf = min_data_in_leaf,
      min_gain_to_split = min_gain_to_split,
      bagging_fraction = bagging_fraction
    ),
    main = list(
      early_stopping_round = early_stopping_round,
      ...
    )
  )

  args <- process_objective_function(args, x, y)

  if (!is.numeric(y)) {
    y <- as.numeric(y) - 1
  }

  args <- process_parallelism(args)

  args <- process_bagging(args, ...)

  args <- process_data(args, x, y, validation, missing(validation),
                       early_stopping_round)

  args <- sort_args(args)

  if (!"verbose" %in% names(args$main)) {
    args$main$verbose <- 1L
  }

  compacted <-
    c(
      list(param = args$param),
      args$main[names(args$main) != "data"],
      list(data = quote(args$main$data))
    )

  call <- parsnip::make_call(fun = "lgb.train", ns = "lightgbm", compacted)

  if (quiet) {
    junk <- utils::capture.output(res <- rlang::eval_tidy(call, env = rlang::current_env()))
  } else {
    res <- rlang::eval_tidy(call, env = rlang::current_env())
  }

  res
}

process_mtry <- function(feature_fraction_bynode, counts, x, is_missing) {
  if (!is.logical(counts)) {
    rlang::abort("'counts' should be a logical value.")
  }

  ineq <- if (counts) {"greater"} else {"less"}
  interp <- if (counts) {"count"} else {"proportion"}
  opp <- if (!counts) {"count"} else {"proportion"}

  if ((feature_fraction_bynode < 1 & counts) | (feature_fraction_bynode > 1 & !counts)) {
    rlang::abort(
      glue::glue(
        "The supplied argument `mtry = {feature_fraction_bynode}` must be ",
        "{ineq} than or equal to 1. \n\n`mtry` is currently being interpreted ",
        "as a {interp} rather than a {opp}. Supply `counts = {!counts}` to ",
        "`set_engine` to supply this argument as a {opp} rather than ",
        # TODO: link to parsnip's lightgbm docs instead here
        "a {interp}. \n\nSee `?train_lightgbm` for more details."
      ),
      call = NULL
    )
  }

  if (rlang::is_call(feature_fraction_bynode)) {
    if (rlang::call_name(feature_fraction_bynode) == "tune") {
      rlang::abort(
        glue::glue(
          "The supplied `mtry` parameter is a call to `tune`. Did you forget ",
          "to optimize hyperparameters with a tuning function like `tune::tune_grid`?"
        ),
        call = NULL
      )
    }
  }

  if (counts && !is_missing) {
    feature_fraction_bynode <- feature_fraction_bynode / ncol(x)
  }

  feature_fraction_bynode
}

process_objective_function <- function(args, x, y) {
  # set the "objective" param argument, clear it out from main args
  if (!any(names(args$main) %in% c("objective"))) {
    if (is.numeric(y)) {
      args$param$objective <- "regression"
    } else {
      lvl <- levels(y)
      lvls <- length(lvl)
      if (lvls == 2) {
        args$param$num_class <- 1
        args$param$objective <- "binary"
      } else {
        args$param$num_class <- lvls
        args$param$objective <- "multiclass"
      }
    }
  } else {
    args$param$objective <- args$main$objective
  }

  args$main$objective <- NULL

  args
}

# supply the number of threads as num_threads in params, clear out
# any other thread args that might be passed as main arguments
process_parallelism <- function(args) {
  if (!is.null(args$main["num_threads"])) {
    args$param$num_threads <- args$main[names(args$main) == "num_threads"]
    args$main[names(args$main) == "num_threads"] <- NULL
  }

  args
}

process_bagging <- function(args, ...) {
  if (args$param$bagging_fraction != 1 &&
      (!"bagging_freq" %in% names(list(...)))) {
    args$param$bagging_freq <- 1
  }

  args
}

process_data <- function(args, x, y, validation, missing_validation,
                         early_stopping_round) {
  #                                           trn_index       | val_index
  #                                         ----------------------------------
  #  needs_validation &  missing_validation | 1:n               1:n
  #  needs_validation & !missing_validation | sample(1:n, m)    setdiff(trn_index, 1:n)
  # !needs_validation &  missing_validation | 1:n               NULL
  # !needs_validation & !missing_validation | sample(1:n, m)    setdiff(trn_index, 1:n)

  n <- nrow(x)
  needs_validation <- !is.null(early_stopping_round)

  if (missing_validation) {
    trn_index <- 1:n
    if (needs_validation) {
      val_index <- trn_index
    } else {
      val_index <- NULL
    }
  } else {
    m <- min(floor(n * (1 - validation)) + 1, n - 1)
    trn_index <- sample(1:n, size = max(m, 2))
    val_index <- setdiff(1:n, trn_index)
  }

  args$main$data <-
    lightgbm::lgb.Dataset(
      data = prepare_df_lgbm(x[trn_index, , drop = FALSE]),
      label = y[trn_index],
      categorical_feature = categorical_columns(x[trn_index, , drop = FALSE]),
      params = list(feature_pre_filter = FALSE)
    )

  if (!is.null(val_index)) {
    args$main$valids <-
      list(validation =
          lightgbm::lgb.Dataset(
          data = prepare_df_lgbm(x[val_index, , drop = FALSE]),
          label = y[val_index],
          categorical_feature = categorical_columns(x[val_index, , drop = FALSE]),
          params = list(feature_pre_filter = FALSE)
        )
      )
  }

  args
}

# transfers arguments between param and main arguments
sort_args <- function(args) {
  # warn on arguments that won't be passed along
  protected <- c("obj", "init_model", "colnames",
                 "categorical_feature", "callbacks", "reset_data")

  if (any(names(args$main) %in% protected)) {
    protected_args <- names(args$main[names(args$main) %in% protected])

    rlang::warn(
      glue::glue(
        "The following argument(s) are guarded by bonsai and will not ",
        "be passed to `lgb.train`: {glue::glue_collapse(protected_args, sep = ', ')}"
      )
    )

    args$main[protected_args] <- NULL
  }

  # dots are deprecated in lgb.train -- pass to param instead
  to_main   <- c("nrounds", "eval", "verbose", "record", "eval_freq",
                 "early_stopping_round", "data", "valids")

  args$param <- c(args$param, args$main[!names(args$main) %in% to_main])

  args$main[!names(args$main) %in% to_main] <- NULL

  args
}

# in lightgbm <= 3.3.2, predict() for multiclass classification produced a single
# vector of length num_observations * num_classes, in row-major order
#
# in versions after that release, lightgbm produces a numeric matrix with shape
# [num_observations, num_classes]
#
# this function ensures that multiclass classification predictions are always
# returned as a [num_observations, num_classes] matrix, regardless of lightgbm version
reshape_lightgbm_multiclass_preds <- function(preds, num_rows) {
    n_preds_per_case <- length(preds) / num_rows
    if (is.vector(preds) && n_preds_per_case > 1) {
        preds <- matrix(preds, ncol = n_preds_per_case, byrow = TRUE)
    }
    preds
}

#' Internal functions
#'
#' Not intended for direct use.
#'
#' @keywords internal
#' @export
#' @rdname lightgbm_helpers
predict_lightgbm_classification_prob <- function(object, new_data, ...) {
  p <- stats::predict(object$fit, prepare_df_lgbm(new_data), ...)
  p <- reshape_lightgbm_multiclass_preds(preds = p, num_rows = nrow(new_data))

  if(is.vector(p)) {
    p <- tibble::tibble(p1 = 1 - p, p2 = p)
  }

  colnames(p) <- object$lvl

  tibble::as_tibble(p)
}

#' @keywords internal
#' @export
#' @rdname lightgbm_helpers
predict_lightgbm_classification_class <- function(object, new_data, ...) {
  p <- predict_lightgbm_classification_prob(object, prepare_df_lgbm(new_data), ...)

  q <- apply(p, 1, function(x) which.max(x))

  names(p)[q]
}

#' @keywords internal
#' @export
#' @rdname lightgbm_helpers
predict_lightgbm_classification_raw <- function(object, new_data, ...) {
  if (using_newer_lightgbm_version()) {
      p <- stats::predict(object$fit, prepare_df_lgbm(new_data), type = "raw", ...)
  } else {
      p <- stats::predict(object$fit, prepare_df_lgbm(new_data), rawscore = TRUE, ...)
  }
  reshape_lightgbm_multiclass_preds(preds = p, num_rows = nrow(new_data))
}

#' @keywords internal
#' @export
#' @rdname lightgbm_helpers
predict_lightgbm_regression_numeric <- function(object, new_data, ...) {
  p <-
    stats::predict(
      object$fit,
      prepare_df_lgbm(new_data),
      params = list(predict_disable_shape_check = TRUE),
      ...
    )
  p
}



#' @keywords internal
#' @export
#' @rdname lightgbm_helpers
multi_predict._lgb.Booster <- function(object, new_data, type = NULL, trees = NULL, ...) {
  if (any(names(rlang::enquos(...)) == "newdata")) {
    rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
  }

  trees <- sort(trees)

  res <- map_df(trees, lightgbm_by_tree,
                object = object, new_data = new_data, type = type)
  res <- dplyr::arrange(res, .row, trees)
  res <- split(res[, -1], res$.row)
  names(res) <- NULL

  tibble::tibble(.pred = res)

}

lightgbm_by_tree <- function(tree, object, new_data, type = NULL) {
  # switch based on prediction type
  if (object$spec$mode == "regression") {
    pred <- predict_lightgbm_regression_numeric(object, new_data, num_iteration = tree)

    pred <- tibble::tibble(.pred = pred)

    nms <- names(pred)
  } else {
    if (is.null(type) || type == "class") {
      pred <- predict_lightgbm_classification_class(object, new_data, num_iteration = tree)

      pred <- tibble::tibble(.pred_class = factor(pred, levels = object$lvl))
    } else {
      pred <- predict_lightgbm_classification_prob(object, new_data, num_iteration = tree)

      names(pred) <- paste0(".pred_", names(pred))
    }

    nms <- names(pred)
  }

  pred[["trees"]] <- tree
  pred[[".row"]] <- 1:nrow(new_data)
  pred[, c(".row", "trees", nms)]
}

prepare_df_lgbm <- function(x, y = NULL) {
  categorical_cols <- categorical_columns(x)

  x <- categorical_features_to_int(x, categorical_cols)

  x <- as.matrix(x)

  return(x)
}

categorical_columns <- function(x){
  categorical_cols <- NULL
  for (i in seq_along(x)) {
    if (is.factor(x[[i]])) {
      categorical_cols <- c(categorical_cols, i)
    }
  }
  categorical_cols
}

categorical_features_to_int <- function(x, cat_indices){
  for (i in cat_indices){
    x[[i]] <- as.integer(x[[i]]) -1
  }
  x
}

check_lightgbm_aliases <- function(...) {
  dots <- rlang::list2(...)

  for (param in names(dots)) {
    uses_alias <- lightgbm_aliases$alias %in% param
    if (any(uses_alias)) {
      main <- lightgbm_aliases$lightgbm[uses_alias]
      parsnip <- lightgbm_aliases$parsnip[uses_alias]
      cli::cli_abort(c(
      "!" = "The {.var {param}} argument passed to \\
             {.help [`set_engine()`](parsnip::set_engine)} is an alias for \\
             a main model argument.",
      "i" = "Please instead pass this argument via the {.var {parsnip}} \\
             argument to {.help [`boost_tree()`](parsnip::boost_tree)}."
      ), call = rlang::call2("fit"))
    }
  }

  invisible(TRUE)
}

lightgbm_aliases <-
  tibble::tribble(
    ~parsnip,         ~lightgbm,                 ~alias,
    # note that "tree_depth" -> "max_depth" has no aliases
    "trees",          "num_iterations",          "num_iteration",
    "trees",          "num_iterations",          "n_iter",
    "trees",          "num_iterations",          "num_tree",
    "trees",          "num_iterations",          "num_trees",
    "trees",          "num_iterations",          "num_round",
    "trees",          "num_iterations",          "num_rounds",
    "trees",          "num_iterations",          "nrounds",
    "trees",          "num_iterations",          "num_boost_round",
    "trees",          "num_iterations",          "n_estimators",
    "trees",          "num_iterations",          "max_iter",
    "learn_rate",     "learning_rate",           "shrinkage_rate",
    "learn_rate",     "learning_rate",           "eta",
    "mtry",           "feature_fraction_bynode", "sub_feature_bynode",
    "mtry",           "feature_fraction_bynode", "colsample_bynode",
    "min_n",          "min_data_in_leaf",        "min_data_per_leaf",
    "min_n",          "min_data_in_leaf",        "min_data",
    "min_n",          "min_data_in_leaf",        "min_child_samples",
    "min_n",          "min_data_in_leaf",        "min_samples_leaf",
    "loss_reduction", "min_gain_to_split",       "min_split_gain",
    "sample_size",    "bagging_fraction",        "sub_row",
    "sample_size",    "bagging_fraction",        "subsample",
    "sample_size",    "bagging_fraction",        "bagging",
    "stop_iter",      "early_stopping_round",    "early_stopping_rounds",
    "stop_iter",      "early_stopping_round",    "early_stopping",
    "stop_iter",      "early_stopping_round",    "n_iter_no_change"
  )

Try the bonsai package in your browser

Any scripts or data that you put into this service are public.

bonsai documentation built on Dec. 1, 2022, 1:28 a.m.