xgb_predict: Predict from an xgboost model at a given number of rounds,...

View source: R/xgb_predict.R

xgb_predictR Documentation

Predict from an xgboost model at a given number of rounds, across resamples

Description

Predict from an xgboost model at a given number of rounds, across resamples

Usage

xgb_predict(
  object,
  newdata = NULL,
  niter = NULL,
  fns = "auto",
  add_data = TRUE,
  ...
)

Arguments

object

an object output by xgb_fit(), which contains a model column.

newdata

data.frame to predict, with the same variables as those used for fitting (and possibly others). When NULL, predict the validation data for each resample.

niter

number of boosting iterations to use in the prediction. Maps to the last bound of iterationrange in xgboost::predict.xgb.Booster(). niter=0 or NULL means use all boosting rounds. Other values are equivalent to what is set in nrounds in xgb_fit().

fns

a named list of summary functions, to compute over the predictions of each observation, across resamples (when there are more than one). If NULL, return all predictions. If "auto", the default, choose a function appropriate for the type of response variable: base::mean() for a numeric, continuous variable; majority_vote() for a factor.

add_data

boolean, whether to add the original data to the output (defaults to TRUE which is practical to compute performance metrics).

...

passed to xgboost::predict.xgb.Booster()

Value

A tibble with the grouping columns in object (ungroup the object before xgb_predict() if this is not the desired behaviour) and the prediction as

  • ⁠pred_***⁠ where *** is a summary function (e.g. mean), or

  • pred when no summary function is chosen. - the original data if add_data is TRUE.

Examples

## Regression

# fit models over 4 folds of cross-validation, repeated 3 times
fits <- resample_cv(mtcars, k=4, n=3) %>%
  xgb_fit(resp="mpg", expl=c("cyl", "hp", "qsec"),
          eta=0.1, max_depth=2, nrounds=30)

# compute the predicted mpg, with 20 trees only, and, by default average
# across the 3 repetitions
res <- xgb_predict(fits, niter=20)
head(res)

# check that we have predicted all items in the dataset (should always be the
# case with cross validation)
nrow(res)
nrow(mtcars)
# compute the Root Mean Squared Error
sqrt( sum((res$pred_mean - res$mpg)^2) / nrow(res) )
# compute several regression metrics
regression_metrics(res$pred_mean, res$mpg)

# examine the variability among the 3 repetitions of the cross validation
# do not average over the repetitions => we get 3x32 lines
res <- xgb_predict(fits, niter=20, fns=NULL)
nrow(res)
# compute the mean but also the standard deviation and error across repetitions
res <- xgb_predict(fits, niter=20, fns=list(mean=mean, sd=sd, se=se))
head(res)


##  Classification

# fit models over 4 folds of cross-validation, repeated 3 times
mtcarsf <- mutate(mtcars, cyl=factor(cyl))
fits <- resample_cv(mtcarsf, k=4, n=3) %>%
  xgb_fit(resp="cyl", expl=c("mpg", "hp", "qsec"),
          eta=0.1, max_depth=2, nrounds=30)

# compute the predicted number of cylinders (cyl) but with only 15 of the 30
# rounds; by default, use the majority vote across the 3 repetitions
res <- xgb_predict(fits, niter=15)
head(res)
# compute accuracy
sum(res$pred_maj == res$cyl) / nrow(res)
# compute several global classification metrics
classification_metrics(res$pred_maj, res$cyl)

# use a different objective for classification
fits <- resample_cv(mtcarsf, k=4, n=3) %>%
  xgb_fit(resp="cyl", expl=c("mpg", "hp", "qsec"),
          objective="multi:softprob",
          eta=0.1, max_depth=2, nrounds=30)
# because the objective is softprob, we predict the probability for each level
res <- xgb_predict(fits)
head(res)
# get the predicted class
res$max_prob_idx <- res %>%
  select(starts_with("pred_")) %>%
  apply(1, which.max)
res$pred_cyl <- refactor(res$max_prob_idx-1L, levels=levels(mtcarsf$cyl))
head(res)
# NB: refactor() uses 0-based indexing and needs integers, hence the -1L

jiho/joml documentation built on Dec. 6, 2023, 5:50 a.m.