xgb_predict | R Documentation |
Predict from an xgboost model at a given number of rounds, across resamples
xgb_predict(
object,
newdata = NULL,
niter = NULL,
fns = "auto",
add_data = TRUE,
...
)
object |
an object output by |
newdata |
data.frame to predict, with the same variables as those used for fitting (and possibly others). When NULL, predict the validation data for each resample. |
niter |
number of boosting iterations to use in the prediction. Maps to
the last bound of |
fns |
a named list of summary functions, to compute over the predictions
of each observation, across resamples (when there are more than one). If
NULL, return all predictions. If "auto", the default, choose a function
appropriate for the type of response variable: |
add_data |
boolean, whether to add the original data to the output (defaults to TRUE which is practical to compute performance metrics). |
... |
passed to xgboost::predict.xgb.Booster() |
A tibble with the grouping columns in object
(ungroup the object
before xgb_predict()
if this is not the desired behaviour) and the
prediction as
pred_***
where *** is a summary function (e.g. mean), or
pred
when no summary function is chosen. - the original data if
add_data
is TRUE.
## Regression
# fit models over 4 folds of cross-validation, repeated 3 times
fits <- resample_cv(mtcars, k=4, n=3) %>%
xgb_fit(resp="mpg", expl=c("cyl", "hp", "qsec"),
eta=0.1, max_depth=2, nrounds=30)
# compute the predicted mpg, with 20 trees only, and, by default average
# across the 3 repetitions
res <- xgb_predict(fits, niter=20)
head(res)
# check that we have predicted all items in the dataset (should always be the
# case with cross validation)
nrow(res)
nrow(mtcars)
# compute the Root Mean Squared Error
sqrt( sum((res$pred_mean - res$mpg)^2) / nrow(res) )
# compute several regression metrics
regression_metrics(res$pred_mean, res$mpg)
# examine the variability among the 3 repetitions of the cross validation
# do not average over the repetitions => we get 3x32 lines
res <- xgb_predict(fits, niter=20, fns=NULL)
nrow(res)
# compute the mean but also the standard deviation and error across repetitions
res <- xgb_predict(fits, niter=20, fns=list(mean=mean, sd=sd, se=se))
head(res)
## Classification
# fit models over 4 folds of cross-validation, repeated 3 times
mtcarsf <- mutate(mtcars, cyl=factor(cyl))
fits <- resample_cv(mtcarsf, k=4, n=3) %>%
xgb_fit(resp="cyl", expl=c("mpg", "hp", "qsec"),
eta=0.1, max_depth=2, nrounds=30)
# compute the predicted number of cylinders (cyl) but with only 15 of the 30
# rounds; by default, use the majority vote across the 3 repetitions
res <- xgb_predict(fits, niter=15)
head(res)
# compute accuracy
sum(res$pred_maj == res$cyl) / nrow(res)
# compute several global classification metrics
classification_metrics(res$pred_maj, res$cyl)
# use a different objective for classification
fits <- resample_cv(mtcarsf, k=4, n=3) %>%
xgb_fit(resp="cyl", expl=c("mpg", "hp", "qsec"),
objective="multi:softprob",
eta=0.1, max_depth=2, nrounds=30)
# because the objective is softprob, we predict the probability for each level
res <- xgb_predict(fits)
head(res)
# get the predicted class
res$max_prob_idx <- res %>%
select(starts_with("pred_")) %>%
apply(1, which.max)
res$pred_cyl <- refactor(res$max_prob_idx-1L, levels=levels(mtcarsf$cyl))
head(res)
# NB: refactor() uses 0-based indexing and needs integers, hence the -1L
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.