R/boost_tree.R

Defines functions C50_by_tree multi_predict._C5.0 C5.0_train xgb_by_tree multi_predict._xgb.Booster get_event_level as_xgb_data xgb_predict maybe_proportion recalc_param process_others xgb_train check_args.boost_tree translate.boost_tree update.boost_tree boost_tree

Documented in boost_tree C5.0_train multi_predict._C5.0 multi_predict._xgb.Booster update.boost_tree xgb_predict xgb_train

# Prototype parsnip code for boosted trees

#' Boosted trees
#'
#' @description
#'
#' `boost_tree()` defines a model that creates a series of decision trees
#' forming an ensemble. Each tree depends on the results of previous trees.
#' All trees in the ensemble are combined to produce a final prediction. This
#' function can fit classification, regression, and censored regression models.
#'
#' \Sexpr[stage=render,results=rd]{parsnip:::make_engine_list("boost_tree")}
#'
#' More information on how \pkg{parsnip} is used for modeling is at
#' \url{https://www.tidymodels.org/}.
#'
#' @param mode A single character string for the prediction outcome mode.
#'  Possible values for this model are "unknown", "regression",
#'  "classification", or "censored regression".
#' @param engine A single character string specifying what computational engine
#'  to use for fitting.
#' @param mtry A number for the number (or proportion) of predictors that will
#'  be randomly sampled at each split when creating the tree models
#' (specific engines only).
#' @param trees An integer for the number of trees contained in
#'  the ensemble.
#' @param min_n An integer for the minimum number of data points
#'  in a node that is required for the node to be split further.
#' @param tree_depth An integer for the maximum depth of the tree (i.e. number
#'  of splits) (specific engines only).
#' @param learn_rate A number for the rate at which the boosting algorithm adapts
#'   from iteration-to-iteration (specific engines only). This is sometimes referred to
#'   as the shrinkage parameter.
#' @param loss_reduction A number for the reduction in the loss function required
#'   to split further (specific engines only).
#' @param sample_size A number for the number (or proportion) of data that is
#'  exposed to the fitting routine. For `xgboost`, the sampling is done at
#'  each iteration while `C5.0` samples once during training.
#' @param stop_iter The number of iterations without improvement before
#'   stopping (specific engines only).
#'
#' @templateVar modeltype boost_tree
#' @template spec-details
#'
#' @template spec-references
#'
#' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("boost_tree")},
#' [xgb_train()], [C5.0_train()]
#'
#' @examplesIf !parsnip:::is_cran_check()
#' show_engines("boost_tree")
#'
#' boost_tree(mode = "classification", trees = 20)
#' @export
boost_tree <-
  function(mode = "unknown",
           engine = "xgboost",
           mtry = NULL, trees = NULL, min_n = NULL,
           tree_depth = NULL, learn_rate = NULL,
           loss_reduction = NULL,
           sample_size = NULL,
           stop_iter = NULL) {
    args <- list(
      mtry = enquo(mtry),
      trees = enquo(trees),
      min_n = enquo(min_n),
      tree_depth = enquo(tree_depth),
      learn_rate = enquo(learn_rate),
      loss_reduction = enquo(loss_reduction),
      sample_size = enquo(sample_size),
      stop_iter = enquo(stop_iter)
    )

    new_model_spec(
      "boost_tree",
      args,
      eng_args = NULL,
      mode = mode,
      user_specified_mode = !missing(mode),
      method = NULL,
      engine = engine,
      user_specified_engine = !missing(engine)
    )
  }

# ------------------------------------------------------------------------------

#' @method update boost_tree
#' @rdname parsnip_update
#' @export
update.boost_tree <-
  function(object,
           parameters = NULL,
           mtry = NULL, trees = NULL, min_n = NULL,
           tree_depth = NULL, learn_rate = NULL,
           loss_reduction = NULL, sample_size = NULL,
           stop_iter = NULL,
           fresh = FALSE, ...) {

    args <- list(
      mtry = enquo(mtry),
      trees = enquo(trees),
      min_n = enquo(min_n),
      tree_depth = enquo(tree_depth),
      learn_rate = enquo(learn_rate),
      loss_reduction = enquo(loss_reduction),
      sample_size = enquo(sample_size),
      stop_iter = enquo(stop_iter)
    )

    update_spec(
      object = object,
      parameters = parameters,
      args_enquo_list = args,
      fresh = fresh,
      cls = "boost_tree",
      ...
    )
  }

