R/construct_models.R

Defines functions construct_models

Documented in construct_models

#' Construct a list of models for synthesis
#'
#' @param roadmap A roadmap object
#' @param default_regression_model A `parsnip` model object used for 
#' regression in numeric outcome variables
#' @param default_classification_model A `parsnip` model object used for 
#' classification in categorical outcome variables
#' @param custom_models A formatted list with `parsnip` model objects explicitly
#' paired with every variable in the `visit_sequence`
#'
#' @return A named list of models
#' 
#' @examples
#' 
#' # construct_models() can create a sequence of models using a fully-default 
#' # approach, a hybrid approach, or a fully-customized approach. All approaches
#' # require a roadmap and model objects. 
#' 
#' rm <- roadmap(
#'   conf_data = acs_conf_nw,
#'   start_data = acs_start_nw
#' )
#' 
#' rpart_mod_reg <- parsnip::decision_tree() |>
#'   parsnip::set_engine(engine = "rpart") |>
#'   parsnip::set_mode(mode = "regression")
#'
#' rpart_mod_class <- parsnip::decision_tree() |>
#'   parsnip::set_engine(engine = "rpart") |>
#'   parsnip::set_mode(mode = "classification")
#' 
#' lm_mod <- parsnip::linear_reg() |> 
#'   parsnip::set_engine("lm") |>
#'   parsnip::set_mode(mode = "regression")
#' 
#' # Fully-default approach
#' 
#' construct_models(
#'   roadmap = rm, 
#'   default_regression_model = lm_mod, 
#'   default_classification_model = rpart_mod_class
#' )
#' 
#' # Hybrid approach
#' 
#' construct_models(
#'   roadmap = rm, 
#'   default_regression_model = lm_mod,
#'   default_classification_model = rpart_mod_class,
#'   custom_models = list(
#'     list(vars = "age", model = lm_mod)
#'   )
#' )
#' 
#' # Fully-customized approach
#' 
#' construct_models(
#'   roadmap = rm, 
#'   custom_models = list(
#'     list(vars = c("hcovany", "empstat", "classwkr"), model = rpart_mod_class),
#'     list(vars = c("age", "famsize", "transit_time", "inctot"), model = rpart_mod_reg)
#'   )
#' )
#' 
#' @export
construct_models <- function(
    roadmap, 
    default_regression_model = NULL, 
    default_classification_model = NULL,
    custom_models = NULL
) {
  
  # create vectors that we will use below
  if (!is_roadmap(roadmap)) {
    
    stop("`roadmap` must be a roadmap object")
    
  }
  
  visit_sequence <- roadmap[["visit_sequence"]][["visit_sequence"]]
  mode <- .extract_mode(roadmap)
  
  # validate inputs
  .validate_construct_inputs_required(
    visit_sequence = visit_sequence,
    mode = mode,
    default_reg = default_regression_model, 
    default_class = default_classification_model, 
    custom_list = custom_models,
    type_check_func = .is_model,
    obj_name = "model(s)"
  )

  # construct models --------------------------------------------------------

  # create a list of default models where the default depends on if the model 
  # is a regression model or a classification model
  synth_models <- purrr::map(
    .x = mode, 
    .f = ~ if (.x == "regression") { default_regression_model } else { default_classification_model }
  )
  
  # add names to object
  names(synth_models) <- visit_sequence
  
  # iterate through the variables and overwrite the default if an alternative
  # model is specified in custom_models
  for (var in visit_sequence) {
  
    # see if there is a custom model
    custom_model <- NULL
    for (i in seq_along(custom_models)) {
      
      if (var %in% custom_models[[i]][["vars"]]) {
        
        custom_model <- custom_models[[i]][["model"]]
        
      }
      
    }
    
    # if custom model, then replace everything with the custom model
    if (!is.null(custom_model)) {
    
      synth_models[[var]] <- custom_model
    
    }
    
  }
  
  # overwrite models for outcome variables with no variation
  identity <- list(
    args = NULL,
    eng_args = NULL,
    mode = "identity",
    use_specified_mode = TRUE,
    method = NULL,
    engine = "identity",
    user_specified_engine = TRUE
  )
  
  no_var_vars <- roadmap[["schema"]][["no_variation"]]
  
  no_var_vars <- names(no_var_vars)[unname(no_var_vars)]
  
  if (!is.null(no_var_vars)) { 
    
    synth_models <- purrr::modify_at(
      .x = synth_models,
      .at = no_var_vars,
      .f = ~ identity
    )
    
  }
  
  return(synth_models)
  
}

Try the tidysynthesis package in your browser

Any scripts or data that you put into this service are public.

tidysynthesis documentation built on March 17, 2026, 1:06 a.m.