R/orsf_R6.R

# TODO:
# - add nocov to cpp
# - tests for survival forest w/no censored


# ObliqueForest class ----

ObliqueForest <- R6::R6Class(
 "ObliqueForest",


 public = list(

  # user-facing fields
  data = NULL,
  trained = FALSE,
  forest = list(),
  importance = vector(mode = 'double', length = 0),
  pred_oobag = matrix(ncol = 0, nrow = 0),
  eval_oobag = list(),

  # fitting fields
  formula = NULL,
  control = NULL,
  weights = NULL,

  # count fields
  n_tree = NULL,
  n_split = NULL,
  n_retry = NULL,
  n_thread = NULL,
  n_obs = NULL,
  mtry = NULL,

  # sampling fields
  sample_with_replacement = NULL,
  sample_fraction = NULL,

  # this is used if length(tree_seeds) = 1
  forest_seed = NULL,

  # tree fields
  tree_seeds = NULL,
  tree_type = NULL,
  leaf_min_events = NULL,
  leaf_min_obs = NULL,
  split_rule = NULL,
  split_min_events = NULL,
  split_min_obs = NULL,
  split_min_stat = NULL,

  # out-of-bootstrap aggregate (oobag) fields
  oobag_pred_mode = NULL,
  pred_type = NULL,
  pred_horizon = NULL,
  oobag_eval_type = NULL,
  oobag_eval_every = NULL,
  oobag_eval_function = NULL,

  # prediction fields
  pred_aggregate = NULL,

  # variable importance fields
  importance_type = NULL,
  importance_max_pvalue = NULL,
  importance_group_factors = NULL,

  # missing data fields
  na_action = NULL,

  # miscellaneous
  verbose_progress = NULL,


  # Assign incoming member variables, then run input checks.
  initialize = function(data,
                        formula,
                        control,
                        weights = NULL,
                        n_tree,
                        n_split,
                        n_retry,
                        n_thread,
                        mtry = NULL,
                        sample_with_replacement,
                        sample_fraction,
                        leaf_min_events,
                        leaf_min_obs,
                        split_rule,
                        split_min_events,
                        split_min_obs,
                        split_min_stat,
                        pred_type,
                        oobag_pred_horizon,
                        oobag_eval_every = NULL,
                        oobag_fun = NULL,
                        importance_type,
                        importance_max_pvalue,
                        importance_group_factors,
                        tree_seeds,
                        na_action,
                        verbose_progress) {

   # if these arguments were specified in the initial
   # call or in updates, then they should be checked.
   # Note the arguments in this list are NULL in orsf()
   # by default, which means their defaults are dynamic.
   # other arguments with non-dynamic defaults should
   # always be checked if a user wants to use update().

   private$user_specified <- list(
    control             = !is.null(control),
    lincomb_df_target   = !is.null(control$lincomb_df_target),
    weights             = !is.null(weights),
    mtry                = !is.null(mtry),
    split_rule          = !is.null(split_rule),
    split_min_stat      = !is.null(split_min_stat),
    pred_type           = !is.null(pred_type),
    pred_horizon        = !is.null(oobag_pred_horizon),
    oobag_eval_function = !is.null(oobag_fun),
    tree_seeds          = !is.null(tree_seeds),
    oobag_eval_every    = !is.null(oobag_eval_every)
   )

   self$data     <- data
   self$formula  <- formula
   self$control  <- control
   self$weights  <- weights
   self$n_tree   <- n_tree
   self$n_split  <- n_split
   self$n_retry  <- n_retry
   self$n_thread <- n_thread
   self$mtry     <- mtry
   self$sample_with_replacement  <- sample_with_replacement
   self$sample_fraction          <- sample_fraction
   self$leaf_min_events          <- leaf_min_events
   self$leaf_min_obs             <- leaf_min_obs
   self$split_rule               <- split_rule
   self$split_min_events         <- split_min_events
   self$split_min_obs            <- split_min_obs
   self$split_min_stat           <- split_min_stat
   self$pred_type                <- pred_type
   self$pred_horizon             <- oobag_pred_horizon
   self$oobag_eval_every         <- oobag_eval_every
   self$oobag_eval_function      <- oobag_fun
   self$importance_type          <- importance_type
   self$importance_max_pvalue    <- importance_max_pvalue
   self$importance_group_factors <- importance_group_factors
   self$tree_seeds               <- tree_seeds
   self$na_action                <- na_action
   self$verbose_progress         <- verbose_progress

   private$init()

  },

  # Update: re-initialize if dynamic args were unspecified
  update = function(args) {

   # for args with default values of NULL, keep track of whether
   # user specified them or not. If they were un-specified, then
   # the standard init() function is called. If they were specified,
   # the standard check function is called.
   # this allows someone to set control to NULL and revert
   # to using a default control for the given forest.

   data <- args$data %||% self$data

   if("formula" %in% names(args)){
    formula <- args[['formula']]
    terms_old <- terms(self$formula, data = data)
    self$formula <- stats::update(as.formula(terms_old), new = formula)
   }

   null_defaults <- c(
    control = "control",
    weights = "weights",
    mtry = "mtry",
    split_rule = "split_rule",
    split_min_stat = "split_min_stat",
    pred_type = "oobag_pred_type",
    pred_horizon = "oobag_pred_horizon",
    oobag_eval_every = "oobag_eval_every",
    oobag_eval_function = "oobag_fun",
    tree_seeds = "tree_seeds"
   )

   hard_defaults <- c(
    n_tree = "n_tree",
    n_split = "n_split",
    n_retry = "n_retry",
    n_thread = "n_thread",
    sample_with_replacement = "sample_with_replacement",
    sample_fraction = "sample_fraction",
    leaf_min_events = "leaf_min_events",
    leaf_min_obs = "leaf_min_obs",
    split_min_events = "split_min_events",
    split_min_obs = "split_min_obs",
    importance_type = "importance",
    importance_max_pvalue = "importance_max_pvalue",
    importance_group_factors = "group_factors",
    na_action = "na_action",
    verbose_progress = "verbose_progress"
   )

   for( i in seq_along(null_defaults) ){

    input_name <- null_defaults[i]

    if(input_name %in% names(args)){

     input <- args[[input_name]]

     r6_name <- names(null_defaults)[i]

     if(is.null(input)){
      private$user_specified[[r6_name]] <- FALSE
     } else {
      self[[r6_name]] <- input
      private$user_specified[[r6_name]] <- TRUE
     }

    }

   }

   # browser()

   for(i in seq_along(hard_defaults)){

    input_name <- hard_defaults[i]

    if(input_name %in% names(args)){

     input <- args[[input_name]]
     r6_name <- names(hard_defaults)[i]

     self[[r6_name]] <- input

    }

   }

   private$init(data = data)

  },

  # Print object with data dependent on tree_type and lincomb_type
  print = function(){

   # TODO: make tree-specific get_ function for this
   info_model_type <- switch(self$tree_type,
                             'survival' = "Cox regression",
                             'regression' = "Linear regression",
                             'classification' = "Logistic regression")

   info_lincomb_type <- switch(self$control$lincomb_type,
                               'glm'    = NULL,
                               'net'    = "Penalized ",
                               'custom' = "Custom user function")

   if(self$control$lincomb_type != 'custom'){

    if(self$control$lincomb_iter_max == 1){
     info_lincomb_type <- "Accelerated "
    }

    info_lincomb_type <- paste0(info_lincomb_type, info_model_type)

   }

   oobag_stat <- "none"
   oobag_type <- self$oobag_eval_type

   if(!is_empty(self$eval_oobag$stat_values)){

    oobag_stat <- last_value(self$eval_oobag$stat_values)
    oobag_stat <- round_magnitude(oobag_stat)

   }


   forest_text <- paste("Oblique random", self$tree_type, "forest")

   header <- paste0('---------- ', forest_text, '\n')

   if(!self$trained){
    header <- paste0('Untrained ', tolower(forest_text), '\n')
   }


   n_predictors <- length(private$data_names$x_original)

   cat(header,
       paste0('     Linear combinations: ', info_lincomb_type             ),
       paste0('          N observations: ', self$n_obs                    ),

       if(self$tree_type == 'survival')
        paste0('                N events: ', private$n_events             ),

       if(self$tree_type == 'classification')
        paste0('               N classes: ', self$n_class             ),

       paste0('                 N trees: ', self$n_tree                   ),
       paste0('      N predictors total: ', n_predictors                  ),
       paste0('   N predictors per node: ', self$mtry                     ),
       paste0(' Average leaves per tree: ', private$mean_leaves           ),
       paste0('Min observations in leaf: ', self$leaf_min_obs             ),

       if(self$tree_type == 'survival')
        paste0('      Min events in leaf: ', self$leaf_min_events         ),

       paste0('          OOB stat value: ', oobag_stat                    ),
       paste0('           OOB stat type: ', oobag_type                    ),
       paste0('     Variable importance: ', self$importance_type          ),
       '\n-----------------------------------------',
       sep = '\n')

   invisible(self)

  },


  # train an untrained ObliqueForest
  #
  #  ... parameters to override defaults in orsf_cpp. Review
  #   prep_cpp_args() before trying to use this, as some input names
  #   are slightly different in the call to cpp.
  #
  # note: ... functionality is not used in exported
  # orsf_train() currently. I think it's better to
  # keep updating functionality in the orsf_update
  # function.
  train = function(...){

   private$compute_means()
   private$compute_stdev()
   private$compute_modes()
   private$compute_bounds()

   private$prep_x()
   private$prep_y()
   private$prep_w()

   # for survival, inputs should be sorted by time
   private$sort_inputs()

   # allow re-training
   if(self$trained){ self$forest <- list() }

   cpp_args <- private$prep_cpp_args(...)

   cpp_output <- do.call(orsf_cpp, args = cpp_args)

   cpp_output$eval_oobag$stat_type <- self$oobag_eval_type

   .dots <- list(...)
   if(length(.dots) > 0)
    for(i in names(.dots)) self[[i]] <- .dots[[i]]

   self$forest <- cpp_output$forest
   self$importance <- cpp_output$importance
   self$pred_oobag <- cpp_output$pred_oobag
   self$eval_oobag <- cpp_output$eval_oobag

   # don't let rows_oobag contain an empty vector, otherwise it
   # will crash R when cpp tries to load the tree later
   empty_oob_rows <- vapply(self$forest$rows_oobag, is_empty, logical(1))

   if(any(empty_oob_rows)) {
    for(i in which(empty_oob_rows)){
     self$forest$rows_oobag[[i]] <- numeric(1)
    }
   }

   if(self$importance_type != 'none'){
    private$clean_importance()
   }

   if(self$pred_type != 'none'){
    private$clean_pred_oobag()
   } else {
    # revert pred type to NULL so it is correctly
    # initialized when this object is passed to
    # other orsf functions
    self$pred_type <- NULL
   }

   self$trained <- TRUE

   private$compute_mean_leaves()

   # free up space
   private$x <- NULL
   private$y <- NULL
   private$w <- NULL

  },

  # this is the inverse of train.
  untrain = function(){
   self$forest <- NULL
   self$importance <- NULL
   self$pred_oobag <- NULL
   self$eval_oobag <- NULL
   self$trained <- FALSE
   private$mean_leaves <- 0
   private$data_means <- NULL
   private$data_modes <- NULL
   private$data_stdev <- NULL
   private$data_bounds <- NULL
  },


  # Prediction always returns matrix or array of predictions
  predict = function(new_data,
                     pred_type,
                     pred_horizon,
                     pred_aggregate,
                     pred_simplify,
                     oobag,
                     na_action,
                     boundary_checks,
                     n_thread,
                     verbose_progress){

   public_state <- list(data             = self$data,
                        pred_type        = self$pred_type,
                        pred_horizon     = self$pred_horizon,
                        pred_aggregate   = self$pred_aggregate,
                        na_action        = self$na_action,
                        n_thread         = self$n_thread,
                        verbose_progress = self$verbose_progress)

   private_state <- list(data_rows_complete = private$data_rows_complete)

   # assign new_data to original data if user did not supply any.
   # (this is usually done to compute oobag predictions).
   # DO NOT skip checking. A user might modify the data in a forest
   # object and then use this function. We need checks for that case.
   new_data <- new_data %||% self$data


   # run checks before you assign new values to object.
   # otherwise, if a check throws an error, the object will
   # not be restored to its normal state.

   private$check_data(new = TRUE, data = new_data)
   private$check_na_action(new = TRUE, na_action = na_action)
   private$check_var_missing(new = TRUE, data = new_data, na_action)
   private$check_var_values(new = TRUE, data = new_data)
   private$check_units(data = new_data)
   private$check_boundary_checks(boundary_checks)
   private$check_n_thread(n_thread)
   private$check_verbose_progress(verbose_progress)
   private$check_pred_aggregate(pred_aggregate)
   private$check_pred_simplify(pred_simplify)

   # check and/or set self$pred_horizon and self$pred_type
   # with defaults depending on tree type
   private$init_pred(pred_horizon, pred_type, boundary_checks)
   # set the rest
   self$data             <- new_data
   self$na_action        <- na_action
   self$n_thread         <- n_thread
   self$verbose_progress <- verbose_progress
   self$pred_aggregate   <- pred_aggregate


   out <- try(
    expr = {

     private$init_data_rows_complete()
     private$prep_x()

     if(oobag && nrow(private$x) != self$n_obs){
      stop("input data must have ", self$n_obs, " observations to ",
           "compute out-of-bag predictions.", call. = FALSE)
     }

     # y and w do not need to be prepped for prediction,
     # but they need to match orsf_cpp()'s expectations
     private$y <- matrix(0, nrow = 1, ncol = 1)
     private$w <- rep(1, nrow(private$x))

     private$predict_internal(simplify = pred_simplify,
                              oobag = oobag)

    },
    silent = TRUE
   )

   # return fields we modified to their original state
   private$restore_state(public_state, private_state)

   # free up space
   private$x <- NULL
   private$y <- NULL
   private$w <- NULL

   # errors within the try statement are not expected,
   # but if they do occur we want to make sure the
   # object is restored to its original state.
   if(is_error(out)) stop(out, call. = FALSE)

   # NaNs may occur with oobag = TRUE and small n_tree
   coerce_nans(out, to = NA_real_)
   # out

  },

  # Variable importance
  # returns a named numeric vector with importance values

  compute_vi = function(type_vi,
                        oobag_fun,
                        n_thread,
                        verbose_progress){

   if(!is.null(oobag_fun)){

    private$check_oobag_eval_function(oobag_fun)
    oobag_eval_function <- oobag_fun
    oobag_eval_type <- "user-specified function"

   } else {

    oobag_eval_function <- self$oobag_eval_function
    oobag_eval_type <- self$oobag_eval_type

   }

   private$prep_x()
   private$prep_y()
   private$prep_w()

   private$sort_inputs()

   # oobag should be FALSE for computing importance
   # even though it is called 'oobag' importance.
   # this is because we subset the x/y/w mats to only
   # contain oobag observations in cpp.

   cpp_args <-
    private$prep_cpp_args(
     oobag_eval_function = oobag_eval_function,
     oobag_eval_type = oobag_eval_type,
     importance_type = type_vi,
     pred_type = self$get_pred_type_vi(),
     pred_mode = FALSE,
     pred_aggregate = TRUE,
     oobag = FALSE,
     write_forest = FALSE,
     run_forest = TRUE,
     n_thread = n_thread %||% self$n_thread,
     verbosity = verbose_progress %||% self$verbose_progress
    )

   out <- do.call(orsf_cpp, args = cpp_args)$importance
   rownames(out) <- colnames(private$x)
   out

  },

  # Partial dependence
  # returns a data.table with dependence values

  compute_dependence = function(pd_data,
                                pred_spec,
                                pred_horizon,
                                pred_type,
                                na_action,
                                expand_grid,
                                prob_values,
                                prob_labels,
                                boundary_checks,
                                n_thread,
                                verbose_progress,
                                oobag,
                                type_output){

   na_action <- na_action %||% self$na_action

   public_state <- list(data             = self$data,
                        na_action        = self$na_action,
                        pred_horizon     = self$pred_horizon,
                        n_thread         = self$n_thread,
                        verbose_progress = self$verbose_progress)

   private_state <- list(data_rows_complete = private$data_rows_complete)

   # run checks before you assign new values to object.
   private$check_boundary_checks(boundary_checks)

   type_input <- 'user'

   if(inherits(pred_spec, 'pspec_auto')){

    private$check_var_names(.names = pred_spec,
                         data = private$data_names$x_original,
                         location = "pred_spec")

    pred_spec <- list_init(pred_spec)

    for(i in names(pred_spec)){

     pred_spec[[i]] <- unique(self$get_var_bounds(i))

    }

   } else if (inherits(pred_spec, 'pspec_intr')){

    type_input <- 'intr'

    pairs <- utils::combn(pred_spec, m = 2)

    pred_spec <- vector(mode = 'list', length = ncol(pairs))


    for(i in seq_along(pred_spec)){

     pred_spec[[i]] <- expand.grid(
      self$get_var_bounds(pairs[1,i]),
      self$get_var_bounds(pairs[2,i])
     )

     colnames(pred_spec[[i]]) <- pairs[, i, drop = TRUE]

    }

   } else {

    private$check_pred_spec(pred_spec, boundary_checks)

   }

   private$check_n_thread(n_thread)
   private$check_expand_grid(expand_grid)
   private$check_verbose_progress(verbose_progress)
   private$check_oobag_pred_mode(oobag, label = 'oobag')

   prob_values <- prob_values %||% c(0.025, 0.50, 0.975)
   prob_labels <- prob_labels %||% c('lwr', 'medn', 'upr')

   private$check_prob_values(prob_values)
   private$check_prob_labels(prob_labels)

   if(length(prob_values) != length(prob_labels)){
    stop("prob_values and prob_labels must have the same length.",
         call. = FALSE)
   }

   # oobag=FALSE to match the format of arg in orsf_pd().
   private$check_pred_type(pred_type, oobag = FALSE,
                           context = 'partial dependence')

   pred_type <- pred_type %||% self$pred_type

   private$check_pred_horizon(pred_horizon, boundary_checks, pred_type)

   if(!oobag){
    private$check_data(new = TRUE, data = pd_data)
    # say new = FALSE to prevent na_action = 'pass'
    private$check_na_action(new = FALSE, na_action = na_action)
    private$check_var_missing(new = TRUE, data = pd_data, na_action)
    private$check_units(data = pd_data)
    self$data <- pd_data
   }

   self$na_action <- na_action
   self$n_thread <- n_thread
   self$verbose_progress <- verbose_progress

   out <- try(
    private$compute_dependence_internal(pred_spec = pred_spec,
                                        pred_type = pred_type,
                                        pred_horizon = pred_horizon,
                                        type_input  = type_input,
                                        type_output = type_output,
                                        expand_grid = expand_grid,
                                        prob_labels = prob_labels,
                                        prob_values = prob_values,
                                        oobag = oobag),
    silent = FALSE
   )

   private$restore_state(public_state, private_state)

   # free up space
   private$x <- NULL
   private$y <- NULL
   private$w <- NULL

   # errors within the try statement are not expected,
   # but if they do occur we want to make sure the
   # object is restored to its original state.
   if(is_error(out)) stop(out, call. = FALSE)

   out

  },

  # Variable selection
  # returns a data.table with variable selection info
  select_variables = function(n_predictor_min, verbose_progress){

   public_state <- list(verbose_progress = self$verbose_progress,
                        forest           = self$forest,
                        control          = self$control)

   object_trained <- self$trained

   out <- try(
    private$select_variables_internal(n_predictor_min, verbose_progress)
   )

   private$restore_state(public_state, private_state = NULL)

   # you can pass a specification to orsf_vs, but it will be trained
   # during the variable selection process. Use untrain() to prevent
   # a specification from being modified when it is passed to orsf_vs
   if(!object_trained) self$untrain()

   if(is_error(out)) stop(out, call. = FALSE)

   out


  },

  # Univariate summary of predictors
  # calls both orsf_vi and orsf_pd_oob to create a summary object
  summarize_uni = function(n_variables = NULL,
                           pred_horizon = NULL,
                           pred_type = NULL,
                           importance_type = NULL,
                           class = NULL,
                           verbose_progress = FALSE){

   # check incoming values if they were specified.
   private$check_n_variables(n_variables)
   private$check_verbose_progress(verbose_progress)
   private$check_class(class)

   if(!is.null(pred_horizon)){
    private$check_pred_horizon(pred_horizon, boundary_checks = TRUE)
   }

   private$check_pred_type(pred_type, oobag = FALSE,
                        context = 'partial dependence')
   private$check_importance_type(importance_type)

   names_x <- private$data_names$x_original

   # use existing values if incoming ones were not specified
   n_variables     <- n_variables     %||% length(names_x)
   pred_horizon    <- pred_horizon    %||% self$pred_horizon
   pred_type       <- pred_type       %||% self$pred_type
   importance_type <- importance_type %||% self$importance_type

   # bindings for CRAN check
   value <- NULL
   level <- NULL

   # TODO: make this go away. Just sort alphabetically if no importance
   if(importance_type == 'none' && is_empty(self$importance_type))
    stop("importance cannot be 'none' if object does not have variable",
         " importance values.", call. = FALSE)

   vi <- switch(
    importance_type,
    'anova' = orsf_vi_anova(self, group_factors = TRUE,
                            verbose_progress = verbose_progress),
    'negate' = orsf_vi_negate(self, group_factors = TRUE,
                              verbose_progress = verbose_progress),
    'permute' = orsf_vi_permute(self, group_factors = TRUE,
                                verbose_progress = verbose_progress),
    'none' = NULL
   )

   bounds <- private$data_bounds
   fctrs  <- private$data_fctrs
   n_obs  <- self$n_obs

   names_vi <- names(vi) %||% names_x

   pred_spec <- list_init(names_vi)[seq(n_variables)]

   for(i in names(pred_spec)){

    if(i %in% colnames(bounds)){

     pred_spec[[i]] <- self$get_var_bounds(i)
      # unique(
      #  as.numeric(bounds[c('25%','50%','75%'), i])
      # )

    } else if (i %in% fctrs$cols) {

     pred_spec[[i]] <- fctrs$lvls[[i]]

    }

   }

   pd_output <- orsf_pd_oob(object = self,
                            pred_spec = pred_spec,
                            expand_grid = FALSE,
                            pred_type = pred_type,
                            prob_values = c(0.25, 0.50, 0.75),
                            pred_horizon = pred_horizon,
                            boundary_checks = FALSE,
                            verbose_progress = verbose_progress)

   fctrs_unordered <- c()

   # did the orsf have factor variables?
   if(!is_empty(fctrs$cols)){
    fctrs_unordered <- fctrs$cols[!fctrs$ordr]
   }

   # some cart-wheels here for backward compatibility.
   f <- as.factor(pd_output$variable)

   name_rep <- rle(as.integer(f))

   pd_output$importance <- rep(vi[levels(f)[name_rep$values]],
                               times = name_rep$lengths)

   pd_output[, value := fifelse(test = is.na(value),
                                yes = as.character(level),
                                no = round_magnitude(value))]

   # if a := is used inside a function with no DT[] before the end of the
   # function, then the next time DT or print(DT) is typed at the prompt,
   # nothing will be printed. A repeated DT or print(DT) will print.
   # To avoid this: include a DT[] after the last := in your function.
   pd_output[]

   new_order <- c('variable', 'importance', 'value',
                  'mean', 'medn', 'lwr', 'upr')

   if(self$tree_type == 'classification'){
    new_order <- insert_vals(new_order, 2, 'class')
    if(!is.null(class)){
     .class <- class # prevents mix-up with class in dt
     pd_output <- pd_output[class == .class]
    } else {
     # put the highest level class on top
     pd_output <- pd_output[order(-class)]
    }
   }

   setcolorder(pd_output, new_order)

   structure(
    .Data = list(dt = pd_output,
                 pred_type = pred_type,
                 pred_horizon = pred_horizon),
    class = 'orsf_summary_uni'
   )


  },

  # setter

  set_field = function(...){

   .dots <- list(...)

   valid_fields <- c("data",
                     "formula",
                     "control",
                     "weights",
                     "n_tree",
                     "n_split",
                     "n_retry",
                     "n_thread",
                     "mtry",
                     "sample_with_replacement",
                     "sample_fraction",
                     "leaf_min_events",
                     "leaf_min_obs",
                     "split_rule",
                     "split_min_events",
                     "split_min_obs",
                     "split_min_stat",
                     "pred_type",
                     "oobag_pred_horizon",
                     "oobag_eval_every",
                     "oobag_fun",
                     "importance_type",
                     "importance_max_pvalue",
                     "importance_group_factors",
                     "tree_seeds",
                     "na_action",
                     "verbose_progress")

   if(!is_empty(.dots)){

    for(i in names(.dots)){

     if(!(i %in% valid_fields))
      stop("field ", i, " is unrecognized")

     self[[i]] <- .dots[[i]]

    }

   }

  },

  # getters
  # the get_ functions return a private member variable in
  # standard or requested format, allowing for info about
  # specific variables to be returned.
  get_names_x = function(ref_coded = FALSE){

   if(ref_coded) return(private$data_names$x_ref_code)

   return(private$data_names$x_original)

  },

  get_names_y = function(){
   return(private$data_names$y)
  },

  get_var_bounds = function(.name){

   if(.name %in% private$data_names$x_numeric){

    out <- unique(as.numeric(private$data_bounds[, .name]))

    if(length(out) < 5){
     # too few unique values to use quantiles,
     # so use the most common unique values instead.
     unis <- sort(table(self$data[[.name]]), decreasing = TRUE)
     n_items <- min(5, length(unis))
     out <- sort(as.numeric(names(unis)[seq(n_items)]))
    }

    return(out)

   } else {

    return(private$data_fctrs$lvls[[.name]])

   }

  },

  get_var_type = function(.name){
   return(class(self$data[[.name]])[1])
  },

  get_var_unit = function(.name){
   return(private$data_units[[.name]])
  },

  get_fctr_info = function(){
   return(private$data_fctrs)
  },

  get_importance_raw = function(){
   return(private$importance_raw)
  },

  get_importance_clean = function(importance_raw = NULL,
                                  group_factors = NULL){


   input <- importance_raw %||% private$importance_raw
   group_factors <- group_factors %||% self$group_factors
   private$clean_importance(input, group_factors)

   return(self$importance)

  },

  get_mean_leaves_per_tree = function(){
   return(private$mean_leaves)
  },

  get_means = function(){
   return(private$data_means)
  },

  get_modes = function(){
   return(private$data_modes)
  },

  get_stdev = function(){
   return(private$data_stdev)
  },

  get_bounds = function(){
   return(private$data_bounds)
  }

 ),

 # private ----
 private = list(

  user_specified = NULL,
  data_rows_complete = NULL,
  data_types = NULL,
  data_names = NULL,
  data_fctrs = NULL,
  data_units = NULL,
  data_means = NULL,
  data_stdev = NULL,
  data_modes = NULL,
  data_bounds = NULL,

  x = NULL,
  y = NULL,
  w = NULL,

  importance_raw = NULL,

  mean_leaves = 0,

  # checkers
  check_data = function(data = NULL, new = FALSE){

   # additional data checks are run during initialization.
   input <- data %||% self$data

   check_arg_is(arg_value = input,
                arg_name = 'data',
                expected_class = 'data.frame')

   # Minimal checks for now, other checks occur later
   if(nrow(input) == 0 || ncol(input) ==  0){
    data_label <- if(new) "new" else "training"
    stop(data_label, " data are empty",
         call. = FALSE)

   }

   if(!new){

    # check for blanks first b/c the check for non-standard symbols
    # will detect blanks with >1 empty characters
    blank_names <- grepl(pattern = '^\\s*$', x = names(input))

    if(any(blank_names)){

     s_if_plural_blank_otherwise <- ""

     to_list <- which(blank_names)

     if(length(to_list) > 1) s_if_plural_blank_otherwise <- "s"

     last <- ifelse(length(to_list) == 2, ' and ', ', and ')

     stop("Blank or empty names detected in training data: see column",
          s_if_plural_blank_otherwise, " ",
          paste_collapse(x = to_list, last = last),
          call. = FALSE)

    }

    ns_names <- grepl(pattern = '[^a-zA-Z0-9\\.\\_]+', x = names(input))

    if(any(ns_names)){

     last <- ifelse(sum(ns_names) == 2, ' and ', ', and ')

     stop("Non-standard names detected in training data: ",
          paste_collapse(x = names(input)[ns_names],
                         last = last),
          call. = FALSE)

    }

   } else {

    private$check_var_names_new(new_data = input)
    private$check_var_types_new(new_data = input)
    private$check_fctrs_new(new_data = input)

   }


  },


  check_var_names = function(.names,
                             data = NULL,
                             location = "formula"){

   data <- data %||% self$data

   if(is.character(data)){
    data_names <- data
   } else {
    data_names <- names(data)
   }

   names_not_found <- setdiff(c(.names), data_names)

   if(!is_empty(names_not_found)){
    msg <- paste0(
     "variables in ", location, " were not found in data: ",
     paste_collapse(names_not_found, last = ' and ')
    )
    stop(msg, call. = FALSE)
   }

  },

  check_var_names_new = function(new_data,
                                 check_new_in_ref = FALSE,
                                 check_ref_in_new = TRUE){

   check_new_data_names(new_data,
                        ref_names = private$data_names$x_original,
                        label_new = "new_data",
                        label_ref = 'training data',
                        check_new_in_ref = check_new_in_ref,
                        check_ref_in_new = check_ref_in_new)


  },

  # Check variable types
  #
  # orsf() should only be run with certain types of variables. This function
  #   checks input data to make sure all variables have a primary (i.e., first)
  #   class that is within the list of valid options.
  #
  # return an error if something is wrong.
  check_var_types = function(var_types, var_names, valid_types){

   good_vars <- var_types %in% valid_types

   if(!all(good_vars)){

    bad_vars <- which(!good_vars)

    vars_to_list <- var_names[bad_vars]
    types_to_list <- var_types[bad_vars]

    meat <- paste0(' <', vars_to_list, '> has type <',
                   types_to_list, '>', collapse = '\n')

    msg <- paste0("some variables have unsupported type:\n",
                  meat, '\nsupported types are ',
                  paste_collapse(valid_types, last = ' and '))

    stop(msg, call. = FALSE)

   }

   var_types


  },

  check_var_types_new = function(new_data){

   check_new_data_types(new_data,
                        ref_names = private$data_names$x_original,
                        ref_types = private$data_types$x,
                        label_new = "new_data",
                        label_ref = "training data")

  },

  check_var_missing = function(data = NULL,
                               new = FALSE,
                               na_action = NULL){

   input <- data %||% self$data
   na_action <- na_action %||% self$na_action

   if(na_action == 'fail'){

    if(any(is.na(select_cols(input, private$data_names$x_original)))){
     data_label <- if(new) "new data" else "training data"
     stop("Please remove missing values from ", data_label, ", or impute them.",
          call. = FALSE)
    }

   }

   if(na_action == 'omit'){

    if(length(private$data_rows_complete) == 0){

     data_label <- if(new) "new data" else "training data"
     stop("There are no observations in ",
          data_label, " with complete data ",
          "for the predictors used by this orsf object.",
          call. = FALSE)

    }

   }

   if(!new){

    if(any(is.na(select_cols(input, private$data_names$y))))
     stop("Please remove missing values from the outcome variable(s).",
          call. = FALSE)

   }

  },

  check_var_values = function(data = NULL, new = FALSE){

   input <- data %||% self$data

   for(i in private$data_names$x_original){

    if(collapse::allNA(input[[i]])){
     stop("column ", i, " has no observed values.",
          call. = FALSE)
    }

    if(any(is.infinite(input[[i]]))){
     stop("Please remove infinite values from ", i, ".",
          call. = FALSE)
    }

    if(!new){
     if(collapse::fnunique(collapse::na_omit(input[[i]])) == 1L){
      stop("column ", i, " is constant.",
           call. = FALSE)
     }
    }

   }

  },

  check_fctrs_new = function(new_data){

   check_new_data_fctrs(new_data,
                        names_x = private$data_names$x_original,
                        fi_ref = private$data_fctrs,
                        label_new = "new_data")

  },

  check_formula = function(formula = NULL){

   input <- formula %||% self$formula

   check_arg_is(arg_value = input,
                arg_name = 'formula',
                expected_class = 'formula')

   if(length(input) != 3){
    stop("formula must be two sided, i.e. left side ~ right side",
         call. = FALSE)
   }

   formula_deparsed <- as.character(input)[[3]]

   for( symbol in c("*", "^", ":", "(", ")", "["," ]", "|", "%") ){

    if(grepl(symbol, formula_deparsed, fixed = TRUE)){

     stop("unrecognized symbol in formula: ", symbol,
          "\norsf recognizes '+', '-', and '.' symbols.",
          call. = FALSE)

    }

   }

  },
  check_control = function(control = NULL){

   input <- control %||% self$control

   check_arg_is(arg_value = input,
                arg_name = 'control',
                expected_class = 'orsf_control')

   if(input$lincomb_type == 'net'){

    if (!requireNamespace("glmnet", quietly = TRUE)) {
     stop(
      "Package \"glmnet\" must be installed to use",
      " orsf_control_net() with orsf().",
      call. = FALSE
     )
    }

   }

  },
  check_weights = function(weights = NULL){

   input <- weights %||% self$weights

   check_arg_type(arg_value = input,
                  arg_name = 'weights',
                  expected_type = 'numeric')

   check_arg_gteq(arg_value = input,
                  arg_name = 'weights',
                  bound = 0)

   check_arg_length(arg_value = input,
                    arg_name  = 'weights',
                    expected_length = self$n_obs)

  },
  check_n_tree = function(n_tree = NULL){

   input <- n_tree %||% self$n_tree

   check_arg_type(arg_value = input,
                  arg_name = 'n_tree',
                  expected_type = 'numeric')

   check_arg_is_integer(arg_value = input,
                        arg_name = 'n_tree')

   check_arg_gteq(arg_value = input,
                  arg_name = 'n_tree',
                  bound = 1)

   check_arg_lteq(arg_value = input,
                  arg_name = 'n_tree',
                  bound = 10000)

   check_arg_length(arg_value = input,
                    arg_name = 'n_tree',
                    expected_length = 1)

  },
  check_n_split = function(n_split = NULL){

   input <- n_split %||% self$n_split

   check_arg_type(arg_value = input,
                  arg_name = 'n_split',
                  expected_type = 'numeric')

   check_arg_is_integer(arg_value = input,
                        arg_name = 'n_split')

   check_arg_gteq(arg_value = input,
                  arg_name = 'n_split',
                  bound = 1)

   check_arg_length(arg_value = input,
                    arg_name = 'n_split',
                    expected_length = 1)

  },
  check_n_retry = function(n_retry = NULL){

   input <- n_retry %||% self$n_retry

   check_arg_type(arg_value = input,
                  arg_name = 'n_retry',
                  expected_type = 'numeric')

   check_arg_is_integer(arg_value = input,
                        arg_name = 'n_retry')

   check_arg_gteq(arg_value = input,
                  arg_name = 'n_retry',
                  bound = 0)

   check_arg_length(arg_value = input,
                    arg_name = 'n_retry',
                    expected_length = 1)

  },
  check_n_thread = function(n_thread = NULL){

   input <- n_thread %||% self$n_thread

   check_arg_type(arg_value = input,
                  arg_name = 'n_thread',
                  expected_type = 'numeric')

   check_arg_is_integer(arg_value = input,
                        arg_name = 'n_thread')

   check_arg_gteq(arg_value = input,
                  arg_name = 'n_thread',
                  bound = 0)

   check_arg_length(arg_value = input,
                    arg_name = 'n_thread',
                    expected_length = 1)

  },

  check_n_variables = function(n_variables = NULL){

   # n_variables is not a field of ObliqueForest,
   # so it is only checked as an incoming input.

   if(!is.null(n_variables)){

    check_arg_type(arg_value = n_variables,
                   arg_name = 'n_variables',
                   expected_type = 'numeric')

    check_arg_is_integer(arg_value = n_variables,
                         arg_name = 'n_variables')

    check_arg_gteq(arg_value = n_variables,
                   arg_name = 'n_variables',
                   bound = 1)

    check_arg_lteq(arg_value = n_variables,
                   arg_name = 'n_variables',
                   bound = length(private$data_names$x_original),
                   append_to_msg = "(total number of predictors)")


    check_arg_length(arg_value = n_variables,
                     arg_name = 'n_variables',
                     expected_length = 1)

   }

  },

  check_mtry = function(mtry = NULL){

   input <- mtry %||% self$mtry

   n_predictors <- length(private$data_names$x_ref_code)

   # okay for this to be unspecified at startup
   if(!is.null(input)){

    check_arg_type(arg_value = input,
                   arg_name = 'mtry',
                   expected_type = 'numeric')

    check_arg_is_integer(arg_value = input,
                         arg_name = 'mtry')

    check_arg_gteq(arg_value = input,
                   arg_name = 'mtry',
                   bound = 1)

    check_arg_length(arg_value = input,
                     arg_name = 'mtry',
                     expected_length = 1)

    if(!is.null(n_predictors)){
     check_arg_lteq(
      arg_value = input,
      arg_name = 'mtry',
      bound = n_predictors,
      append_to_msg = "(number of columns in the reference coded x-matrix)"
     )
    }

   }

  },
  check_sample_with_replacement = function(sample_with_replacement = NULL){

   input <- sample_with_replacement %||% self$sample_with_replacement

   check_arg_type(arg_value = input,
                  arg_name = 'sample_with_replacement',
                  expected_type = 'logical')

   check_arg_length(arg_value = input,
                    arg_name = 'sample_with_replacement',
                    expected_length = 1)

  },
  check_sample_fraction = function(sample_fraction = NULL){

   input <- sample_fraction %||% self$sample_fraction

   check_arg_type(arg_value = input,
                  arg_name = 'sample_fraction',
                  expected_type = 'numeric')

   check_arg_gt(arg_value = input,
                arg_name = 'sample_fraction',
                bound = 0)

   check_arg_lteq(arg_value = input,
                  arg_name = 'sample_fraction',
                  bound = 1)

   check_arg_length(arg_value = input,
                    arg_name = 'sample_fraction',
                    expected_length = 1)

  },
  check_leaf_min_obs = function(leaf_min_obs = NULL,
                                n_obs = NULL){

   input <- leaf_min_obs %||% self$leaf_min_obs
   n_obs <- n_obs %||% self$n_obs

   # users should never get this error but it may help me debug
   if(is.null(n_obs))
    stop("cannot check leaf_min_obs when n_obs is unspecified")

   check_arg_type(arg_value = input,
                  arg_name = 'leaf_min_obs',
                  expected_type = 'numeric')

   check_arg_is_integer(arg_value = input,
                        arg_name = 'leaf_min_obs')

   check_arg_gteq(arg_value = input,
                  arg_name = 'leaf_min_obs',
                  bound = 1)

   check_arg_length(arg_value = input,
                    arg_name = 'leaf_min_obs',
                    expected_length = 1)

   check_arg_lteq(arg_value = input,
                  arg_name = "leaf_min_obs",
                  bound = round(n_obs / 2),
                  append_to_msg = "(number of observations divided by 2)")

  },
  check_split_rule = function(split_rule = NULL){

   input <- split_rule %||% self$split_rule

   # okay to pass a NULL value on startup
   if(!is.null(input)){

    check_arg_type(arg_value = input,
                   arg_name = 'split_rule',
                   expected_type = 'character')

    check_arg_length(arg_value = input,
                     arg_name = 'split_rule',
                     expected_length = 1)

    private$check_split_rule_internal()

   }


  },
  check_split_rule_internal = function(){

   stop("this method should be defined in a derived class.")

  },
  check_split_min_obs = function(split_min_obs = NULL,
                                 n_obs = NULL){

   input <- split_min_obs %||% self$split_min_obs
   n_obs <- n_obs %||% self$n_obs

   # users should never get this error but it may help me debug
   if(is.null(n_obs))
    stop("cannot check split_min_obs when n_obs is unspecified")

   check_arg_type(arg_value = input,
                  arg_name = 'split_min_obs',
                  expected_type = 'numeric')

   check_arg_is_integer(arg_value = input,
                        arg_name = 'split_min_obs')

   check_arg_gteq(arg_value = input,
                  arg_name = 'split_min_obs',
                  bound = 1)

   check_arg_length(arg_value = input,
                    arg_name = 'split_min_obs',
                    expected_length = 1)

   check_arg_lt(arg_value = input,
                arg_name = "split_min_obs",
                bound = n_obs,
                append_to_msg = "(number of observations)")

  },
  check_split_min_stat = function(split_min_stat = NULL,
                                  split_rule = NULL){

   input <- split_min_stat %||% self$split_min_stat

   split_rule <- split_rule %||% self$split_rule

   # okay to pass a NULL value on startup
   if(!is.null(input)){

    check_arg_type(arg_value = input,
                   arg_name = 'split_min_stat',
                   expected_type = 'numeric')

    check_arg_gteq(arg_value = input,
                   arg_name = 'split_min_stat',
                   bound = 0)

    check_arg_length(arg_value = input,
                     arg_name = 'split_min_stat',
                     expected_length = 1)

    # also okay for split_rule to be NULL on startup
    if(!is.null(split_rule)){

     if(split_rule %in% c('cstat', 'gini')){

      append <- paste0("(split stat <", self$split_rule, "> is always < 1)")

      check_arg_lt(arg_value = input,
                   arg_name = 'split_min_stat',
                   bound = 1,
                   append_to_msg = append)

     }

    }

   }

  },

  # must specify oobag when you call this to make sure it isn't forgotten
  check_pred_type = function(pred_type = NULL, oobag, context = NULL){

   input <- pred_type %||% self$pred_type

   # okay to pass a NULL value on startup
   if(!is.null(input)){

    arg_name <- if(oobag) 'oobag_pred_type' else "pred_type"

    check_arg_type(arg_value = input,
                   arg_name = arg_name,
                   expected_type = 'character')

    check_arg_length(arg_value = input,
                     arg_name = arg_name,
                     expected_length = 1)

    private$check_pred_type_internal(oobag = oobag,
                                  pred_type = pred_type,
                                  context = context)

   }



  },

  check_pred_aggregate = function(pred_aggregate = NULL){

   input <- pred_aggregate %||% self$pred_aggregate

   check_arg_type(arg_value = input,
                  arg_name = 'pred_aggregate',
                  expected_type = 'logical')

   check_arg_length(arg_value = input,
                    arg_name = 'pred_aggregate',
                    expected_length = 1)

  },

  check_pred_simplify = function(pred_simplify){

   check_arg_type(arg_value = pred_simplify,
                  arg_name = 'pred_simplify',
                  expected_type = 'logical')

   check_arg_length(arg_value = pred_simplify,
                    arg_name = 'pred_simplify',
                    expected_length = 1)

  },

  check_oobag_eval_every = function(oobag_eval_every = NULL,
                                    n_tree = NULL){

   input <- oobag_eval_every %||% self$oobag_eval_every

   n_tree <- n_tree %||% self$n_tree

   check_arg_type(arg_value = input,
                  arg_name = 'oobag_eval_every',
                  expected_type = 'numeric')

   check_arg_is_integer(arg_value = input,
                        arg_name = 'oobag_eval_every')

   check_arg_gteq(arg_value = input,
                  arg_name = 'oobag_eval_every',
                  bound = 1)

   check_arg_length(arg_value = input,
                    arg_name = 'oobag_eval_every',
                    expected_length = 1)

   check_arg_lteq(arg_value = self$oobag_eval_every,
                  arg_name = 'oobag_eval_every',
                  bound = n_tree)


  },

  check_importance_type = function(importance_type = NULL){

   input <- importance_type <- self$importance_type

   check_arg_type(arg_value = input,
                  arg_name = 'importance',
                  expected_type = 'character')

   check_arg_length(arg_value = input,
                    arg_name = 'importance',
                    expected_length = 1)

   check_arg_is_valid(arg_value = input,
                      arg_name = 'importance',
                      valid_options = c("none",
                                        "anova",
                                        "negate",
                                        "permute"))

  },
  check_importance_max_pvalue = function(importance_max_pvalue = NULL){

   input <- importance_max_pvalue %||% self$importance_max_pvalue

   check_arg_type(arg_value = input,
                  arg_name = 'importance_max_pvalue',
                  expected_type = 'numeric')

   check_arg_gt(arg_value = input,
                arg_name = 'importance_max_pvalue',
                bound = 0)

   check_arg_lt(arg_value = input,
                arg_name = 'importance_max_pvalue',
                bound = 1)

   check_arg_length(arg_value = input,
                    arg_name = 'importance_max_pvalue',
                    expected_length = 1)


  },

  check_importance_group_factors = function(importance_group_factors = NULL){

   input <- importance_group_factors %||% self$importance_group_factors

   check_arg_type(arg_value = input,
                  arg_name = 'group_factors',
                  expected_type = 'logical')

   check_arg_length(arg_value = input,
                    arg_name = 'group_factors',
                    expected_length = 1)

  },
  check_tree_seeds = function(tree_seeds = NULL,
                              n_tree = NULL){

   input <- tree_seeds %||% self$tree_seeds
   n_tree <- n_tree %||% self$n_tree

   # okay for this to be unspecified at start-up
   if(!is.null(input)){

    check_arg_type(arg_value = input,
                   arg_name = 'tree_seed',
                   expected_type = 'numeric')

    check_arg_is_integer(arg_value = input,
                         arg_name = 'tree_seeds')

    if(!is.null(n_tree)){

     if(length(input) > 1 && length(input) != n_tree){

      stop('tree_seeds should have length = 1 or length = ",
          "n_tree <', self$n_tree,
           "> (the number of trees) but instead has length <",
           length(input), ">", call. = FALSE)

     }

    }

   }

  },
  check_na_action = function(na_action = NULL, new = FALSE){

   input <- na_action %||% self$na_action

   check_arg_type(arg_value = input,
                  arg_name = 'na_action',
                  expected_type = 'character')

   check_arg_length(arg_value = input,
                    arg_name = 'na_action',
                    expected_length = 1)

   valid_options <- c("fail", "omit", "impute_meanmode")

   if(new) valid_options <- c(valid_options, 'pass')

   check_arg_is_valid(arg_value = input,
                      arg_name = 'na_action',
                      valid_options = valid_options)

  },
  check_oobag_eval_function = function(oobag_eval_function = NULL){

   input <- oobag_eval_function %||% self$oobag_eval_function

   if(!is.null(input)){

    oobag_fun_args <- names(formals(input))

    if(length(oobag_fun_args) != 3) stop(
     "oobag_fun should have 3 input arguments but instead has ",
     length(oobag_fun_args),
     call. = FALSE
    )

    if(oobag_fun_args[1] != 'y_mat') stop(
     "the first input argument of oobag_fun should be named 'y_mat' ",
     "but is instead named '", oobag_fun_args[1], "'",
     call. = FALSE
    )

    if(oobag_fun_args[2] != 'w_vec') stop(
     "the second input argument of oobag_fun should be named 'w_vec' ",
     "but is instead named '", oobag_fun_args[1], "'",
     call. = FALSE
    )

    if(oobag_fun_args[3] != 's_vec') stop(
     "the third input argument of oobag_fun should be named 's_vec' ",
     "but is instead named '", oobag_fun_args[2], "'",
     call. = FALSE
    )

    test_output <- private$check_oobag_eval_function_internal(input)

    if(!is.numeric(test_output)) stop(
     "oobag_fun should return a numeric output but instead returns ",
     "output of type ", class(test_output)[1],
     call. = FALSE
    )

    if(length(test_output) != 1) stop(
     "oobag_fun should return output of length 1 instead returns ",
     "output of length ", length(test_output),
     call. = FALSE
    )

   }

  },


  check_lincomb_R_function = function(lincomb_R_function = NULL){

   input <- lincomb_R_function %||% self$control$lincomb_R_function

   args <- names(formals(input))

   if(length(args) != 3) stop(
    "input should have 3 input arguments but instead has ",
    length(args),
    call. = FALSE
   )

   arg_names_expected <- c("x_node",
                           "y_node",
                           "w_node")

   arg_names_refer <- c('first', 'second', 'third')

   for(i in seq_along(arg_names_expected)){
    if(args[i] != arg_names_expected[i])
     stop(
      "the ", arg_names_refer[i], " input argument of input ",
      "should be named '", arg_names_expected[i],"' ",
      "but is instead named '", args[i], "'",
      call. = FALSE
     )
   }

   test_output <- private$check_lincomb_R_function_internal(input)

   if(!is.matrix(test_output)) stop(
    "user-supplied function should return a matrix output ",
    "but instead returns output of type ", class(test_output)[1],
    call. = FALSE
   )

   if(ncol(test_output) != 1) stop(
    "user-supplied function should return a matrix with 1 column ",
    "but instead returns a matrix with ", ncol(test_output), " columns.",
    call. = FALSE
   )

   if(nrow(test_output) != 3L) stop(
    "user-supplied function should return a matrix with 1 row for each ",
    " column in x_node but instead returns a matrix with ",
    nrow(test_output), " rows ", "in a testing case where x_node has ",
    3L, " columns",
    call. = FALSE
   )

  },

  check_verbose_progress = function(verbose_progress = NULL){

   input <- verbose_progress %||% self$verbose_progress

   # check_arg_type(arg_value = input,
   #                arg_name = 'verbose_progress',
   #                expected_type = 'logical')
   #
   # check_arg_length(arg_value = input,
   #                  arg_name = 'verbose_progress',
   #                  expected_length = 1)

  },
  check_units = function(data = NULL){

   input <- data %||% self$data

   ui_new <- unit_info(data = input, .names = names(private$data_units))

   ui_missing <- setdiff(names(private$data_units), names(ui_new))

   if(!is_empty(ui_missing)){

    if(length(ui_missing) == 1){
     stop(ui_missing, " had unit attributes in training data but",
          " did not have unit attributes in testing data.",
          " Please ensure that variables in new data have the same",
          " units as their counterparts in the training data.",
          call. = FALSE)
    }

    stop(length(ui_missing),
         " variables (",
         paste_collapse(ui_missing, last = ' and '),
         ") had unit attributes in training",
         " data but did not have unit attributes in new data.",
         " Please ensure that variables in new data have the same",
         " units as their counterparts in the training data.",
         call. = FALSE)

   }

   for(i in names(private$data_units)){

    if(private$data_units[[i]]$label != ui_new[[i]]$label){

     msg <- paste0("variable <", i, "> has unit '",
                   private$data_units[[i]]$label,
                   "' in training data but has unit '",
                   ui_new[[i]]$label, "' in new data")

     stop(msg, call. = FALSE)

    }

   }

  },

  check_pred_spec = function(pred_spec, boundary_checks){

   if(is_empty(pred_spec)){
    stop("pred_spec is empty", call. = FALSE)
   }

   if(is_empty(names(pred_spec))){
    stop("pred_spec is unnamed", call. = FALSE)
   }

   bad_name_index <- which(
    is.na(match(names(pred_spec), private$data_names$x_original))
   )

   if(!is_empty(bad_name_index)){

    bad_names <- names(pred_spec)[bad_name_index]

    stop("variables in pred_spec are not recognized as predictors in object: ",
         paste_collapse(bad_names, last = ' and '),
         call. = FALSE)

   }

   numeric_bounds <- private$data_bounds
   numeric_names <- intersect(colnames(numeric_bounds), names(pred_spec))

   if(!is_empty(numeric_names) && boundary_checks){

    for(.name in numeric_names){

     vals_above_stop <- which(pred_spec[[.name]] > numeric_bounds['90%', .name])
     vals_below_stop <- which(pred_spec[[.name]] < numeric_bounds['10%', .name])

     boundary_error <- FALSE
     vals_above_list <- vals_below_list <- " "

     if(!is_empty(vals_above_stop)){
      vals_above_list <- paste_collapse(
       round_magnitude(pred_spec[[.name]][vals_above_stop]),
       last = ' and '
      )

      boundary_error <- TRUE
      vals_above_list <-
       paste0(" (",vals_above_list," > ", numeric_bounds['90%', .name],") ")

     }

     if(!is_empty(vals_below_stop)){

      vals_below_list <- paste_collapse(
       round_magnitude(pred_spec[[.name]][vals_below_stop]),
       last = ' and '
      )

      boundary_error <- TRUE

      vals_below_list <-
       paste0(" (",vals_below_list," < ", numeric_bounds['10%', .name],") ")

     }

     if(boundary_error)
      stop("Some values for ",
           .name,
           " in pred_spec are above",
           vals_above_list,
           "or below",
           vals_below_list,
           "90th or 10th percentiles in training data.",
           " Change pred_spec or set boundary_checks = FALSE",
           " to prevent this error",
           call. = FALSE)

    }

   }



  },

  check_boundary_checks = function(boundary_checks){

   # not a field so boundary_checks should never be null

   check_arg_type(arg_value = boundary_checks,
                  arg_name = 'boundary_checks',
                  expected_type = 'logical')

   check_arg_length(arg_value = boundary_checks,
                    arg_name = 'boundary_checks',
                    expected_length = 1)


  },

  check_expand_grid = function(expand_grid){

   # not a field so boundary_checks should never be null
   check_arg_type(arg_value = expand_grid,
                  arg_name = 'expand_grid',
                  expected_type = 'logical')

   check_arg_length(arg_value = expand_grid,
                    arg_name = 'expand_grid',
                    expected_length = 1)

  },

  check_prob_values = function(prob_values){

   check_arg_type(arg_value = prob_values,
                  arg_name = 'prob_values',
                  expected_type = 'numeric')

   check_arg_gteq(arg_value = prob_values,
                  arg_name = 'prob_values',
                  bound = 0)

   check_arg_lteq(arg_value = prob_values,
                  arg_name = 'prob_values',
                  bound = 1)

  },

  check_prob_labels = function(prob_labels){


   check_arg_type(arg_value = prob_labels,
                  arg_name = 'prob_labels',
                  expected_type = 'character')

  },

  check_oobag_pred_mode = function(oobag_pred_mode, label,
                                   sample_fraction = NULL){

   sample_fraction <- sample_fraction %||% self$sample_fraction

   check_arg_type(arg_value = oobag_pred_mode,
                  arg_name = label,
                  expected_type = 'logical')

   check_arg_length(arg_value = oobag_pred_mode,
                    arg_name = label,
                    expected_length = 1)

   if(!is.null(sample_fraction)){

    if(oobag_pred_mode && sample_fraction == 1){
     stop(
      "cannot compute out-of-bag predictions if no samples are out-of-bag.",
      " Try setting sample_fraction < 1 or oobag_pred_type = 'none'.",
      call. = FALSE
     )
    }

   }

  },

  check_lincomb_df_target = function(lincomb_df_target = NULL,
                                     mtry = NULL){

   input <- lincomb_df_target %||% self$control$lincomb_df_target
   mtry <- mtry %||% self$mtry

   check_arg_lteq(
    arg_value = input,
    arg_name = 'df_target',
    bound = mtry,
    append_to_msg = "(number of randomly selected predictors)"
   )

  },

  check_class = function(class = NULL){

   if(!is.null(class)){

    check_arg_is(arg_value = class,
                 arg_name = "class",
                 expected_class = "character")

    check_arg_length(arg_value = class,
                     arg_name = "class",
                     expected_length = 1L)

    check_arg_is_valid(arg_value = class,
                       arg_name = "class",
                       valid_options = self$class_levels)

   }


  },

  # runs checks and sets defaults where needed.
  # data is NULL when we are creating a new forest,
  # but may be non-NULL if we update an existing one
  init = function(data = NULL) {

   # look for odd symbols in formula before you check variables in data
   private$check_formula()
   # check & init data should be near first bc they set up other checks
   private$check_data(data)
   private$init_data(data)

   # if data is not null, it means we are updating an orsf spec
   # and in that process applying it to a new dataset, so:
   if(!is.null(data)) self$data <- data

   if(private$user_specified$control){
    private$check_control()
   } else {
    private$init_control()
   }


   if(private$user_specified$mtry){
    private$check_mtry()
   } else {
    private$init_mtry()
   }

   if(private$user_specified$lincomb_df_target){
    private$check_lincomb_df_target()
   } else {
    private$init_lincomb_df_target()
   }

   if(private$user_specified$weights){
    private$check_weights()
   } else {
    private$init_weights()
   }

   if(private$user_specified$pred_type){
    private$check_pred_type(oobag = TRUE)
   } else {
    private$init_pred_type()
   }

   if(private$user_specified$split_rule){
    private$check_split_rule()
   } else {
    private$init_split_rule()
   }

   if(private$user_specified$split_min_stat){
    private$check_split_min_stat()
   } else {
    private$init_split_min_stat()
   }

   if(private$user_specified$oobag_eval_function){
    private$check_oobag_eval_function()
    self$oobag_eval_type <- "User-specified function"
   } else {
    private$init_oobag_eval_function()
   }

   if(private$user_specified$oobag_eval_every){
    private$check_oobag_eval_every()
   } else {
    private$init_oobag_eval_every()
   }

   if(self$control$lincomb_type == 'custom'){
    private$check_lincomb_R_function()
   } else if (is.null(self$control$lincomb_R_function)){
    private$init_lincomb_R_function()
   }

   # arguments with hard defaults do not need an init option
   private$check_n_tree()
   private$check_n_split()
   private$check_n_retry()
   private$check_n_thread()
   private$check_sample_with_replacement()
   private$check_sample_fraction()
   private$check_leaf_min_obs()
   private$check_split_min_obs()
   private$check_importance_type()
   private$check_importance_max_pvalue()
   private$check_importance_group_factors()
   private$check_na_action()
   private$check_verbose_progress()

   # args below depend on at least one upstream arg

   if(private$user_specified$tree_seeds && is.null(self$forest_seed)){
    private$check_tree_seeds()
   } else if (!is.null(self$forest_seed)){
    # this only happens when the forest is updated
    private$plant_tree_seeds(self$forest_seed)
   } else {
    private$init_tree_seeds()
   }

   if(length(self$tree_seeds) == 1 && self$n_tree > 1){
    private$plant_tree_seeds(self$tree_seeds)
   }

   # oobag_pred_mode depends on pred_type, which is checked above,
   # so there is no reason to check it here.
   private$init_oobag_pred_mode()
   # check if sample_fraction conflicts with oobag_pred_mode
   private$check_oobag_pred_mode(self$oobag_pred_mode,
                              label = 'oobag_pred_mode',
                              sample_fraction = self$sample_fraction)

   private$init_internal()


  },
  init_data = function(data = NULL){

   if(!is.null(data)) self$data <- data

   formula_terms <- suppressWarnings(
    stats::terms(x = self$formula, data = self$data)
   )

   if(attr(formula_terms, 'response') == 0)
    stop("formula should have a response", call. = FALSE)

   names_x_data <- attr(formula_terms, 'term.labels')
   names_y_data <- all.vars(self$formula[[2]])

   # check factors in x data
   fctr_check(self$data, names_x_data)
   fctr_id_check(self$data, names_x_data)

   private$check_var_names(c(names_x_data, names_y_data), data = self$data)

   private$data_names <- list(y = names_y_data,
                              x_original = names_x_data)


   # coerce logicals to 0/1 integers.
   # only do this for outcomes because changing predictor types here
   # would cause issues down the road. E.g., if X[,1] is an integer
   # in the training data and X[,1] is a logical in the testing data.
   for(.name in c(names_y_data)){

    if(self$get_var_type(.name) == 'logical'){
     self$data[[.name]] <- as.integer(self$data[[.name]])
    }

   }

   types_y_data <- vapply(names_y_data, self$get_var_type, character(1))
   types_x_data <- vapply(names_x_data, self$get_var_type, character(1))

   valid_y_types <- c('numeric', 'integer', 'units', 'factor', "Surv")
   valid_x_types <- c(valid_y_types, 'ordered')


   private$check_var_types(types_y_data, names_y_data, valid_y_types)
   private$check_var_types(types_x_data, names_x_data, valid_x_types)

   private$data_types <- list(y = types_y_data, x = types_x_data)
   private$data_fctrs <- fctr_info(self$data, private$data_names$x_original)

   private$init_numeric_names()
   private$init_ref_code_names()
   private$init_data_rows_complete()

   self$n_obs <- ifelse(test = self$na_action == 'omit',
                        yes  = length(private$data_rows_complete),
                        no   = nrow(self$data))

   private$check_var_missing()
   private$check_var_values()

   unit_names <- c(names_y_data[types_y_data == 'units'],
                   names_x_data[types_x_data == 'units'])

   private$data_units <- unit_info(data = self$data, .names = unit_names)

  },
  init_data_rows_complete = function(){

   private$data_rows_complete <- collapse::whichv(
    collapse::missing_cases(self$data, private$data_names$x_original),
    value = FALSE
   )

  },
  init_tree_seeds = function(){

   if(is.null(self$tree_seeds)){
    self$tree_seeds <- sample(1e6, size = 1)
   }

  },
  init_numeric_names = function(){

   pattern <- "^integer$|^numeric$|^units$"

   index <- grep(pattern = pattern, x = private$data_types$x)

   private$data_names[["x_numeric"]] = private$data_names$x_original[index]

  },
  init_ref_code_names = function(){

   xref <- xnames <- private$data_names$x_original
   fi <- private$data_fctrs

   for (i in seq_along(fi$cols)){

    if(fi$cols[i] %in% xnames){

     if(!fi$ordr[i]){

      xref <- insert_vals(
       vec = xref,
       where = xref %==% fi$cols[i],
       what = fi$keys[[i]][-1]
      )

     }
    }

   }

   private$data_names$x_ref_code = xref

  },
  init_oobag_pred_mode = function(){

   # if pred_type is null when this is run, it means
   # the user did not specify pred_type, which means the
   # family-specific default will be used, which means
   # pred_type will not be 'none', so it is safe to assume
   # oobag_pred_mode is TRUE if pred_type is currently null

   if(is.null(self$pred_type)){
    self$oobag_pred_mode <- TRUE
   } else {
    self$oobag_pred_mode <- self$pred_type != "none"
   }

  },
  init_mtry = function(){

   n_col_x <- length(private$data_names$x_ref_code)

   self$mtry <- ceiling(sqrt(n_col_x))

  },




  init_lincomb_df_target = function(mtry = NULL){

   mtry <- mtry %||% self$mtry

   self$control$lincomb_df_target <- mtry

  },

  init_weights = function(){

   # set weights as 1 if user did not supply them.
   # length of weights depends on how missing are handled.
   self$weights <- rep(1, self$n_obs)

  },


  init_oobag_eval_function = function(){

   self$oobag_eval_function <- function(y_mat, w_vec, s_vec){
    return(1)
   }

  },

  init_oobag_eval_every = function(n_tree = NULL){
   self$oobag_eval_every = n_tree %||% self$n_tree
  },

  init_lincomb_R_function = function(){

   self$control$lincomb_R_function <- function(x) x

  },

  # use a starter seed to create n_tree seeds
  plant_tree_seeds = function(start_seed){

   self$forest_seed <- start_seed
   set.seed(start_seed)
   self$tree_seeds <- sample(1e5, size = self$n_tree)

  },

  # computers

  compute_means = function(){

   numeric_data <- select_cols(self$data, private$data_names$x_numeric)

   if(self$na_action == 'omit'){
    numeric_data <- collapse::fsubset(numeric_data, private$data_rows_complete)
   }

   private$data_means <- collapse::fmean(numeric_data, w = self$weights)

  },


  compute_modes = function(){

   nominal_data <- select_cols(self$data, private$data_fctrs$cols)

   if(self$na_action == 'omit'){
    nominal_data <- collapse::fsubset(nominal_data, private$data_rows_complete)
   }

   private$data_modes <- vapply(nominal_data,
                                collapse::fmode,
                                FUN.VALUE = integer(1),
                                w = self$weights)

  },

  compute_stdev = function(){

   numeric_data <- select_cols(self$data, private$data_names$x_numeric)

   if(self$na_action == 'omit'){
    numeric_data <- collapse::fsubset(numeric_data, private$data_rows_complete)
   }

   private$data_stdev <- collapse::fsd(numeric_data, w = self$weights)

  },

  compute_bounds = function(){

   numeric_data <- select_cols(self$data, private$data_names$x_numeric)

   if(is_empty(numeric_data)){

    private$data_bounds <- matrix(NA, nrow=5, ncol=1,
                                  dimnames = list(c('10%', '25%', '50%', '75%', '90%'),
                                                  "placeholder"))
    return()

   }

   if(self$na_action == 'omit'){
    numeric_data <- collapse::fsubset(numeric_data, private$data_rows_complete)
   }

   private$data_bounds <- matrix(
    data = c(
     collapse::fnth(numeric_data, 0.10, w = self$weights),
     collapse::fnth(numeric_data, 0.25, w = self$weights),
     collapse::fnth(numeric_data, 0.50, w = self$weights),
     collapse::fnth(numeric_data, 0.75, w = self$weights),
     collapse::fnth(numeric_data, 0.90, w = self$weights)
    ),
    nrow = 5,
    byrow = TRUE,
    dimnames = list(c('10%', '25%', '50%', '75%', '90%'),
                    names(numeric_data))
   )

  },

  compute_mean_leaves = function(){

   leaf_counts <- vapply(X = self$forest$leaf_summary,
                         FUN = function(x) sum(x != 0),
                         FUN.VALUE = integer(1))

   private$mean_leaves <- collapse::fmean(leaf_counts)

  },

  compute_dependence_internal = function(pred_spec,
                                         pred_type,
                                         pred_horizon = NULL,
                                         type_input,
                                         type_output,
                                         prob_labels,
                                         prob_values,
                                         expand_grid,
                                         oobag){

   # make a visible binding for CRAN
   id_variable = NULL

   pred_horizon <- pred_horizon %||% self$pred_horizon %||% 1

   # just use first value for these pred_types
   if(pred_type %in% c("mort", "time")) pred_horizon <- pred_horizon[1]

   pred_horizon_order <- order(pred_horizon)
   pred_horizon_ordered <- pred_horizon[pred_horizon_order]


   private$init_data_rows_complete()
   private$prep_x()
   # y and w do not need to be prepped for prediction,
   # but they need to match orsf_cpp()'s expectations
   private$prep_y(placeholder = TRUE)
   private$w <- rep(1, nrow(private$x))

   if(oobag){ private$sort_inputs(sort_y = FALSE) }

   # the values in pred_spec need to be centered & scaled to match x,
   # which is also centered and scaled
   means <- private$data_means
   stdev <- private$data_stdev

   if(type_input == 'user'){

    for(i in intersect(names(means), names(pred_spec))){
     pred_spec[[i]] <- (pred_spec[[i]] - means[i]) / stdev[i]
    }

   } else {

    for(j in seq_along(pred_spec)){

     for(i in intersect(names(means), names(pred_spec[[j]]))){
      pred_spec[[j]][[i]] <- (pred_spec[[j]][[i]] - means[i]) / stdev[i]
     }

    }

    expand_grid <- FALSE

   }

   fi <- private$data_fctrs

   if(expand_grid){

    if(!is.data.frame(pred_spec))
     pred_spec <- expand.grid(pred_spec, stringsAsFactors = TRUE)

    for(i in seq_along(fi$cols)){

     ii <- fi$cols[i]

     if(is.character(pred_spec[[ii]]) && !fi$ordr[i]){

      pred_spec[[ii]] <- factor(pred_spec[[ii]], levels = fi$lvls[[ii]])

     }

    }

    check_new_data_fctrs(new_data  = pred_spec,
                         names_x   = private$data_names$x_original,
                         fi_ref    = fi,
                         label_new = "pred_spec")

    pred_spec_new <- ref_code(x_data = pred_spec, fi = fi,
                              names_x_data = names(pred_spec))

    x_cols <- list(match(names(pred_spec_new), colnames(private$x))-1)

    pred_spec_new <- list(as.matrix(pred_spec_new))

    pd_bind <- list(pred_spec)

   } else if(type_input == 'user'){

    pred_spec_new <- pd_bind <- x_cols <- list()

    for(i in seq_along(pred_spec)){

     pred_spec_new[[i]]  <- as.data.frame(pred_spec[i])
     pd_name <- names(pred_spec)[i]

     pd_bind[[i]] <- data.frame(
      variable = pd_name,
      value = rep(NA_real_, length(pred_spec[[i]])),
      level = rep(NA_character_, length(pred_spec[[i]]))
     )

     if(pd_name %in% fi$cols) {

      pd_bind[[i]]$level <- as.character(pred_spec[[i]])

      pred_spec_new[[i]] <- ref_code(pred_spec_new[[i]],
                                     fi = fi,
                                     names_x_data = pd_name)

     } else {

      pd_bind[[i]]$value <- pred_spec[[i]]

     }

     x_cols[[i]] <- match(names(pred_spec_new[[i]]), colnames(private$x)) - 1
     pred_spec_new[[i]] <- as.matrix(pred_spec_new[[i]])

    }

   } else {

    pd_bind <- pred_spec_new <- pred_spec

    for(i in seq_along(pred_spec_new)){
     pred_spec_new[[i]] <- ref_code(pred_spec_new[[i]], fi = fi,
                                    names_x_data = names(pred_spec_new[[i]]))
    }

    pred_spec_new <- lapply(pred_spec_new, collapse::qM)

    x_cols <- lapply(
     pred_spec_new,
     function(x){
      match(colnames(x), colnames(private$x)) - 1
     }
    )

   }

   cpp_args <- private$prep_cpp_args(x = private$x,
                                     y = private$y,
                                     w = private$w,
                                     importance_type = 'none',
                                     pred_type = pred_type,
                                     pred_mode = FALSE,
                                     pred_aggregate = TRUE,
                                     pred_horizon = pred_horizon_ordered,
                                     oobag = oobag,
                                     oobag_eval_type = 'none',
                                     pd_type_R = switch(type_output,
                                                        "smry" = 1L,
                                                        "ice" = 2L),
                                     pd_x_vals = pred_spec_new,
                                     pd_x_cols = x_cols,
                                     pd_probs = prob_values,
                                     write_forest = FALSE,
                                     run_forest = TRUE)

   pd_vals <- do.call(orsf_cpp, cpp_args)$pd_values

   row_delim <- switch(self$tree_type,
                       "survival" = pred_horizon_ordered,
                       "regression" = 1,
                       "classification" = self$class_levels)

   row_delim_label <- switch(self$tree_type,
                             "survival" = "pred_horizon",
                             "regression" = "pred_row",
                             "classification" = "class")

   for(i in seq_along(pd_vals)){

    pd_bind[[i]]$id_variable <- seq(nrow(pd_bind[[i]]))

    for(j in seq_along(pd_vals[[i]])){

     # nans <- which(is.nan(pd_vals[[i]][[j]]))
     nans <- which(pd_vals[[i]][[j]]==0)

     if(!is_empty(nans)){
      pd_vals[[i]][[j]][nans] <- NA_real_
     }

     pd_vals[[i]][[j]] <- matrix(pd_vals[[i]][[j]],
                                 nrow=length(row_delim),
                                 byrow = T)

     rownames(pd_vals[[i]][[j]]) <- row_delim


     if(type_output=='smry'){

      pd_vals[[i]][[j]] <- t(
       apply(
        X = pd_vals[[i]][[j]],
        MARGIN = 1,
        function(x, p){
         c(collapse::fmean(x), stats::quantile(x, p, na.rm=TRUE))
        },
        prob_values
       )
      )

      colnames(pd_vals[[i]][[j]]) <- c('mean', prob_labels)

     } else {

      colnames(pd_vals[[i]][[j]]) <- c(paste(1:nrow(private$x)))

     }

     pd_vals[[i]][[j]] <- as.data.table(pd_vals[[i]][[j]],
                                        keep.rownames = row_delim_label)

     if(type_output == 'ice'){

      measure.vars <- setdiff(names(pd_vals[[i]][[j]]), row_delim_label)

      pd_vals[[i]][[j]] <- melt_aorsf(data = pd_vals[[i]][[j]],
                                      id.vars = row_delim_label,
                                      variable.name = 'id_row',
                                      value.name = 'pred',
                                      measure.vars = measure.vars)

     }

    }

    pd_vals[[i]] <- rbindlist(pd_vals[[i]], idcol = 'id_variable')

    # this seems awkward but the reason I convert back to data.frame
    # here is to avoid a potential memory leak from forder & bmerge.
    # I have no idea why this memory leak may be occurring but it does
    # not if I apply merge.data.frame instead of merge.data.table
    pd_vals[[i]] <- merge(as.data.frame(pd_vals[[i]]),
                          as.data.frame(pd_bind[[i]]),
                          by = 'id_variable')

    if(type_input == 'intr'){

     v1 <- colnames(pred_spec[[i]])[1]
     v2 <- colnames(pred_spec[[i]])[2]

     pd_vals[[i]][['var_1_name']] <- v1
     pd_vals[[i]][['var_2_name']] <- v2

     names(pd_vals[[i]])[names(pd_vals[[i]]) == v1] <- "var_1_value"
     names(pd_vals[[i]])[names(pd_vals[[i]]) == v2] <- "var_2_value"

     pd_vals[[i]]$var_1_value <- as.numeric(pd_vals[[i]]$var_1_value)
     pd_vals[[i]]$var_2_value <- as.numeric(pd_vals[[i]]$var_2_value)

    }

   }

   out <- rbindlist(pd_vals)

   # missings may occur when oobag=TRUE and n_tree is small
   if(type_output == 'ice') {
    out <- collapse::na_omit(out, cols = 'pred')
   }

   ids <- c('id_variable')

   if(type_output == 'ice') ids <- c(ids, 'id_row')

   mid <- setdiff(names(out), c(ids, 'mean', prob_labels, 'pred'))

   end <- setdiff(names(out), c(ids, mid))

   setcolorder(out, neworder = c(ids, mid, end))

   if(self$tree_type == 'classification'){
    out[, class := factor(class, levels = self$class_levels)]
    setkey(out, class)
   }


   if(self$tree_type == 'survival' && !(pred_type %in% c('mort', 'time')))
    out[, pred_horizon := as.numeric(pred_horizon)]

   if(self$tree_type == 'regression'){
    out[, pred_row := NULL]
   }

   if(pred_type %in% c('mort', 'time'))
    out[, pred_horizon := NULL]

   # not needed for summary
   if(type_output == 'smry')
    out[, id_variable := NULL]

   # put data back into original scale
   if(type_input == 'intr'){

    for(j in collapse::funique(out$var_1_name)){

     if(j %in% names(means)){
      var_index <- out$var_1_name %==% j
      var_value <- (out$var_1_value[var_index] * stdev[j]) + means[j]
      set(out, i = var_index, j = 'var_1_value', value = var_value)
     }


    }

    for(j in collapse::funique(out$var_2_name)){

     if(j %in% names(means)){
      var_index <- out$var_2_name %==% j
      var_value <- (out$var_2_value[var_index] * stdev[j]) + means[j]
      set(out, i = var_index, j = 'var_2_value', value = var_value)
     }

    }

   } else {

    for(j in intersect(names(means), names(pred_spec))){

     if(j %in% names(out)){

      var_index <- collapse::seq_row(out)
      var_value <- (out[[j]] * stdev[j]) + means[j]
      var_name  <- j

     } else {

      var_index <- out$variable %==% j
      var_value <- (out$value[var_index] * stdev[j]) + means[j]
      var_name  <- 'value'

     }

     set(out, i = var_index, j = var_name, value = var_value)

    }

   }


   # silent print after modify in place
   out[]

   out

  },

  select_variables_internal = function(n_predictor_min, verbose_progress){

   n_predictors <- length(private$data_names$x_original)

   # verbose progress on the forest should always be FALSE
   # because for orsf_vs, verbosity is coordinated in R
   self$verbose_progress <- FALSE

   oob_data <- data.table(
    n_predictors = seq(n_predictors),
    stat_value = rep(NA_real_, n_predictors),
    variables_included = vector(mode = 'list', length = n_predictors),
    predictors_included = vector(mode = 'list', length = n_predictors),
    predictor_dropped = rep(NA_character_, n_predictors)
   )

   # if the forest was not trained prior to variable selection
   if(!self$trained){
    private$compute_means()
    private$compute_stdev()
    private$compute_modes()
    private$compute_bounds()
   }

   private$prep_x()
   private$prep_y()
   private$prep_w()

   # for survival, inputs should be sorted by time
   private$sort_inputs()

   # allow re-training.
   self$forest <- list()

   pred_type <- switch(self$tree_type,
                       'survival' = 'mort',
                       'classification' = 'prob',
                       'regression' = 'mean')

   cpp_args <- private$prep_cpp_args(pred_type = pred_type,
                                     oobag_pred = TRUE,
                                     importance_group_factors = TRUE,
                                     write_forest = FALSE)


   fctr_key <- lapply(self$get_fctr_info()$keys, function(x) x[-1])

   fctr_key <- data.frame(
    variable = rep(names(fctr_key), sapply(fctr_key, length)),
    predictor = unlist(fctr_key),
    row.names = NULL
   )

   variable_key <- data.frame(
    predictor = self$get_names_x(ref_coded = TRUE)
   )

   if(is_empty(fctr_key)){

    variable_key$variable <- variable_key$predictor

   } else {

    variable_key <- merge(variable_key, fctr_key,
                          by = 'predictor',
                          all.x = TRUE)

   }

   variable_key$variable[is.na(variable_key$variable)] <-
    variable_key$predictor[is.na(variable_key$variable)]

   max_progress <- n_predictors - n_predictor_min
   current_progress <- 0
   start_time <- last_time <- Sys.time()

   while(n_predictors >= n_predictor_min){

    if(verbose_progress){

     time_passed <- as.numeric(
      as.difftime(Sys.time() - start_time),
      units = "secs"
     )

     time_lapsed <- as.numeric(
      as.difftime(Sys.time() - last_time),
      units = "secs"
     )

     if(current_progress > 0 && time_lapsed > 3){

      relative_progress <- current_progress / max_progress

      remaining_time = round((1 / relative_progress - 1) * time_passed)

      cat(
       "Selecting variables:",
       paste0( round_magnitude(relative_progress*100), "%" ),
       "~ time remaining:", beautifyTime(remaining_time),
       "\n")

      last_time <- Sys.time()

     }

    }

    mtry_safe <- ceiling(sqrt(n_predictors))

    if(self$control$lincomb_df_target > mtry_safe){
     self$control$lincomb_df_target <- mtry_safe
    }

    cpp_args$mtry <- mtry_safe
    cpp_output <- do.call(orsf_cpp, args = cpp_args)

    worst_index <- which.min(cpp_output$importance)
    worst_predictor <- colnames(cpp_args$x)[worst_index]


    .variables_included <- with(
     variable_key,
     unique(variable[predictor %in% colnames(cpp_args$x)])
    )

    oob_data[n_predictors,
             `:=`(n_predictors = n_predictors,
                  stat_value = cpp_output$eval_oobag$stat_values[1,1],
                  variables_included = .variables_included,
                  predictors_included = colnames(cpp_args$x),
                  predictor_dropped = worst_predictor)]

    cpp_args$x <- cpp_args$x[, -worst_index, drop = FALSE]
    n_predictors <- n_predictors - 1
    current_progress <- current_progress + 1

   }

   if(verbose_progress){
    cat("Selecting variables: 100%\n")
   }


   collapse::na_omit(oob_data)

  },

  # preppers

  prep_x = function(){

   data_x <- select_cols(self$data, private$data_names$x_original)

   if(any(is.na(data_x))){

    switch(
     self$na_action,

     'fail' = {
      # this should already have been checked but just in case
      stop("Please remove missing values from data, or impute them.",
           call. = FALSE)
     },

     'omit' = {
      data_x <- data_x[private$data_rows_complete, ]
     },

     'pass' = {
      data_x <- data_x[private$data_rows_complete, ]
     },

     'impute_meanmode' = {

      data_x <- data_impute(data_x,
                            cols = names(data_x),
                            values = c(as.list(private$data_means),
                                       as.list(private$data_modes)))
     }
    )

   }

   x <- ref_code(data_x, private$data_fctrs, names(data_x))

   for(i in private$data_names$x_numeric){

    if(has_units(x[[i]])) x[[i]] <- as.numeric(x[[i]])

    # can't modify by reference here, it would modify the user's data
    x[[i]] <- (x[[i]] - private$data_means[i]) / private$data_stdev[i]

   }

   private$x = as_matrix(x)

  },

  # Prep the outcome variable
  #
  # Coerce the outcome to be compatible with C++ routines.
  # If there is no feasible way to make it work, throw an error.
  # If there aren't any problems, return the outcome.
  #
  # placeholder (*logical*) whether to engage with the
  #   outcome member variable or just make a version that is safe
  #   to pass to orsf_cpp().
  #
  # For survival, y is a matrix with two columns:
  #   - first column: time values
  #   - second column: status values
  #
  # For classification, y is a factor or character vector
  #
  # For regression, y is a numeric vector
  #
  # @return stored the outcome in a matrix format in private$y
  #
  # - Survival outcomes are transformed to have status values of 0 and 1
  # - Classification outcomes are transformed to a one-hot matrix
  # - Regression outcomes are converted to a matrix
  #
  prep_y = function(placeholder = FALSE){

   private$y <- select_cols(self$data, private$data_names$y)

   if(self$na_action == 'omit' && !placeholder)
    private$y <- private$y[private$data_rows_complete, , drop = FALSE]

   private$prep_y_internal(placeholder)

  },

  prep_w = function(){

   # re-scale so that sum(w) == nrow(data)
   private$w <- self$weights * length(self$weights) / sum(self$weights)

  },

  prep_cpp_args = function(...){

   .dots <- list(...)

   args <- list(
    x = private$x,
    y = private$y,
    w = private$w,
    tree_type_R = switch(self$tree_type,
                         'classification' = 1,
                         'regression'= 2,
                         'survival' = 3),
    tree_seeds = self$tree_seeds,
    loaded_forest = self$forest,
    n_tree = .dots$n_tree %||% self$n_tree,
    mtry = .dots$mtry %||% self$mtry,
    sample_with_replacement = .dots$sample_with_replacement %||% self$sample_with_replacement,
    sample_fraction = .dots$sample_fraction %||% self$sample_fraction,
    vi_type_R = switch(.dots$importance_type %||% self$importance_type,
                       "none"    = 0,
                       "negate"  = 1,
                       "permute" = 2,
                       "anova"   = 3),
    vi_max_pvalue = .dots$importance_max_pvalue %||% self$importance_max_pvalue,
    leaf_min_events = .dots$leaf_min_events %||% self$leaf_min_events %||% 1,
    leaf_min_obs = .dots$leaf_min_obs %||% self$leaf_min_obs,
    split_rule_R = switch(self$split_rule,
                          "logrank"  = 1,
                          "cstat"    = 2,
                          "gini"     = 3,
                          "variance" = 4),
    split_min_events = .dots$split_min_events %||% self$split_min_events %||% 1,
    split_min_obs = .dots$split_min_obs %||% self$split_min_obs,
    split_min_stat = .dots$split_min_stat %||% self$split_min_stat,
    split_max_cuts = .dots$split_max_cuts %||% self$n_split,
    split_max_retry = .dots$split_max_retry %||% self$n_retry,
    lincomb_R_function = self$control$lincomb_R_function,
    lincomb_type_R = switch(self$control$lincomb_type,
                            'glm'    = 1,
                            'random' = 2,
                            'net'    = 3,
                            'custom' = 4),
    lincomb_eps = self$control$lincomb_eps,
    lincomb_iter_max = self$control$lincomb_iter_max,
    lincomb_scale = self$control$lincomb_scale,
    lincomb_alpha = self$control$lincomb_alpha,
    lincomb_df_target = self$control$lincomb_df_target,
    lincomb_ties_method = switch(tolower(self$control$lincomb_ties_method),
                                 'breslow' = 0,
                                 'efron'   = 1),
    pred_type_R = switch(.dots$pred_type %||% self$pred_type,
                         "none"  = 0,
                         "risk"  = 1,
                         "surv"  = 2,
                         "chf"   = 3,
                         "mort"  = 4,
                         "mean"  = 5,
                         "prob"  = 6,
                         "class" = 7,
                         "leaf"  = 8,
                         "time"  = 2), # time=2 is not a typo
    pred_mode = .dots$pred_mode %||% FALSE,
    pred_aggregate = .dots$pred_aggregate %||% (self$pred_type != 'leaf'),
    pred_horizon = if(self$pred_type == 'time'){
     self$event_times
    } else {
     .dots$pred_horizon %||% self$pred_horizon %||% 1
    },
    oobag = .dots$oobag %||% self$oobag_pred_mode,
    oobag_R_function = .dots$oobag_eval_function %||% self$oobag_eval_function,
    oobag_eval_type_R = switch(
     tolower(.dots$oobag_eval_type %||% self$oobag_eval_type),
     "none" = 0,
     "harrell's c-index" = 1,
     "auc-roc" = 1,
     "user-specified function" = 2,
     "mse" = 3,
     "rsq" = 4
    ),
    oobag_eval_every = .dots$oobag_eval_every %||% self$oobag_eval_every,
    # switch(pd_type, "none" = 0L, "smry" = 1L, "ice" = 2L)
    pd_type_R = .dots$pd_type %||% 0L,
    pd_x_vals = .dots$pd_x_vals %||% list(matrix(0, ncol=0, nrow=0)),
    pd_x_cols = .dots$pd_x_cols %||% list(matrix(0, ncol=0, nrow=0)),
    pd_probs  = .dots$pd_probs %||% 0,
    n_thread  = .dots$n_thread %||% self$n_thread,
    write_forest = .dots$write_forest %||% TRUE,
    run_forest   = .dots$run_forest %||% TRUE,
    verbosity    = as.integer(.dots$verbosity %||% self$verbose_progress)
   )

  },

  sort_inputs = function(sort_y = NULL,
                         sort_x = NULL,
                         sort_w = NULL){
   NULL
  },

  # cleaners

  clean_importance = function(importance = NULL, group_factors = NULL){

   out <- importance %||% self$importance
   group_factors <- group_factors %||% self$importance_group_factors

   # nan indicates a variable was never used
   out[is.nan(out)] <- 0

   rownames(out) <- private$data_names$x_ref_code

   private$importance_raw <- out

   if(group_factors){

    fi <- private$data_fctrs

    if(!is_empty(fi$cols)){

     for(f in fi$cols[!fi$ordr]){

      f_lvls <- fi$lvls[[f]]
      f_rows <- match(paste(f, f_lvls[-1], sep = '_'), rownames(out))
      f_wts <- 1

      if(length(f_lvls) > 2){
       f_wts <- prop.table(x = table(self$data[[f]], useNA = 'no')[-1])
      }

      f_vi <- sum(out[f_rows] * f_wts, na.rm = TRUE)

      out[f_rows] <- f_vi
      rownames(out)[f_rows] <- f

     }

     if(!is_empty(fi$cols[!fi$ordr])) {
      # take extra care in case there are duplicate vi values.
      out <- data.frame(variable = rownames(out), value = out)
      out <- unique(out)
      out$variable <- NULL
      out <- as.matrix(out)
      colnames(out) <- NULL
     }

    }

   }

   self$importance <- rev(out[order(out), , drop=TRUE])

  },

  clean_pred_oobag = function(){

   if(self$pred_type == 'leaf'){

    all_rows <- seq(self$n_obs)

    for(i in seq(self$n_tree)){

     rows_inbag <- setdiff(all_rows, self$forest$rows_oobag[[i]]+1)
     self$pred_oobag[rows_inbag, i] <- NA_real_

    }

   }

   self$pred_oobag[is.nan(self$pred_oobag)] <- NA_real_

   private$clean_pred_oobag_internal()

  },

  clean_pred_oobag_internal = function(){
   NULL
  },

  clean_pred_new = function(preds){

   if(self$na_action == "pass"){

    out <- matrix(nrow = nrow(self$data),
                  ncol = ncol(preds))

    out[private$data_rows_complete, ] <- preds

   } else {

    out <- preds

   }

   if(self$pred_type == "leaf" || !self$pred_aggregate) return(out)

   private$clean_pred_new_internal(out)

  },

  restore_state = function(public_state = list(),
                           private_state = list()){

   if(!is_empty(public_state)){
    for(i in names(public_state)){
     self[[i]] <- public_state[[i]]
    }
   }

   if(!is_empty(private_state)){
    for(i in names(private_state)){
     private[[i]] <- private_state[[i]]
    }
   }

  }


 )

)

