R/fit_one_xgb.R

Defines functions fit_one_xgb

Documented in fit_one_xgb

#' Fit an xgboost model to *one* `resamples` object
#'
#' This is the "workhorse" function called by the different methods of xgb_fit.
#'
#' @param object of class resamples, created by a `resample_***()` function.
#' @param resp name of the response variable.
#' @param expl names of the explanatory variables.
#' @param params named list of parameters passed to `xgboost::xgb.train()`.
#' @param nrounds number of boosting rounds (i.e. number of trees).
#' @param verbose 0 = silent, 1 = display performance, 2 = more verbose.
#' @param weight a vector of observation-level weights, one per line of the
#'   training set.
#' @param threads number of paralel threads used to fit each model. This is set
#'   to 1 by default to avoid conflict with parallelisation per resample (in
#'   [`xgb_fit()`]), which is often more efficient. Set this to more than 1 when
#'   fitting only one model.
#' @param ... other parameters passed to `xgboost::xgb.Train()``
#'
#' @returns The input object (a tibble of class `resamples`) with an additional
#'   column called `model` containing the fitted model object.
#'
#' @importFrom dplyr `%>%`
#' @export
#' @examples
#' # regression
#' rs <- resample_identity(mtcars, 1)
#' m <- fit_one_xgb(object=rs,
#'   resp="mpg", expl=c("cyl", "hp", "qsec"),
#'   eta=0.1, max_depth=4,
#'   nrounds=20
#' )
#' m$model
#'
#' # classification
#' mtcarsf <- mutate(mtcars, cyl=factor(cyl))
#' rs <- resample_identity(mtcarsf, 1)
#' m <- fit_one_xgb(object=rs,
#'   resp="cyl", expl=c("mpg", "hp", "qsec"),
#'   eta=0.1, max_depth=4,
#'   nrounds=20
#' )
#' m$model
#'
#' # parameters can also be passed as a list
#' m_list <- fit_one_xgb(object=rs,
#'   resp="cyl", expl=c("mpg", "hp", "qsec"),
#'   params=list(eta=0.1, max_depth=4),
#'   nrounds=20
#' )
#' m$model[[1]]$params
#' m_list$model[[1]]$params
fit_one_xgb <- function(object, resp, expl, params=list(), nrounds, verbose=0,
                        weight=NULL, threads=1, ...) {
  # TODO Add checks for arguments

  # extract training set, in dMatrix form, for xgboost
  train  <- as.data.frame(object$train, check.names=FALSE)
  if (is.factor(train[[resp]])) {
    # we are in classification mode, record additional things
    classif <- TRUE
    levels <- levels(train[[resp]])
    num_class <- length(levels)
    # convert to integer for xgboost
    train[[resp]] <- defactor(train[[resp]])
  } else {
    classif <- FALSE
    num_class <- NULL
  }

  dTrain <- xgboost::xgb.DMatrix(
    data =data.matrix(train[expl]),
    label=train[[resp]],
    # use mono-core here, by default
    # we will parallelise at a higher level
    nthread=threads
  )
  if (!is.null(weight)) {
    xgboost::setinfo(dTrain, "weight", weight[as.integer(object$train[[1]])])
  }
  # TODO allow resp and expl to be unquoted, like in select()

  # do the same for the validation set, taking into the account the case when
  # it is missing
  val <- as.data.frame(object$val, check.names=FALSE)
  if (nrow(val) == 0) {
    val_list <- list()
  } else {
    if (classif) {
      val[[resp]] <- defactor(val[[resp]])
    }
    val_list <- list(
      val = xgboost::xgb.DMatrix(
        data =data.matrix(val[expl]),
        label=val[[resp]],
        nthread=threads
      )
    )
  }

  # train the model on the training set, with the provided params
  m <- xgboost::xgb.train(data=dTrain,
    params=c(params, num_class=num_class), nrounds=nrounds,
    watchlist=val_list,
    verbose=verbose,
    nthread=threads,
    ...
  )
  if (classif) {
    m$levels <- levels
  }

  # add model to this `resamples` object
  object$model <- list(m)
  return(object)
}
jiho/joml documentation built on Dec. 6, 2023, 5:50 a.m.