#' Fit model
#'
#' Fit model, make predictions, and generate SHAP values.
#' @param tweets_transformed Tweets retrieved from `retrieve_tweets()` and passed into `transform_tweeets()`
#' @param stem Either `"favorite"` or `"retweet"`
#' @param overwrite Whether to overwrite existing fit, predictions, and SHAP values.
#' @param ... Extra arguments to pass to `.transform_tweets()`
#' @param .overwrite Specific booleans for overwriting specific outputs saved to file. Default is to use same value as `overwrite`. Only use this if you know what you're doing.
#' @param .path Specific paths to export the tune results and fitted model to. If left as `NULL`, they are saved as `"res_tune_cv_{stem}.rds"` and `"fit_{stem}"` in the directory at `getOption("xengagement.dir_data")`. As with `.overwrite`, only use this if you know what you're doing.
#' @export
do_fit <-
function(tweets_transformed,
stem = get_valid_stems(),
overwrite = TRUE,
...,
.overwrite = list(
tune = NULL,
fit = NULL
),
.path = list(
tune = NULL,
fit = NULL
)) {
# tweets <- .import_tweets(overwrite = FALSE, append = TRUE, export = TRUE)
# stem = 'favorite'
# overwrite <- TRUE
# .overwrite =
# list(
# tune = NULL,
# fit = NULL
# )
.validate_stem(stem)
cols_lst <- .get_cols_lst(stem = stem)
data <-
tweets_transformed %>%
dplyr::filter(!is_fresh)
.path_data_x <- function(file, ext = NULL) {
.path_data(file = sprintf('%s_%s', file, stem), ext = ext)
}
.path_data_rds_x <- purrr::partial(.path_data_x, ext = 'rds', ... = )
path_res_tune_cv <- .path$res_tune_cv %||% .path_data_x('res_tune_cv', ext = 'rds')
path_fit <- .path$fit %||% .path_data_x('fit')
.overwrite$tune <- .overwrite$tune %||% overwrite
.overwrite$fit <- .overwrite$fit %||% overwrite
col_y_sym <- cols_lst$col_y %>% sym()
data <- data %>% tidyr::drop_na(!!col_y_sym)
x_mat <- data %>% dplyr::select(dplyr::one_of(c(cols_lst$cols_x))) %>% .df2mat()
# TODO: Make these package options?
nrounds <- 2000
booster <- 'gbtree'
objective <- 'reg:squarederror'
eval_metrics <- list('rmse')
early_stopping_rounds <- 10
print_every_n <- 100
n_fold <- 10
y <- data[[cols_lst$col_y]]
wt <- data[[cols_lst$col_wt]]
n_col <- ncol(x_mat)
x_dmat <-
xgboost::xgb.DMatrix(
x_mat,
weight = wt,
label = y
)
x_dmat
.f_tune <- function() {
seed <- .get_seed()
set.seed(seed)
folds_ids <-
.create_folds(
data[[cols_lst$col_strata]],
k = n_fold,
list = FALSE,
returnTrain = FALSE
)
folds_ids
col_strata <- cols_lst$col_strata
col_strata_sym <- col_strata %>% sym()
folds <-
data %>%
dplyr::bind_cols(dplyr::tibble(fold = folds_ids)) %>%
dplyr::left_join(
data %>% dplyr::select(!!col_strata_sym, idx),
by = c('idx', col_strata)
) %>%
dplyr::select(fold, idx) %>%
split(.$fold) %>%
purrr::map(~dplyr::select(.x, -fold) %>% dplyr::pull(idx))
n_obs <- folds %>% purrr::flatten_int() %>% length()
max_idx <- folds %>% purrr::flatten_int() %>% max()
assertthat::assert_that(n_obs == max_idx)
n_row <- 30
grid_params <-
dials::grid_latin_hypercube(
dials::mtry(range = c(round(n_col / 3), n_col)),
dials::min_n(),
dials::tree_depth(),
dials::learn_rate(),
size = n_row
) %>%
dplyr::mutate(
learn_rate = 0.1 * ((1:dplyr::n()) / dplyr::n()),
mtry = mtry / n_col,
idx = dplyr::row_number()
) %>%
dplyr::relocate(idx)
grid_params
res_tune_cv <-
.tune_xgb_cv(
nrounds = nrounds,
stem = stem,
grid_params = grid_params,
folds = folds,
x_dmat = x_dmat,
booster = booster,
objective = objective,
eval_metrics = eval_metrics,
sample_weight = wt,
early_stopping_rounds = early_stopping_rounds,
print_every_n = print_every_n
)
res_tune_cv
}
res_tune_cv <-
do_get(
f = .f_tune,
path = path_res_tune_cv,
f_import = readr::read_rds,
f_export = readr::write_rds,
append = FALSE,
export = TRUE,
overwrite = .overwrite$tune
)
.f_fit <- function() {
eval_metric <- eval_metrics[1]
eval_metric_tst <- sprintf('%s_tst', eval_metric)
eval_metric_tst_sym <- eval_metric_tst %>% sym()
res_cv_best <- res_tune_cv %>% dplyr::slice_min(!!eval_metric_tst_sym)
res_cv_best
.pluck_param <- function(x) {
res_cv_best %>% purrr::pluck(x)
}
params_best <-
list(
booster = booster,
objective = objective,
eval_metric = eval_metrics,
eta = .pluck_param('eta'),
colsample_bytree = .pluck_param('colsample_bytree'),
min_child_weight = .pluck_param('min_child_weight')
)
params_best
nrounds_best <- round((.pluck_param('iter') / ((n_fold - 1) / (n_fold))), 0) + early_stopping_rounds
fit <-
xgboost::xgboost(
params = params_best,
data = x_dmat,
# data = x_mat,
# label = y,
# sample_weight = wt,
nrounds = nrounds_best,
early_stopping_rounds = early_stopping_rounds,
print_every_n = print_every_n,
verbose = 1
)
}
fit <-
do_get(
f = .f_fit,
path = path_fit,
f_import = xgboost::xgb.load,
f_export = xgboost::xgb.save,
append = FALSE,
export = TRUE,
overwrite = .overwrite$fit
)
list(res_tune_cv = res_tune_cv, fit = fit)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.