demo/predict_leaf_indices.R

require(xgboost)
require(data.table)
require(Matrix)

set.seed(1982)

# load in the agaricus dataset
data(agaricus.train, package = 'xgboost')
data(agaricus.test, package = 'xgboost')
dtrain <- xgb.DMatrix(data = agaricus.train$data, label = agaricus.train$label)
dtest <- xgb.DMatrix(data = agaricus.test$data, label = agaricus.test$label)

param <- list(max_depth = 2, eta = 1, objective = 'binary:logistic')
nrounds <- 4

# training the model for two rounds
bst <- xgb.train(params = param, data = dtrain, nrounds = nrounds, nthread = 2)

# Model accuracy without new features
accuracy.before <- (sum((predict(bst, agaricus.test$data) >= 0.5) == agaricus.test$label)
                    / length(agaricus.test$label))

# by default, we predict using all the trees
pred_with_leaf <- predict(bst, dtest, predleaf = TRUE)
head(pred_with_leaf)

create.new.tree.features <- function(model, original.features){
  pred_with_leaf <- predict(model, original.features, predleaf = TRUE)
  cols <- list()
  for (i in 1:model$niter) {
    # max is not the real max but it s not important for the purpose of adding features
    leaf.id <- sort(unique(pred_with_leaf[, i]))
    cols[[i]] <- factor(x = pred_with_leaf[, i], level = leaf.id)
  }
  cbind(original.features, sparse.model.matrix(~ . - 1, as.data.frame(cols)))
}

# Convert previous features to one hot encoding
new.features.train <- create.new.tree.features(bst, agaricus.train$data)
new.features.test <- create.new.tree.features(bst, agaricus.test$data)
colnames(new.features.test) <- colnames(new.features.train)

# learning with new features
new.dtrain <- xgb.DMatrix(data = new.features.train, label = agaricus.train$label)
new.dtest <- xgb.DMatrix(data = new.features.test, label = agaricus.test$label)
watchlist <- list(train = new.dtrain)
bst <- xgb.train(params = param, data = new.dtrain, nrounds = nrounds, nthread = 2)

# Model accuracy with new features
accuracy.after <- (sum((predict(bst, new.dtest) >= 0.5) == agaricus.test$label)
                   / length(agaricus.test$label))

# Here the accuracy was already good and is now perfect.
cat(paste("The accuracy was", accuracy.before, "before adding leaf features and it is now",
          accuracy.after, "!\n"))

Try the xgboost package in your browser

Any scripts or data that you put into this service are public.

xgboost documentation built on March 31, 2023, 10:05 p.m.