R/model-xgboost.R

Defines functions mlflow_predict.xgb.Booster mlflow_load_flavor.mlflow_flavor_xgboost mlflow_save_model.xgb.Booster

Documented in mlflow_save_model.xgb.Booster

#' @include model-utils.R
NULL

#' @rdname mlflow_save_model
#' @export
mlflow_save_model.xgb.Booster <- function(model,
                                          path,
                                          model_spec = list(),
                                          conda_env = NULL,
                                          ...) {
  assert_pkg_installed("xgboost")
  if (dir.exists(path)) unlink(path, recursive = TRUE)
  dir.create(path)

  model_data_subpath <- "model.xgb"
  xgboost::xgb.save(model, fname = file.path(path, model_data_subpath))
  version <- remove_patch_version(
    as.character(utils::packageVersion("xgboost"))
  )

  pip_deps <- list("mlflow", paste("xgboost>=", version, sep = ""))
  conda_env <- create_default_conda_env_if_absent(path, conda_env, default_pip_deps = pip_deps)
  python_env <- create_python_env(path, dependencies = pip_deps)
  xgboost_conf <- list(
    xgboost = list(xgb_version = version, data = model_data_subpath)
  )
  pyfunc_conf <- create_pyfunc_conf(
    loader_module = "mlflow.xgboost",
    data = model_data_subpath,
    env = list(conda = conda_env, virtualenv = python_env)
  )
  model_spec$flavors <- append(append(model_spec$flavors, xgboost_conf), pyfunc_conf)

  mlflow_write_model_spec(path, model_spec)
}

#' @export
mlflow_load_flavor.mlflow_flavor_xgboost <- function(flavor, model_path) {
  assert_pkg_installed("xgboost")
  model_data_subpath <- "model.xgb"
  xgboost::xgb.load(file.path(model_path, model_data_subpath))
}

#' @export
mlflow_predict.xgb.Booster <- function(model, data, ...) {
  assert_pkg_installed("xgboost")
  stats::predict(model, xgboost::xgb.DMatrix(as.matrix(data)), ...)
}

Try the mlflow package in your browser

Any scripts or data that you put into this service are public.

mlflow documentation built on Nov. 23, 2023, 9:13 a.m.