knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
library(add2xgb)
library(xgboost)
library(tidyverse)
library(whisker)
set.seed(123)
train_data <- mtcars %>% 
    rename(y = am)
dtrain <- 
    xgb.DMatrix(
        data = as.matrix(
            train_data %>% select(-y)
        )
        ,label = train_data$y
    )
library(xgboost)
xgb_model <- xgb.train(
    data=dtrain,
    nround=10,
    seed = 1, 
    max_depth = 1,
    objective = "binary:logistic",
    base_score = mean(train_data$y) # fix uncalibration problem
)
model_trees <- jsonlite::fromJSON(
    xgb.dump(xgb_model, with_stats = FALSE, dump_format='json'), 
    simplifyDataFrame = FALSE)
feature_dict <- as.list(xgb_model$feature_names)
library(add2xgb)
body(add2xgb::xgb_tree_sql)

WHEN {{{split_long}}} < {{{split_condition}}} THEN {{{yes_sql}}}主要看 dump 文档,这里是小于符号 q的书写保证了每个距离差了一\n

body(add2xgb::xgb_sql_score_query)
queries <- xgb_sql_score_query(
    model_trees, 
    'mtcars',
    feature_dict,
    base_score = mean(train_data$y)
)
queries %>% cat()
queries %>% write_file("mtcars_model_code.sql")
pred_from_model <- predict(xgb_model, newdata = dtrain)
library(sqldf)
pred_from_sql <- 
read_file("mtcars_model_code.sql") %>% 
  str_remove("id,") %>% 
  sqldf() %>% 
  .$score
library(tidypredict)
pred_from_tidypredict <- 
tidypredict_sql(xgb_model, dbplyr::simulate_dbi()) %>% 
  paste("select ",.," from mtcars") %>% 
  # cat
    sqldf() %>% 
  pull
(pred_from_model-pred_from_sql) %>% abs %>% mean


JiaxiangBU/add2xgb documentation built on July 15, 2020, 8:32 a.m.