# ------------------------------------------------------------------------------

#' @export
translate.boost_tree <- function(x, engine = x$engine, ...) {
  if (is.null(engine)) {
    message("Used `engine = 'xgboost'` for translation.")
    engine <- "xgboost"
  }
  x <- translate.default(x, engine, ...)

  ## -----------------------------------------------------------------------------

  arg_vals <- x$method$fit$args

  if (engine == "spark") {
    if (x$mode == "unknown") {
      rlang::abort(
        glue::glue(
          "For spark boosted trees models, the mode cannot be 'unknown' ",
          "if the specification is to be translated."
        )
      )
    } else {
      arg_vals$type <- x$mode
    }
  }

  ## -----------------------------------------------------------------------------
  # Protect some arguments based on data dimensions

  # min_n parameters
  if (any(names(arg_vals) == "min_instances_per_node")) {
    arg_vals$min_instances_per_node <-
      rlang::call2("min_rows", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(x))
  }

  ## -----------------------------------------------------------------------------

  x$method$fit$args <- arg_vals

  x
}

# ------------------------------------------------------------------------------

#' @export
check_args.boost_tree <- function(object, call = rlang::caller_env()) {

  args <- lapply(object$args, rlang::eval_tidy)

  check_number_whole(args$trees, min = 0, allow_null = TRUE, call = call, arg = "trees")
  check_number_decimal(args$sample_size, min = 0, max = 1, allow_null = TRUE, call = call, arg = "sample_size")
  check_number_whole(args$tree_depth, min = 0, allow_null = TRUE, call = call, arg = "tree_depth")
  check_number_whole(args$min_n, min = 0, allow_null = TRUE, call = call, arg = "min_n")
  
  invisible(object)
}


# xgboost helpers --------------------------------------------------------------