# ObliqueForestSurvival class ----

ObliqueForestSurvival <- R6::R6Class(
 "ObliqueForestSurvival",
 inherit = ObliqueForest,
 public = list(

  get_max_time = function(){
   return(private$max_time)
  },

  get_pred_type_vi = function(){
   return("mort")
  },

  leaf_min_events = NULL,
  split_min_events = NULL,
  pred_horizon = NULL,
  event_times = NULL

 ),

 # private ----
 private = list(

  pred_horizon_order = NULL,
  data_row_sort = NULL,
  max_time = NULL,
  n_events = NULL,

  check_split_rule_internal= function(){

   check_arg_is_valid(arg_value = self$split_rule,
                      arg_name = 'split_rule',
                      valid_options = c("logrank", "cstat"))

  },
  check_pred_type_internal = function(oobag,
                                      pred_type = NULL,
                                      context = NULL){

   input <- pred_type %||% self$pred_type

   arg_name <- if(oobag) 'oobag_pred_type' else 'pred_type'

   if(is.null(context)){
    valid_options <- c("none", "surv", "risk", "chf", "mort", "leaf", "time")
   } else {
    valid_options <- switch(
     context,
     'partial dependence' = c("surv", "risk", "chf", "mort", "time"),
     'prediction' = c("surv", "risk", "chf", "mort", "leaf", "time")
    )
    context <- paste(context, 'with survival forests')
   }

   check_arg_is_valid(arg_value = input,
                      arg_name = arg_name,
                      valid_options = valid_options,
                      context = context)

  },

  check_pred_horizon = function(pred_horizon = NULL,
                                boundary_checks = TRUE,
                                pred_type = NULL){

   pred_type <- pred_type %||% self$pred_type
   input <- pred_horizon %||% self$pred_horizon

   if(is.null(input) && pred_type %in% c('risk', 'surv', 'chf')){

    stop("pred_horizon must be specified for ",
         pred_type, " predictions.", call. = FALSE)

   }


   if(self$oobag_pred_mode)
    arg_name <- 'oobag_pred_horizon'
   else
    arg_name <- 'pred_horizon'

   check_arg_type(arg_value = input,
                  arg_name = 'pred_horizon',
                  expected_type = 'numeric')

   for(i in seq_along(input)){

    check_arg_gteq(arg_value = input[i],
                   arg_name = 'pred_horizon',
                   bound = 0)

   }

   if(boundary_checks){

    if(any(input > private$max_time)){

     stop("prediction horizon should ",
          "be <= max follow-up time ",
          "observed in training data: ",
          private$max_time,
          call. = FALSE)

    }

   }

  },
  check_leaf_min_events = function(){

   check_arg_type(arg_value = self$leaf_min_events,
                  arg_name = 'leaf_min_events',
                  expected_type = 'numeric')

   check_arg_is_integer(arg_value = self$leaf_min_events,
                        arg_name = 'leaf_min_events')

   check_arg_gteq(arg_value = self$leaf_min_events,
                  arg_name = 'leaf_min_events',
                  bound = 1)

   check_arg_length(arg_value = self$leaf_min_events,
                    arg_name = 'leaf_min_events',
                    expected_length = 1)

   check_arg_lteq(
    arg_value = self$leaf_min_events,
    arg_name = 'leaf_min_events',
    bound = round(private$n_events / 2),
    append_to_msg = "(number of events divided by 2)"
   )

  },
  check_split_min_events = function(){

   check_arg_type(arg_value = self$split_min_events,
                  arg_name = 'split_min_events',
                  expected_type = 'numeric')

   check_arg_is_integer(arg_value = self$split_min_events,
                        arg_name = 'split_min_events')

   check_arg_gteq(arg_value = self$split_min_events,
                  arg_name = 'split_min_events',
                  bound = 1)

   check_arg_length(arg_value = self$split_min_events,
                    arg_name = 'split_min_events',
                    expected_length = 1)

   check_arg_lt(
    arg_value = self$split_min_events,
    arg_name = "split_min_events",
    bound = private$n_events,
    append_to_msg = "(number of events)"
   )

  },

  check_oobag_eval_function_internal = function(oobag_fun){

   test_time <- seq(from = 1, to = 5, length.out = 100)
   test_status <- rep(c(0,1), each = 50)

   .y_mat <- cbind(time = test_time, status = test_status)
   .w_vec <- rep(1, times = 100)
   .s_vec <- seq(0.9, 0.1, length.out = 100)

   test_output <- try(oobag_fun(y_mat = .y_mat,
                                w_vec = .w_vec,
                                s_vec = .s_vec),
                      silent = FALSE)

   if(is_error(test_output)){

    stop("oobag_fun encountered an error when it was tested. ",
         "Please make sure your oobag_fun works for this case:\n\n",
         "test_time <- seq(from = 1, to = 5, length.out = 100)\n",
         "test_status <- rep(c(0,1), each = 50)\n\n",
         "y_mat <- cbind(time = test_time, status = test_status)\n",
         "w_vec <- rep(1, times = 100)\n",
         "s_vec <- seq(0.9, 0.1, length.out = 100)\n\n",
         "test_output <- oobag_fun(y_mat = y_mat, w_vec = w_vec, s_vec = s_vec)\n\n",
         "test_output should be a numeric value of length 1",
         call. = FALSE)

   }

   test_output

  },

  check_lincomb_R_function_internal = function(lincomb_R_function = NULL){

   input <- lincomb_R_function %||% self$lincomb_R_function

   test_time <- seq(from = 1, to = 5, length.out = 100)
   test_status <- rep(c(0,1), each = 50)

   .x_node <- matrix(rnorm(300), ncol = 3)
   .y_node <- cbind(time = test_time, status = test_status)
   .w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)


   out <- try(input(.x_node, .y_node, .w_node), silent = FALSE)

   if(is_error(out)){

    stop("user-supplied function encountered an error when it was tested. ",
         "Please make sure the function works for this case:\n\n",
         "test_time <- seq(from = 1, to = 5, length.out = 100)\n",
         "test_status <- rep(c(0,1), each = 50)\n\n",
         ".x_node <- matrix(seq(-1, 1, length.out = 300), ncol = 3)\n",
         ".y_node <- cbind(time = test_time, status = test_status)\n",
         ".w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)\n\n",
         "test_output <- user_function(.x_node, .y_node, .w_node)\n\n",
         "test_output should be a numeric matrix with 1 column and",
         " with nrow(test_output) = ncol(.x_node)",
         call. = FALSE)

   }

   out

  },

  sort_inputs = function(sort_x = TRUE,
                         sort_y = TRUE,
                         sort_w = TRUE){

   if(sort_x)
    private$x <- private$x[private$data_row_sort, , drop = FALSE]
   if(sort_y)
    private$y <- private$y[private$data_row_sort, , drop = FALSE]
   if(sort_w)
    private$w <- private$w[private$data_row_sort]

  },

  init_control = function(){

   self$control <- orsf_control_survival(method = 'glm',
                                         scale_x = FALSE,
                                         max_iter = 1)

  },

  init_pred_type = function(){
   self$pred_type <- 'risk'
  },

  init_split_rule = function(){
   self$split_rule <- 'logrank'
  },

  init_split_min_stat = function(){

   if(is.null(self$split_rule))
    stop("cannot init split_min_stat without split_rule", call. = FALSE)

   self$split_min_stat <- switch(self$split_rule,
                                 'logrank' = 3.841459,
                                 'cstat' = 0.55)
  },

  init_internal = function(){

   self$tree_type <- "survival"

   if(!is.function(self$control$lincomb_R_function) &&
      self$control$lincomb_type == 'net'){
    self$control$lincomb_R_function <- penalized_cph
   }

   y <- select_cols(self$data, private$data_names$y)

   if(inherits(y[[1]], 'Surv')){
    y <- as.matrix(y[[1]])
    y <- as.data.frame(y)
   }

   # strip units for these tests
   if(has_units(y[[1]])) y[[1]] <- as.numeric(y[[1]])

   check_arg_type(arg_value = y[[1]],
                  arg_name = "time to event",
                  expected_type = 'numeric')

   check_arg_gt(arg_value = y[[1]],
                arg_name = "time to event",
                bound = 0)

   # strip extra attributes for these tests
   if(is.factor(y[[2]])  ||
      is.logical(y[[2]]) ||
      has_units(y[[2]])  ){
    y[[2]] <- as.numeric(y[[2]])
   }

   check_arg_type(arg_value = y[[2]],
                  arg_name = "status indicator",
                  expected_type = 'numeric')

   if(self$na_action == 'omit')
    y <- y[private$data_rows_complete, ]


   # order observations by event time and event status
   private$data_row_sort <- collapse::radixorder(y[[1]], -y[[2]])
   # boundary check for pred horizon
   private$max_time <- y[last_value(private$data_row_sort), 1]
   # boundary check for event-based tree parameters
   private$n_events <- collapse::fsum(y[, 2])
   # unique event times sorted in ascending order
   self$event_times <- sort(collapse::funique(
    y[[1]][collapse::whichv(x = y[[2]], value = 1)]
   ), decreasing = FALSE)

   # if pred_horizon is unspecified, provide sensible default
   # if it is specified, check for correctness
   if(private$user_specified$pred_horizon){
    private$check_pred_horizon(self$pred_horizon, boundary_checks = TRUE)
   } else {
    self$pred_horizon <- collapse::fmedian(y[, 1])
   }

   private$check_leaf_min_events()
   private$check_split_min_events()

   if(!self$oobag_pred_mode) self$oobag_eval_type <- "none"

   # use default if eval type was not specified by user
   if(self$oobag_pred_mode && is.null(self$oobag_eval_type)){
    self$oobag_eval_type <- "Harrell's C-index"
   }

  },

  init_pred = function(pred_horizon = NULL, pred_type = NULL,
                       boundary_checks = TRUE){

   pred_type_supplied <- !is.null(pred_type)
   pred_horizon_supplied <- !is.null(pred_horizon)

   if(pred_type_supplied){
    private$check_pred_type(oobag = FALSE, pred_type = pred_type)
   } else {
    pred_type <- self$pred_type %||% "risk"
   }

   if(pred_horizon_supplied){
    private$check_pred_horizon(pred_horizon, boundary_checks, pred_type)
   } else {
    pred_horizon <- self$pred_horizon
    if(is.null(pred_horizon)){
     stop("pred_horizon was not specified and could not be found in object.",
          call. = FALSE)
    }
   }

   if(pred_type_supplied &&
      pred_horizon_supplied &&
      pred_type %in% c('leaf', 'mort', 'time')){

    extra_text <- if(length(pred_horizon)>1){
     " Predictions at each value of pred_horizon will be identical."
    } else {
     ""
    }

    warning("pred_horizon does not impact predictions",
            " when pred_type is '", pred_type, "'.",
            extra_text, call. = FALSE)

   }

   self$pred_horizon <- pred_horizon
   self$pred_type <- pred_type

  },

  prep_y_internal = function(placeholder = FALSE){


   if(placeholder){
    private$y <- matrix(0, ncol = 2, nrow = 1)
    return()
   }

   y <- private$y
   cols <- names(y)

   # if a surv object was included in the formula, it probably
   # doesn't need to be checked here, but it's still checked
   # just to be safe b/c if there is a problem with y it is
   # likely that it will cause R to crash in orsf_cpp().
   if(length(cols) == 1 && inherits(y[[1]], 'Surv')){
    y <- as.data.frame(as.matrix(y))
    cols <- names(y)
   }

   for(i in cols){
    if(has_units(y[[i]])) y[[i]] <- as.numeric(y[[i]])
   }

   # make sure y gets passed to CPP with status values of 0 and 1
   if(is.factor(y[[2]]) || is.logical(y[[2]])){
    y[[2]] <- as.numeric(y[[2]])
   }

   status_uni <- collapse::funique(y[[2]])

   # if nothing is censored
   if(all(status_uni == 1)) {

    private$y <- as_matrix(y)
    return()

   }

   # status values are modified if they are not all 0 and 1
   if(!is_equivalent(c(0, 1), status_uni)){

    # assume the lowest value of status indicates censoring
    censor_indicator <- collapse::fmin(status_uni)

    if(censor_indicator != as.integer(censor_indicator)){
     stop("the censoring indicator is not integer valued. ",
          "This can occur when the time column is swapped ",
          "with the status column by mistake. ",
          "Did you enter `status + time` when you meant ",
          "to put `time + status`?", call. = FALSE)
    }

    # assume the integer value 1 above censor is an event
    event_indicator <- censor_indicator + 1

    if( !(event_indicator %in% status_uni) ){
     stop("there does not appear to be an event indicator ",
          "in the status column. The censoring indicator appears to ",
          "be ", censor_indicator, " but there are no values of ",
          1 + censor_indicator, ", which we would assume to be an ",
          " event indicator. This can occur when the time column is ",
          "swapped with the status column by mistake. Did you enter ",
          "`status + time` when you meant  to enter `time + status`? ",
          "If this problem persists, try setting status to 0 for ",
          "censored observations and 1 for events.", call. = FALSE)
    }

    other_events <- which(y[[2]] > event_indicator)

    # we don't handle competing risks yet
    if(!is_empty(other_events)){

     stop("detected >1 event type in status variable. ",
          "Currently only 1 event type is supported. ",
          call. = FALSE)

     # y[other_events, 2] <- censor_indicator

    }

    # The status column needs to contain only 0s and 1s when it
    # gets passed into the C++ routines.
    y[[2]] <- y[[2]] - censor_indicator

    # check to make sure we don't send something into C++
    # that is going to make the R session crash.
    status_uni <- collapse::funique(y[[2]])

    if(!is_equivalent(c(0, 1), status_uni)){
     stop("could not coerce the status column to values of 0 and 1. ",
          "This can occur when the time column is ",
          "swapped with the status column by mistake. Did you enter ",
          "`status + time` when you meant  to enter `time + status`? ",
          "Please modify your data so that the status values are ",
          "0 and 1, with 0 indicating censored observations and 1 ",
          "indicating events.", call. = FALSE)
    }

   }

   private$y <- as_matrix(y)

  },

  clean_pred_oobag_internal = function(){

   # put the oob predictions into the same order as the training data.
   unsorted <- collapse::radixorder(private$data_row_sort)
   self$pred_oobag <- self$pred_oobag[unsorted, , drop = FALSE]

   if(self$pred_type == 'time'){

    self$eval_oobag$stat_values <-
     self$eval_oobag$stat_values[, ncol(self$pred_oobag), drop = FALSE]

    self$pred_oobag <- apply(self$pred_oobag,
                             MARGIN = 1,
                             FUN = rmst,
                             times = self$event_times)

    self$pred_oobag <- collapse::qM(self$pred_oobag)

   }

   # these predictions should always be 1 column
   # b/c they do not depend on the prediction horizon
   if(self$pred_type == 'mort'){

    self$eval_oobag$stat_values <-
     self$eval_oobag$stat_values[, 1L, drop = FALSE]

    self$pred_oobag <- self$pred_oobag[, 1L, drop = FALSE]

   }

  },
  clean_pred_new_internal = function(preds){

   if(self$pred_type == 'time'){

    preds <- apply(preds,
                   MARGIN = 1,
                   FUN = rmst,
                   times = self$event_times)
    preds <- collapse::qM(preds)

   }

   # don't let multiple pred horizon values through for mort
   if(self$pred_type %in% c('mort', 'time')){
    return(preds[, 1, drop = FALSE])
   }

   # output in the same order as user's pred_horizon vector
   preds <- preds[, order(private$pred_horizon_order), drop = FALSE]

   preds

  },

  predict_internal = function(simplify, oobag){

   private$pred_horizon_order <- order(self$pred_horizon)
   pred_horizon_ordered <- self$pred_horizon[private$pred_horizon_order]

   # must sort if oobag b/c when oobag_rows were originally created,
   # it was after the data had been sorted.
   if(oobag){
    private$sort_inputs(sort_y = FALSE)
    unsorted <- collapse::radixorder(private$data_row_sort)
   }

   cpp_args = private$prep_cpp_args(x = private$x,
                                    y = private$y,
                                    w = private$w,
                                    importance_type = 'none',
                                    pred_type = self$pred_type,
                                    pred_aggregate = self$pred_aggregate,
                                    pred_horizon = pred_horizon_ordered,
                                    oobag_pred = oobag,
                                    pred_mode = TRUE,
                                    write_forest = FALSE,
                                    run_forest = TRUE)

   if(length(self$pred_horizon) > 1 && !self$pred_aggregate){

    results <- vector(mode = 'list', length = length(self$pred_horizon))

    for(i in seq_along(results)){

     cpp_args$pred_horizon <- self$pred_horizon[i]

     results[[i]] <- do.call(orsf_cpp, args = cpp_args)$pred_new

     if(oobag){
      # put the oob predictions into the same order as the training data.
      results[[i]] <- results[[i]][unsorted, , drop = FALSE]
     }


     results[[i]] <- private$clean_pred_new(results[[i]])

    }

    # all components are the same if pred type is mort
    # (user also gets a warning if they ask for this)
    if(self$pred_type %in% c('mort', 'leaf', 'time')) return(results[[1]])

    if(simplify){
     results <- simplify2array(results)
    }

    return(results)

   }

   out_values <- do.call(orsf_cpp, args = cpp_args)$pred_new

   if(oobag){
    # put the oob predictions into the same order as the training data.
    out_values <- out_values[unsorted, , drop = FALSE]
   }

   private$clean_pred_new(out_values)

  }

 )

)

