R/booster2sql.R

#' Transform XGBoost model object to SQL query.
#'
#' This function generates SQL query for in-database scoring of XGBoost models,
#' providing a robust and efficient way of model deployment. It takes in the trained XGBoost model \code{xgbModel},
#' name of the input database table \code{input_table_name},
#' and name of a unique identifier within that table \code{unique_id} as input,
#' writes the SQL query to a file specified by \code{output_file_name}.
#' Note that the input database table should be generated from the raw table using the one-hot encoding query output by \code{onehot2sql()},
#' or to provide the one-hot encoding query as input \code{input_onehot_query} to this function, working as sub-query inside the final model scoring query.
#'
#' @param xgbModel The trained model object of class \code{xgb.Booster}.
#' Current supported booster is \code{booster="gbtree"}, supported \code{objective} options are:
#' \itemize{
#'   \item – \code{reg:linear}: linear regression.
#'   \item – \code{reg:logistic}: logistic regression.
#'   \item - \code{binary:logistic}: logistic regression for binary classification, output probability.
#'   \item – \code{binary:logitraw}: logistic regression for binary classification, output score before logistic transformation.
#'   \item - \code{binary:hinge}: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities.
#'   \item - \code{count:poisson}: poisson regression for count data, output mean of poisson distribution.
#'   \item - \code{reg:gamma}: gamma regression with log-link, output mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be gamma-distributed.
#'   \item - \code{reg:tweedie}: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be Tweedie-distributed.
#' }
#' @param print_progress Boolean indicator controls whether the SQL generating progress should be printed to console.
#' @param unique_id A row unique identifier is crucial for in-database scoring of XGBoost model. If not given, SQL query will be generated with id name "ROW_KEY".
#' @param output_file_name File name that the SQL syntax will write to. It must not be empty in order for this function to run.
#' @param input_table_name Name of raw data table in the database, that the SQL query will select from. If not given, SQL query will be generated with table name "MODREADY_TABLE".
#' @param input_onehot_query SQL query of one-hot encoding generated by \code{onehot2sql}. When \code{input_table_name} is empty while \code{input_onehot_query} is not, the final output query will include \code{input_onehot_query} as sub-query.
#' @return The SQL query will write to the file specified by \code{output_file_name}.
#'
#' @export
#'
#' @examples
#' library(xgboost)
#' # load data
#' df = data.frame(ggplot2::diamonds)
#' head(df)
#'
#' # data processing
#' out <- onehot2sql(df)
#' x <- out$model.matrix[,colnames(out$model.matrix)!='price']
#' y <- out$model.matrix[,colnames(out$model.matrix)=='price']
#'
#' # model training
#' bst <- xgboost(data = x,
#'                label = y,
#'                max.depth = 3,
#'                eta = .3,
#'                nround = 5,
#'                nthread = 1,
#'                objective = 'reg:linear')
#'
#' # generate model scoring SQL script with ROW_KEY and MODREADY_TABLE
#' booster2sql(bst, output_file_name='xgb.txt')