#' Boosted trees via xgboost
#'
#' `xgb_train()` and `xgb_predict()` are wrappers for `xgboost` tree-based
#' models where all of the model arguments are in the main function.
#'
#' @param x A data frame or matrix of predictors
#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
#' @param max_depth An integer for the maximum depth of the tree.
#' @param nrounds An integer for the number of boosting iterations.
#' @param eta A numeric value between zero and one to control the learning rate.
#' @param colsample_bytree Subsampling proportion of columns for each tree.
#' See the `counts` argument below. The default uses all columns.
#' @param colsample_bynode Subsampling proportion of columns for each node
#' within each tree. See the `counts` argument below. The default uses all
#' columns.
#' @param min_child_weight A numeric value for the minimum sum of instance
#'  weights needed in a child to continue to split.
#' @param gamma A number for the minimum loss reduction required to make a
#'  further partition on a leaf node of the tree
#' @param subsample Subsampling proportion of rows. By default, all of the
#' training data are used.
#' @param validation The _proportion_ of the data that are used for performance
#' assessment and potential early stopping.
#' @param early_stop An integer or `NULL`. If not `NULL`, it is the number of
#' training iterations without improvement before stopping. If `validation` is
#' used, performance is base on the validation set; otherwise, the training set
#' is used.
#' @param counts A logical. If `FALSE`, `colsample_bynode` and
#' `colsample_bytree` are both assumed to be _proportions_ of the proportion of
#' columns affects (instead of counts).
#' @param event_level For binary classification, this is a single string of either
#' `"first"` or `"second"` to pass along describing which level of the outcome
#' should be considered the "event".
#' @param ... Other options to pass to `xgb.train()` or xgboost's method for `predict()`.
#' @return A fitted `xgboost` object.
#' @keywords internal
#' @export
xgb_train <- function(
    x, y, weights = NULL,
    max_depth = 6, nrounds = 15, eta  = 0.3, colsample_bynode = NULL,
    colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
    validation = 0, early_stop = NULL, counts = TRUE,
    event_level = c("first", "second"), ...) {

  event_level <- rlang::arg_match(event_level, c("first", "second"))
  others <- list(...)

  num_class <- length(levels(y))

  if (!is.numeric(validation) || validation < 0 || validation >= 1) {
    rlang::abort("`validation` should be on [0, 1).")
  }

  if (!is.null(early_stop)) {
    if (early_stop <= 1) {
      rlang::abort(paste0("`early_stop` should be on [2, ",  nrounds, ")."))
    } else if (early_stop >= nrounds) {
      early_stop <- nrounds - 1
      rlang::warn(paste0("`early_stop` was reduced to ", early_stop, "."))
    }
  }

  n <- nrow(x)
  p <- ncol(x)

  x <-
    as_xgb_data(x, y,
                validation = validation,
                event_level = event_level,
                weights = weights)


  if (!is.numeric(subsample) || subsample < 0 || subsample > 1) {
    rlang::abort("`subsample` should be on [0, 1].")
  }

  # initialize
  if (is.null(colsample_bytree)) {
    colsample_bytree <- 1
  } else {
    colsample_bytree <- recalc_param(colsample_bytree, counts, p)
  }
  if (is.null(colsample_bynode)) {
    colsample_bynode <- 1
  } else {
    colsample_bynode <- recalc_param(colsample_bynode, counts, p)
  }

  if (min_child_weight > n) {
    msg <- paste0(min_child_weight, " samples were requested but there were ",
                  n, " rows in the data. ", n, " will be used.")
    rlang::warn(msg)
    min_child_weight <- min(min_child_weight, n)
  }

  arg_list <- list(
    eta = eta,
    max_depth = max_depth,
    gamma = gamma,
    colsample_bytree = colsample_bytree,
    colsample_bynode = colsample_bynode,
    min_child_weight = min(min_child_weight, n),
    subsample = subsample
  )

  others <- process_others(others, arg_list)

  main_args <- c(
      list(
        data = quote(x$data),
        watchlist = quote(x$watchlist),
        params = arg_list,
        nrounds = nrounds,
        early_stopping_rounds = early_stop
      ),
      others
  )

  if (is.null(main_args$objective)) {
    if (is.numeric(y)) {
      main_args$objective <- "reg:squarederror"
    } else {
      if (num_class == 2) {
        main_args$objective <- "binary:logistic"
      } else {
        main_args$objective <- "multi:softprob"
      }
    }
  }

  if (!is.null(num_class) && num_class > 2) {
    main_args$num_class <- num_class
  }

  call <- make_call(fun = "xgb.train", ns = "xgboost", main_args)

  eval_tidy(call, env = current_env())
}

process_others <- function(others, arg_list) {
  guarded <- c("data", "weights", "num_class", "watchlist")
  guarded_supplied <- names(others)[names(others) %in% guarded]

  if (length(guarded_supplied) > 0) {
    cli::cli_warn(
      c(
        "!" = "{cli::qty(guarded_supplied)}The argument{?s} {.arg {guarded_supplied}} \
               {?is/are} guarded by parsnip and will not be passed to {.fun xgb.train}."
      ),
      class = "xgboost_guarded_warning"
    )
  }

  others <-
    others[!(names(others) %in% guarded)]

  if (!is.null(others$params)) {
    cli::cli_warn(
      c(
        "!" = "Please supply elements of the `params` list argument as main arguments \
               to `set_engine()` rather than as part of `params`.",
        "i" = "See `?details_boost_tree_xgboost` for more information."
      ),
      class = "xgboost_params_warning"
    )

    params <- others$params[!names(others$params) %in% names(arg_list)]
    others <- c(others[names(others) != "params"], params)
  }

  if (!(any(names(others) == "verbose"))) {
    others$verbose <- 0
  }

  others
}

recalc_param <- function(x, counts, denom) {
  nm <- as.character(match.call()$x)
  if (is.null(x)) {
    x <- 1
  } else {
    if (counts) {
      maybe_proportion(x, nm)
      x <- min(denom, x)/denom
    }
  }
  x
}

maybe_proportion <- function(x, nm) {
  if (x < 1) {
    msg <- paste0(
      "The option `counts = TRUE` was used but parameter `", nm,
      "` was given as ", signif(x, 3), ". Please use a value >= 1 or use ",
      "`counts = FALSE`."
    )
    rlang::abort(msg)
  }
}