# ObliqueForestClassification class ----

ObliqueForestClassification <- R6::R6Class(
 "ObliqueForestClassification",
 inherit = ObliqueForest,
 public = list(

  n_class = NULL,

  class_levels = NULL,

  get_pred_type_vi = function(){
   return("prob")
  }

 ),

 # private ----
 private = list(

  check_split_rule_internal = function(){

   check_arg_is_valid(arg_value = self$split_rule,
                      arg_name = 'split_rule',
                      valid_options = c("gini", "cstat"))

  },
  check_pred_type_internal = function(oobag,
                                      pred_type = NULL,
                                      context = NULL){

   input <- pred_type %||% self$pred_type

   arg_name <- if(oobag) 'oobag_pred_type' else 'pred_type'

   if(is.null(context)){
    valid_options <- c("none", "prob", "class", "leaf")
   } else {
    valid_options <- switch(
     context,
     'partial dependence' = c("prob"),
     'prediction' = c("prob", "class", "leaf")
    )
    context <- paste(context, 'with classification forests')
   }

   check_arg_is_valid(arg_value = input,
                      arg_name = arg_name,
                      valid_options = valid_options,
                      context = context)

  },

  check_pred_horizon = function(pred_horizon = NULL,
                                boundary_checks = TRUE,
                                pred_type = NULL){

   # nothing to check
   NULL

  },

  check_oobag_eval_function_internal = function(oobag_fun){

   test_y <- rep(c(0,1), each = 50)

   .y_mat <- matrix(test_y, ncol = 1)
   .w_vec <- rep(1, times = 100)
   .s_vec <- seq(0.9, 0.1, length.out = 100)

   test_output <- try(oobag_fun(y_mat = .y_mat,
                                w_vec = .w_vec,
                                s_vec = .s_vec),
                      silent = FALSE)

   if(is_error(test_output)){

    stop("oobag_fun encountered an error when it was tested. ",
         "Please make sure your oobag_fun works for this case:\n\n",
         "test_y <- rep(c(0,1), each = 50)\n",
         "y_mat <- matrix(test_y, ncol = 1)\n",
         "w_vec <- rep(1, times = 100)\n",
         "s_vec <- seq(0.9, 0.1, length.out = 100)\n\n",
         "test_output <- oobag_fun(y_mat = y_mat, w_vec = w_vec, s_vec = s_vec)\n\n",
         "test_output should be a numeric value of length 1",
         call. = FALSE)

   }

   test_output

  },

  check_lincomb_R_function_internal = function(lincomb_R_function = NULL){

   input <- lincomb_R_function %||% self$lincomb_R_function

   .x_node <- matrix(rnorm(300), ncol = 3)
   .y_node <- matrix(rbinom(100, size = 1, prob = 1/2), ncol = 1)
   .w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)

   out <- try(input(.x_node, .y_node, .w_node), silent = FALSE)

   if(is_error(out)){

    stop("user-supplied function encountered an error when it was tested. ",
         "Please make sure the function works for this case:\n\n",
         ".x_node <- matrix(rnorm(300), ncol = 3)\n",
         ".y_node <- matrix(rbinom(100, size = 1, prob = 1/2), ncol = 1)\n",
         ".w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)\n",
         "test_output <- your_function(.x_node, .y_node, .w_node)\n\n",
         "test_output should be a numeric matrix with 1 column and",
         " with nrow(test_output) = ncol(.x_node)",
         call. = FALSE)

   }

   out

  },

  init_control = function(){

   self$control <- orsf_control_classification(method = 'glm',
                                               scale_x = FALSE,
                                               max_iter = 1)

  },

  init_pred_type = function(){
   self$pred_type <- 'prob'
  },

  init_split_rule = function(){
   self$split_rule <- 'gini'
  },

  init_split_min_stat = function(){

   if(is.null(self$split_rule))
    stop("cannot init split_min_stat without split_rule", call. = FALSE)

   self$split_min_stat <- switch(self$split_rule,
                                 'gini' = 0,
                                 'cstat' = 0.55)

  },

  init_internal = function(){

   self$tree_type <- "classification"

   if(!is.function(self$control$lincomb_R_function) &&
      self$control$lincomb_type == 'net'){
    self$control$lincomb_R_function <- penalized_logreg
   }

   if(!self$oobag_pred_mode) self$oobag_eval_type <- "none"

   # use default if eval type was not specified by user
   if(self$oobag_pred_mode && is.null(self$oobag_eval_type)){
    self$oobag_eval_type <- "AUC-ROC"
   }

   y <- self$data[[private$data_names$y]]

   if(is.factor(y)){
    self$class_levels <- levels(y)
    self$n_class <- length(self$class_levels)
   } else {
    self$class_levels <- unique(y)
    self$n_class <- length(self$class_levels)
   }



  },

  init_pred = function(pred_horizon = NULL, pred_type = NULL,
                       boundary_checks = TRUE){

   if(!is.null(pred_horizon)){
    warning("pred_horizon does not impact predictions",
            " for classification forests", call. = FALSE)
   }

   if(!is.null(pred_type)){
    private$check_pred_type(oobag = FALSE, pred_type = pred_type)
   } else {
    pred_type <- self$pred_type %||% "prob"
   }

   self$pred_type <- pred_type

  },

  prep_y_internal = function(placeholder = FALSE){

   if(placeholder){
    private$y <- matrix(0, ncol = self$n_class, nrow = 1)
    return()
   }

   # y is always 1 column for classification (right?)
   y <- private$y[[1]]

   input_was_numeric <- !is.factor(y)

   if(input_was_numeric) y <- as.factor(y)

   n_class <- length(levels(y))

   if(n_class > 5 && input_was_numeric){
    stop("The outcome is numeric and has > 5 unique values.",
         " Did you mean to use orsf_control_regression()? If not,",
         " please convert ", private$data_names$y, " to a factor and re-run",
         call. = FALSE)
   }

   y <- as.numeric(y) - 1

   if(min(y) > 0) stop("y is less than 0")

   private$y <- expand_y_clsf(as_matrix(y), n_class)

  },

  clean_pred_oobag_internal = function(){

   if(self$pred_type %in% c("prob") && is.matrix(self$pred_oobag)){
    colnames(self$pred_oobag) <- self$class_levels
   }

  },

  predict_internal = function(simplify, oobag){

   # resize y to have the right number of columns
   private$y <- matrix(0, ncol = self$n_class)

   cpp_args = private$prep_cpp_args(x = private$x,
                                    y = private$y,
                                    w = private$w,
                                    importance_type = 'none',
                                    pred_type = self$pred_type,
                                    pred_aggregate = self$pred_aggregate,
                                    oobag_pred = oobag,
                                    pred_mode = TRUE,
                                    write_forest = FALSE,
                                    run_forest = TRUE)

   if(!self$pred_aggregate && self$n_class > 2 && self$pred_type == 'prob'){
    stop("unaggregated probability predictions for outcomes with >2 classes",
         " is not currently supported.", call. = FALSE)
   }

   out <- do.call(orsf_cpp, args = cpp_args)$pred_new

   out <- private$clean_pred_new(out)

   if(self$pred_type == 'prob' && self$pred_aggregate){

    colnames(out) <- self$class_levels

    if(simplify && self$n_class == 2)
     out <- out[, 2L, drop = TRUE]

   }

   if(self$pred_type == 'class'){

    # cpp class levels start at 0, R levels start at 1
    out <- out + 1

    # convert to factor (and coerce to vector) if asked
    if(simplify)
     out <- factor(out,
                   levels = seq(self$n_class),
                   labels = self$class_levels)
   }

   out

  },

  clean_pred_new_internal = function(preds){

   preds

  }

 )
)