booster2sql <- function(xgbModel, print_progress=FALSE, unique_id=NULL,
                        output_file_name=NULL, input_table_name=NULL, input_onehot_query=NULL) {

  ###### initial setup ######
  xgb_dump <- xgboost::xgb.dump(xgbModel)
  first_letter <- substring(xgb_dump,1,1)
  all_tree_index <- which(first_letter=="b")
  if (is.null(unique_id)) {
    unique_id <- "ROW_KEY"
    message("query is written to file with row unique id named as ROW_KEY")
  }
  if (is.null(output_file_name)) {
    stop("output file not specified")
  }
  if (is.null(input_table_name) & is.null(input_onehot_query)) {
    input_table_name <- "MODREADY_TABLE"
    message("query is written to file with input table named as MODREADY_TABLE")
  } else if (is.null(input_table_name) & !is.null(input_onehot_query)) {
    input_table_name <- paste0("( \n",input_onehot_query," \n) AS MODREADY_TABLE ")
  }

  ###### recurse fun ######
  fun_recurse_tree <- function(g,local_dump,dump_index,branch_index){
    if (grepl("leaf",local_dump[dump_index])==TRUE) {
      cat(sub(".*leaf= *(.*?)", "\\1", local_dump[dump_index]))
    } else {
      cur_var_name <- g$feature_names[as.numeric(regmatches(local_dump[dump_index],regexec("f(.*?)[<]",local_dump[dump_index]))[[1]][2])+1]
      cur_var_val <- as.numeric(regmatches(local_dump[dump_index],regexec("[<](.*?)[]]",local_dump[dump_index]))[[1]][2])

      # if YES
      left_dump_index <- which(branch_index==
                                 as.numeric(regmatches(local_dump[dump_index],regexec("yes=(.*?)[,]",local_dump[dump_index]))[[1]][2]))
      # if NO
      right_dump_index <- which(branch_index==
                                  as.numeric(regmatches(local_dump[dump_index],regexec("no=(.*?)[,]",local_dump[dump_index]))[[1]][2]))
      # if missing
      missing_dump_index <- which(branch_index==
                                    as.numeric(regmatches(local_dump[dump_index],regexec("missing=(.*?)$",local_dump[dump_index]))[[1]][2]))

      cat("\n (CASE WHEN", paste0("[",cur_var_name,"] < ",cur_var_val), "THEN ")
      cat(fun_recurse_tree(g,local_dump,left_dump_index,branch_index))
      cat("\n  WHEN ",paste0("[",cur_var_name,"] >= ",cur_var_val), "THEN ")
      cat(fun_recurse_tree(g,local_dump,right_dump_index,branch_index))
      cat("\n  WHEN ",paste0("[",cur_var_name,"] IS NULL"), "THEN ")
      cat(fun_recurse_tree(g,local_dump,missing_dump_index,branch_index))
      cat(" END)")
    }
  }

  ###### generate tree ######
  sink(output_file_name, type ="output")

  cat("SELECT ", unique_id, ", ")
  if(xgbModel$params$objective == "binary:logistic" | xgbModel$params$objective == "reg:logistic" | xgbModel$params$objective == "binary:logitraw"){
    p0 <- ifelse(is.null(xgbModel$params$base_score),0.5,xgbModel$params$base_score)
    b0 <- log(p0/(1-p0))
    if (xgbModel$params$objective == "binary:logitraw") {
      cat(b0,"+ SUM(ONETREE) AS XGB_PRED")
    } else {
      cat("1/(1+exp(-(",b0,"+ SUM(ONETREE)))) AS XGB_PRED")
    }
  } else if (xgbModel$params$objective == "binary:hinge") {
    b0 <- ifelse(is.null(xgbModel$params$base_score),0.5,xgbModel$params$base_score)
    cat("IF((",b0,"+ SUM(ONETREE) )>0,1,0) AS XGB_PRED")
  } else if(xgbModel$params$objective == "reg:linear"){
    b0 <- ifelse(is.null(xgbModel$params$base_score),0.5,xgbModel$params$base_score)
    cat(b0,"+ SUM(ONETREE) AS XGB_PRED")
  } else if(xgbModel$params$objective == "reg:gamma" | xgbModel$params$objective == "count:poisson" | xgbModel$params$objective == "reg:tweedie"){
    mu0 <- ifelse(is.null(xgbModel$params$base_score),0.5,xgbModel$params$base_score)
    b0 <- log(mu0)
    cat("exp(",b0,"+ SUM(ONETREE)) AS XGB_PRED")
  } else {
    warning("query is generated with unsupported objective")
  }

  cat("\nFROM (  ")
  for (tree_num in 1:length(all_tree_index)) {
    cat(" \n SELECT", unique_id, ",")

    tree_begin <- all_tree_index[tree_num]+1
    if(is.na(all_tree_index[tree_num+1])){
      tree_end <- length(xgb_dump)
    } else {
      tree_end <- all_tree_index[tree_num+1] - 1
    }

    all_branch_index <- as.numeric(sub("\\D*(\\d+).*", "\\1", xgb_dump))

    branch_index <- all_branch_index[tree_begin:tree_end]
    local_dump <- xgb_dump[tree_begin:tree_end]

    fun_recurse_tree(xgbModel,local_dump,1,branch_index)
    cat(" AS ONETREE")
    if(tree_num != length(all_tree_index)){
      cat(" FROM ", input_table_name, "\n UNION ALL \n")
    }

    if (print_progress==TRUE) {
      sink()
      cat("====== Processing", tree_num, "/", length(all_tree_index), "Tree ======\n")
      sink(output_file_name, type ="output", append=TRUE)
    }
  }
  cat(" FROM ",input_table_name,"  \n) AS TREES_TABLE GROUP BY ",unique_id)

  sink()
}

Try the xgb2sql package in your browser

Any scripts or data that you put into this service are public.

xgb2sql documentation built on May 2, 2019, 1:09 p.m.