#' Stratify learner fits by a single variable
#'
#' @docType class
#' @importFrom R6 R6Class
#'
#' @export
#'
#' @keywords data
#'
#' @return Learner object with methods for training and prediction. See
#' \code{\link{Lrnr_base}} for documentation on learners.
#'
#' @format \code{\link{R6Class}} object.
#'
#' @family Learners
#'
#' @section Parameters:
#' \describe{
#' \item{\code{learner="learner"}}{An initialized Lrnr_* object.
#' }
#' \item{\code{variable_stratify="variable_stratify"}}{\code{character} giving
#' the variable in the covariates on which to stratify. Supports only
#' variables with discrete levels coded as \code{numeric}.
#' }
#' \item{\code{...}}{Other parameters passed directly to
#' \code{learner$train}. See its documentation for details.
#' }
#' }
#'
#' @examples
#' library(data.table)
#'
#' # load example data set
#' data(cpp_imputed)
#' setDT(cpp_imputed)
#'
#' # use covariates of intest and the outcome to build a task object
#' covars <- c("apgar1", "apgar5", "sexn")
#' task <- sl3_Task$new(cpp_imputed, covariates = covars, outcome = "haz")
#'
#' hal_lrnr <- Lrnr_hal9001$new(fit_control = list(n_folds = 3))
#' stratified_hal <- Lrnr_stratified$new(
#' learner = hal_lrnr,
#' variable_stratify = "sexn"
#' )
#'
#' # stratified learner
#' set.seed(123)
#' stratified_hal_fit <- stratified_hal$train(task)
#' stratified_prediction <- stratified_hal_fit$predict(task = task)
Lrnr_stratified <- R6Class(
classname = "Lrnr_stratified", inherit = Lrnr_base,
portable = TRUE, class = TRUE,
public = list(
initialize = function(learner, variable_stratify, ...) {
# learner is an already initialized learner
params <- list(
learner = learner, variable_stratify = variable_stratify,
...
)
super$initialize(params = params, ...)
}
),
active = list(
name = function() {
name <- paste("strat", self$params$variable_stratify,
self$params$learner$name,
sep = "_"
)
}
),
private = list(
.properties = c("continuous", "binomial", "wrapper"),
.train = function(task) {
args <- self$params
args$X <- as.matrix(task$X)
strata_ids <- unlist(task$data[, args$variable_stratify, with = FALSE])
variable_stratify_stratas <- unique(strata_ids)
# fit_object is a dictionary of instantiated of Lrnr_* objects
fit_object <- list()
for (strata in variable_stratify_stratas) {
index_in_strata <- which(strata_ids == strata)
sub_task <- task$subset_task(
row_index = index_in_strata,
drop_folds = TRUE
)
# remove the `variable_stratify` from the sub-task
sub_task <- sub_task$next_in_chain(
covariates = sub_task$nodes$covariates[
sub_task$nodes$covariates != args$variable_stratify
]
)
# assume that `variable_stratify` is a numeric multinomial factor
fit_object[[as.character(strata)]] <- args$learner$train(sub_task)
}
return(fit_object)
},
.predict = function(task = NULL) {
learner_dict <- self$fit_object
variable_stratify_stratas <- names(learner_dict)
variable_stratify <- self$params$variable_stratify
strata_ids <- unlist(task$data[, variable_stratify, with = FALSE])
variable_stratify_stratas_new <- unique(strata_ids)
if (
length(
setdiff(variable_stratify_stratas_new, variable_stratify_stratas)
) > 0
) {
stop("There is new strata in the prediction data that is not present in
training data!")
}
prediction_df_dict <- list()
# predictions <- aorder(results$predictions, order(results$index))
for (strata in variable_stratify_stratas_new) {
index_subtask <- which(strata_ids == strata)
# construct subtask
sub_task <- task$subset_task(row_index = index_subtask)
sub_task <- sub_task$next_in_chain(
covariates = sub_task$nodes$covariates[
sub_task$nodes$covariates != variable_stratify
]
)
# predict on the subtask
prediction_subtask <- learner_fit_predict(
learner_dict[[as.character(strata)]],
sub_task
)
result <- list(
prediction = prediction_subtask,
original_index = index_subtask
)
prediction_df_dict[[as.character(strata)]] <- result
}
results <- apply(do.call(rbind, prediction_df_dict), 2, as.list)
results <- origami::combine_results(results)
predictions <- aorder(results$prediction, order(results$original_index))
return(predictions)
},
.required_packages = NULL
)
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.