knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
library(add2xgb)
library(xgboost) library(tidyverse) library(whisker) set.seed(123)
train_data <- mtcars %>% rename(y = am)
dtrain <- xgb.DMatrix( data = as.matrix( train_data %>% select(-y) ) ,label = train_data$y )
library(xgboost) xgb_model <- xgb.train( data=dtrain, nround=10, seed = 1, max_depth = 1, objective = "binary:logistic", base_score = mean(train_data$y) # fix uncalibration problem )
model_trees <- jsonlite::fromJSON( xgb.dump(xgb_model, with_stats = FALSE, dump_format='json'), simplifyDataFrame = FALSE)
feature_dict <- as.list(xgb_model$feature_names)
library(add2xgb) body(add2xgb::xgb_tree_sql)
WHEN {{{split_long}}} < {{{split_condition}}} THEN {{{yes_sql}}}
主要看 dump 文档,这里是小于符号
q
的书写保证了每个距离差了一\n
body(add2xgb::xgb_sql_score_query)
queries <- xgb_sql_score_query( model_trees, 'mtcars', feature_dict, base_score = mean(train_data$y) ) queries %>% cat() queries %>% write_file("mtcars_model_code.sql")
pred_from_model <- predict(xgb_model, newdata = dtrain)
library(sqldf) pred_from_sql <- read_file("mtcars_model_code.sql") %>% str_remove("id,") %>% sqldf() %>% .$score
library(tidypredict) pred_from_tidypredict <- tidypredict_sql(xgb_model, dbplyr::simulate_dbi()) %>% paste("select ",.," from mtcars") %>% # cat sqldf() %>% pull
(pred_from_model-pred_from_sql) %>% abs %>% mean
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.