#' @rdname xgb_train
#' @param new_data A rectangular data object, such as a data frame.
#' @keywords internal
#' @export
xgb_predict <- function(object, new_data, ...) {
  if (!inherits(new_data, "xgb.DMatrix")) {
    new_data <- maybe_matrix(new_data)
    new_data <- xgboost::xgb.DMatrix(data = new_data, missing = NA)
  }

  res <- predict(object, new_data, ...)

  x <- switch(
    object$params$objective %||% 3L,
    "binary:logitraw" = stats::binomial()$linkinv(res),
    "multi:softprob" = matrix(res, ncol = object$params$num_class, byrow = TRUE),
    res)

  x
}


as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "first", ...) {
  lvls <- levels(y)
  n <- nrow(x)

  if (is.data.frame(x)) {
    x <- as.matrix(x)
  }

  if (is.factor(y)) {
    if (length(lvls) < 3) {
      if (event_level == "first") {
        y <- -as.numeric(y) + 2
      } else {
        y <- as.numeric(y) - 1
      }
    } else {
      if (event_level == "second") rlang::warn("`event_level` can only be set for binary variables.")
      y <- as.numeric(y) - 1
    }
  }

  if (!inherits(x, "xgb.DMatrix")) {
    if (validation > 0) {
      # Split data
      m <- floor(n * (1 - validation)) + 1
      trn_index <- sample(seq_len(n), size = max(m, 2))
      val_data <- xgboost::xgb.DMatrix(
        data = x[-trn_index, , drop = FALSE],
        label = y[-trn_index],
        missing = NA
      )
      watch_list <- list(validation = val_data)

      info_list <- list(label = y[trn_index])
      if (!is.null(weights)) {
        info_list$weight <- weights[trn_index]
      }
      dat <- xgboost::xgb.DMatrix(
        data = x[trn_index, , drop = FALSE],
        missing = NA,
        info = info_list
      )

    } else {
      info_list <- list(label = y)
      if (!is.null(weights)) {
        info_list$weight <- weights
      }
      dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list)
      watch_list <- list(training = dat)
    }
  } else {
    dat <- xgboost::setinfo(x, "label", y)
    if (!is.null(weights)) {
      dat <- xgboost::setinfo(x, "weight", weights)
    }
    watch_list <- list(training = dat)
  }

  list(data = dat, watchlist = watch_list)
}

get_event_level <- function(model_spec){
  if ("event_level" %in% names(model_spec$eng_args)) {
    event_level <- get_expr(model_spec$eng_args$event_level)
  } else {
    # "first" is the default for as_xgb_data() and xgb_train()
    event_level <- "first"
  }
  event_level
}


#' @export
#' @rdname multi_predict
#' @param trees An integer vector for the number of trees in the ensemble.
multi_predict._xgb.Booster <-
  function(object, new_data, type = NULL, trees = NULL, ...) {
    if (is.null(trees)) {
      trees <- object$fit$nIter
    }
    trees <- sort(trees)

    if (is.null(type)) {
      if (object$spec$mode == "classification")
        type <- "class"
      else
        type <- "numeric"
    }

    res <-
      map(trees, xgb_by_tree, object = object, new_data = new_data,
          type = type, ...) %>%
      purrr::list_rbind()
    res <- arrange(res, .row, trees)
    res <- split(res[, -1], res$.row)
    names(res) <- NULL
    tibble(.pred = res)
  }

xgb_by_tree <- function(tree, object, new_data, type, ...) {
  pred <- xgb_predict(
    object$fit,
    new_data = new_data,
    iterationrange = c(1, tree + 1),
    ntreelimit = NULL
  )

  # switch based on prediction type
  if (object$spec$mode == "regression") {
    pred <- tibble(.pred = pred)
    nms <- names(pred)
  } else {
    if (type == "class") {
      pred <- object$spec$method$pred$class$post(pred, object)
      pred <- tibble(.pred_class = factor(pred, levels = object$lvl))
    } else {
      pred <- object$spec$method$pred$prob$post(pred, object)
      pred <- as_tibble(pred)
      names(pred) <- paste0(".pred_", names(pred))
    }
    nms <- names(pred)
  }
  pred[["trees"]] <- tree
  pred[[".row"]] <- seq_len(nrow(new_data))
  pred[, c(".row", "trees", nms)]
}

