tests/testthat/test_model_compatibility.R

require(xgboost)
require(jsonlite)

context("Models from previous versions of XGBoost can be loaded")

metadata <- list(
  kRounds = 2,
  kRows = 1000,
  kCols = 4,
  kForests = 2,
  kMaxDepth = 2,
  kClasses = 3
)

run_model_param_check <- function (config) {
  testthat::expect_equal(config$learner$learner_model_param$num_feature, '4')
  testthat::expect_equal(config$learner$learner_train_param$booster, 'gbtree')
}

get_num_tree <- function (booster) {
  dump <- xgb.dump(booster)
  m <- regexec('booster\\[[0-9]+\\]', dump, perl = TRUE)
  m <- regmatches(dump, m)
  num_tree <- Reduce('+', lapply(m, length))
  return (num_tree)
}

run_booster_check <- function (booster, name) {
  # If given a handle, we need to call xgb.Booster.complete() prior to using xgb.config().
  if (inherits(booster, "xgb.Booster") && xgboost:::is.null.handle(booster$handle)) {
    booster <- xgb.Booster.complete(booster)
  }
  config <- jsonlite::fromJSON(xgb.config(booster))
  run_model_param_check(config)
  if (name == 'cls') {
    testthat::expect_equal(get_num_tree(booster),
                           metadata$kForests * metadata$kRounds * metadata$kClasses)
    testthat::expect_equal(as.numeric(config$learner$learner_model_param$base_score), 0.5)
    testthat::expect_equal(config$learner$learner_train_param$objective, 'multi:softmax')
    testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class),
                           metadata$kClasses)
  } else if (name == 'logitraw') {
    testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
    testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0)
    testthat::expect_equal(config$learner$learner_train_param$objective, 'binary:logitraw')
  } else if (name == 'logit') {
    testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
    testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0)
    testthat::expect_equal(config$learner$learner_train_param$objective, 'binary:logistic')
  } else if (name == 'ltr') {
    testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
    testthat::expect_equal(config$learner$learner_train_param$objective, 'rank:ndcg')
  } else {
    testthat::expect_equal(name, 'reg')
    testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds)
    testthat::expect_equal(as.numeric(config$learner$learner_model_param$base_score), 0.5)
    testthat::expect_equal(config$learner$learner_train_param$objective, 'reg:squarederror')
  }
}

test_that("Models from previous versions of XGBoost can be loaded", {
  bucket <- 'xgboost-ci-jenkins-artifacts'
  region <- 'us-west-2'
  file_name <- 'xgboost_r_model_compatibility_test.zip'
  zipfile <- file.path(getwd(), file_name)
  model_dir <- file.path(getwd(), 'models')
  download.file(paste('https://', bucket, '.s3-', region, '.amazonaws.com/', file_name, sep = ''),
                destfile = zipfile, mode = 'wb', quiet = TRUE)
  unzip(zipfile, overwrite = TRUE)

  pred_data <- xgb.DMatrix(matrix(c(0, 0, 0, 0), nrow = 1, ncol = 4))

  lapply(list.files(model_dir), function (x) {
    model_file <- file.path(model_dir, x)
    m <- regexec("xgboost-([0-9\\.]+)\\.([a-z]+)\\.[a-z]+", model_file, perl = TRUE)
    m <- regmatches(model_file, m)[[1]]
    model_xgb_ver <- m[2]
    name <- m[3]
    is_rds <- endsWith(model_file, '.rds')
    is_json <- endsWith(model_file, '.json')

    cpp_warning <- capture.output({
      # Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
      if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') < 0) {
        booster <- readRDS(model_file)
        expect_warning(predict(booster, newdata = pred_data))
        booster <- readRDS(model_file)
        expect_warning(run_booster_check(booster, name))
      } else {
        if (is_rds) {
          booster <- readRDS(model_file)
        } else {
          booster <- xgb.load(model_file)
        }
        predict(booster, newdata = pred_data)
        run_booster_check(booster, name)
      }
    })
    cpp_warning <- paste0(cpp_warning, collapse = ' ')
    if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') >= 0) {
      # Expect a C++ warning when a model is loaded from RDS and it was generated by old XGBoost`
      m <- grepl(paste0('.*If you are loading a serialized model ',
                        '\\(like pickle in Python, RDS in R\\).*',
                        'for more details about differences between ',
                        'saving model and serializing.*'), cpp_warning, perl = TRUE)
      expect_true(length(m) > 0 && all(m))
    }
  })
})

Try the xgboost package in your browser

Any scripts or data that you put into this service are public.

xgboost documentation built on March 31, 2023, 10:05 p.m.