# ObliqueForestRegression class ----

ObliqueForestRegression <- R6::R6Class(
 "ObliqueForestRegression",
 inherit = ObliqueForest,
 public = list(

  get_pred_type_vi = function(){
   return("mean")
  }

 ),

 # private ----
 private = list(

  check_split_rule_internal = function(){

   check_arg_is_valid(arg_value = self$split_rule,
                      arg_name = 'split_rule',
                      valid_options = c("variance"))

  },

  check_pred_type_internal = function(oobag,
                                      pred_type = NULL,
                                      context = NULL){

   input <- pred_type %||% self$pred_type

   arg_name <- if(oobag) 'oobag_pred_type' else 'pred_type'

   if(is.null(context)){
    valid_options <- c("none", "mean", "leaf")
   } else {
    valid_options <- switch(
     context,
     'partial dependence' = c("mean"),
     'prediction' = c("mean", "leaf")
    )
    context <- paste(context, 'with regression forests')
   }

   check_arg_is_valid(arg_value = input,
                      arg_name = arg_name,
                      valid_options = valid_options,
                      context = context)

  },

  check_pred_horizon = function(pred_horizon = NULL,
                                boundary_checks = TRUE,
                                pred_type = NULL){

   # nothing to check
   NULL

  },

  check_oobag_eval_function_internal = function(oobag_fun){


   test_y <- seq(0, 1, length.out = 100)

   .y_mat <- matrix(test_y, ncol = 1)
   .w_vec <- rep(1, times = 100)
   .s_vec <- seq(0.9, 0.1, length.out = 100)

   test_output <- try(oobag_fun(y_mat = .y_mat,
                                w_vec = .w_vec,
                                s_vec = .s_vec),
                      silent = FALSE)

   if(is_error(test_output)){

    stop("oobag_fun encountered an error when it was tested. ",
         "Please make sure your oobag_fun works for this case:\n\n",
         "test_y <- seq(0, 1, length.out = 100)\n",
         "y_mat <- matrix(test_y, ncol = 1)\n",
         "w_vec <- rep(1, times = 100)\n",
         "s_vec <- seq(0.9, 0.1, length.out = 100)\n\n",
         "test_output <- oobag_fun(y_mat = y_mat, w_vec = w_vec, s_vec = s_vec)\n\n",
         "test_output should be a numeric value of length 1",
         call. = FALSE)

   }

   test_output

  },

  check_lincomb_R_function_internal = function(lincomb_R_function = NULL){

   input <- lincomb_R_function %||% self$lincomb_R_function

   .x_node <- matrix(rnorm(300), ncol = 3)
   .y_node <- matrix(rnorm(100), ncol = 1)
   .w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)

   out <- try(input(.x_node, .y_node, .w_node), silent = FALSE)

   if(is_error(out)){

    stop("user-supplied function encountered an error when it was tested. ",
         "Please make sure the function works for this case:\n\n",
         ".x_node <- matrix(rnorm(300), ncol = 3)\n",
         ".y_node <- matrix(rnorm(100), ncol = 1)\n",
         ".w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)\n",
         "test_output <- your_function(.x_node, .y_node, .w_node)\n\n",
         "test_output should be a numeric matrix with 1 column and",
         " with nrow(test_output) = ncol(.x_node)",
         call. = FALSE)

   }

   out

  },

  init_control = function(){

   self$control <- orsf_control_regression(method = 'glm',
                                           scale_x = FALSE,
                                           max_iter = 1)

  },

  init_pred_type = function(){
   self$pred_type <- 'mean'
  },

  init_split_rule = function(){
   self$split_rule <- 'variance'
  },

  init_split_min_stat = function(){

   if(is.null(self$split_rule))
    stop("cannot init split_min_stat without split_rule", call. = FALSE)

   self$split_min_stat <- switch(self$split_rule, 'variance' = 0)

  },

  init_internal = function(){

   self$tree_type <- "regression"

   if(is.factor(self$data[[private$data_names$y]])){
    stop("Cannot fit regression trees to outcome ",
         private$data_names$y, " because it is a factor.",
         " Did you mean to use orsf_control_classification()?",
         call. = FALSE)
   }

   if(!is.function(self$control$lincomb_R_function) &&
      self$control$lincomb_type == 'net'){
    self$control$lincomb_R_function <- penalized_linreg
   }

   if(!self$oobag_pred_mode) self$oobag_eval_type <- "none"

   # use default if eval type was not specified by user
   if(self$oobag_pred_mode && is.null(self$oobag_eval_type)){
    self$oobag_eval_type <- "RSQ"
   }

  },

  init_pred = function(pred_horizon = NULL, pred_type = NULL,
                       boundary_checks = TRUE){

   if(!is.null(pred_horizon)){
    warning("pred_horizon does not impact predictions",
            " for regression forests", call. = FALSE)
   }

   if(!is.null(pred_type)){
    private$check_pred_type(oobag = FALSE, pred_type = pred_type)
   } else {
    pred_type <- self$pred_type %||% "mean"
   }

   self$pred_type <- pred_type

  },

  prep_y_internal = function(placeholder = FALSE){

   if(placeholder){
    private$y <- matrix(0, ncol = 1, nrow = 1)
    return()
   }

   # y is always 1 column for regression (for now)
   private$y <- as_matrix(private$y[[1]])
   colnames(private$y) <- private$data_names$y

  },

  predict_internal = function(simplify, oobag){

   # resize y to have the right number of columns
   private$y <- matrix(0, ncol = 1)

   cpp_args = private$prep_cpp_args(x = private$x,
                                    y = private$y,
                                    w = private$w,
                                    importance_type = 'none',
                                    pred_type = self$pred_type,
                                    pred_aggregate = self$pred_aggregate,
                                    oobag_pred = oobag,
                                    pred_mode = TRUE,
                                    write_forest = FALSE,
                                    run_forest = TRUE)


   out <- do.call(orsf_cpp, args = cpp_args)$pred_new

   out <- private$clean_pred_new(out)

   if(simplify) dim(out) <- NULL

   out

  },

  clean_pred_new_internal = function(preds){

   preds

  }

 )
)

Try the aorsf package in your browser

Any scripts or data that you put into this service are public.

aorsf documentation built on June 22, 2024, 10:31 a.m.