# C5.0 helpers -----------------------------------------------------------------

#' Boosted trees via C5.0
#'
#' `C5.0_train` is a wrapper for the `C5.0()` function in the
#' \pkg{C50} package that fits tree-based models
#'  where all of the model arguments are in the main function.
#'
#' @param x A data frame or matrix of predictors.
#' @param y A factor vector with 2 or more levels
#' @param trials An integer specifying the number of boosting
#'  iterations. A value of one indicates that a single model is
#'  used.
#' @param weights An optional numeric vector of case weights. Note
#'  that the data used for the case weights will not be used as a
#'  splitting variable in the model (see
#'  \url{https://www.rulequest.com/see5-info.html} for
#'  Quinlan's notes on case weights).
#' @param minCases An integer for the smallest number of samples
#'  that must be put in at least two of the splits.
#' @param sample A value between (0, .999) that specifies the
#'  random proportion of the data should be used to train the model.
#'  By default, all the samples are used for model training. Samples
#'  not used for training are used to evaluate the accuracy of the
#'  model in the printed output. A value of zero means that all the training
#'  data are used.
#' @param ... Other arguments to pass.
#' @return A fitted C5.0 model.
#' @keywords internal
#' @export
C5.0_train <-
  function(x, y, weights = NULL, trials = 15, minCases = 2, sample = 0, ...) {
    other_args <- list(...)
    protect_ctrl <- c("minCases", "sample")
    protect_fit <- "trials"
    f_names <- names(formals(getFromNamespace("C5.0.default", "C50")))
    c_names <- names(formals(getFromNamespace("C5.0Control", "C50")))
    other_args <- other_args[!(other_args %in% c(protect_ctrl, protect_fit))]
    ctrl_args <- other_args[names(other_args) %in% c_names]
    fit_args <- other_args[names(other_args) %in% f_names]

    n <- nrow(x)
    if (n == 0) {
      rlang::abort("There are zero rows in the predictor set.")
    }


    ctrl <- call2("C5.0Control", .ns = "C50")
    if (minCases > n) {
      msg <- paste0(minCases, " samples were requested but there were ",
                    n, " rows in the data. ", n, " will be used.")
      rlang::warn(msg)
      minCases <- n
    }
    ctrl$minCases <- minCases

    ctrl$sample <- sample
    ctrl <- rlang::call_modify(ctrl, !!!ctrl_args)

    fit_call <- call2("C5.0", .ns = "C50")
    fit_call$x <- expr(x)
    fit_call$y <- expr(y)
    fit_call$trials <- trials
    fit_call$control <- ctrl
    if (!is.null(weights)) {
      fit_call$weights <- quote(weights)
    }
    fit_call <- rlang::call_modify(fit_call, !!!fit_args)

    eval_tidy(fit_call)
  }

#' @export
#' @rdname multi_predict
multi_predict._C5.0 <-
  function(object, new_data, type = NULL, trees = NULL, ...) {
    if (is.null(trees))
      trees <- min(object$fit$trials)
    trees <- sort(trees)

    if (is.null(type))
      type <- "class"

    res <-
      map(trees, C50_by_tree, object = object,
          new_data = new_data, type = type, ...) %>%
      purrr::list_rbind()
    res <- arrange(res, .row, trees)
    res <- split(res[, -1], res$.row)
    names(res) <- NULL
    tibble(.pred = res)
  }

C50_by_tree <- function(tree, object, new_data, type, ...) {
  pred <- predict(object$fit, newdata = new_data, trials = tree, type = type)

  # switch based on prediction type
  if (type == "class") {
    pred <- tibble(.pred_class = factor(pred, levels = object$lvl))
  } else {
    pred <- as_tibble(pred)
    names(pred) <- paste0(".pred_", names(pred))
  }
  nms <- names(pred)
  pred[["trees"]] <- tree
  pred[[".row"]] <- seq_len(nrow(new_data))
  pred[, c(".row", "trees", nms)]
}
topepo/parsnip documentation built on April 16, 2024, 3:23 a.m.