caret_removal_plan.md

Okay, this is an excellent and well-thought-out "surgery plan." It correctly identifies the key caret touchpoints and proposes very sensible, phased replacements. My additions will focus on formalizing this into a step-by-step project plan, adding a bit more detail to each replacement, considering potential edge cases, and ensuring the new helper functions are well-defined.

Formalized Plan: Excising caret from rMVPA

Overall Strategy: We will follow the phased approach outlined in the "surgery plan." Each phase will aim to replace a distinct piece of caret functionality, be tested thoroughly, and keep the package in a runnable state. The primary replacements will leverage rsample for resampling and yardstick for metrics, both from the tidymodels ecosystem, which are lightweight and modern. For tuning, we'll start with a custom loop (Option A) and consider tune (Option B) as a future enhancement.

Phase 0: Preparation 1. Branching: Create a new feature branch in Git (e.g., feature/remove-caret). 2. Dependency Update (Anticipatory): Add rsample and yardstick to Suggests in the DESCRIPTION file for now. We'll move them to Imports as they become strictly necessary. This allows for gradual integration.

Phase 1: Decouple Model Discovery and Basic Infrastructure

Phase 2: Resampling Control and Performance Metrics

Phase 3: Hyperparameter Tuning

Phase 4: Cleanup and Finalization

  1. Update DESCRIPTION:
    • Remove caret from Imports.
    • Ensure rsample, yardstick are in Imports. (purrr, dplyr, tibble are already there).
  2. Documentation:
    • Update any documentation, vignettes, or examples that refer to caret models not in MVPAModels or caret specific tuning options.
    • Document the MVPAModels structure and the register_mvpa_model() helper.
  3. Final R CMD check --as-cran: Ensure everything passes.

New Helper Functions to Implement:

  1. register_mvpa_model(name, model_spec) (e.g., in R/common.R or R/classifiers.R):

    • Allows users to add their own models to the MVPAModels environment.
    • model_spec must be a list with type, library, label, parameters, grid, fit, predict, prob elements matching the MVPAModels convention. ```R

    ' Register a Custom MVPA Model

    '

    ' Adds a user-defined model specification to the rMVPA model registry (MVPAModels).

    '

    ' @param name A character string, the unique name for the model.

    ' @param model_spec A list containing the model specification. It must include

    ' elements: type ("Classification" or "Regression"), library (character vector

    ' of required packages), label (character, usually same as name), parameters

    ' (data.frame of tunable parameters: name, class, label), grid (function to

    ' generate tuning grid), fit (function), predict (function), and prob

    ' (function, for classification).

    ' @export

    ' @examples

    ' \dontrun{

    ' my_model_spec <- list(

    ' type = "Classification", library = "e1071", label = "my_svm",

    ' parameters = data.frame(parameter = "cost", class = "numeric", label = "Cost"),

    ' grid = function(x, y, len = NULL) data.frame(cost = 10^(-3:3)),

    ' fit = function(x, y, wts, param, lev, ...) e1071::svm(x, y, cost = param$cost, ...),

    ' predict = function(modelFit, newdata, ...) predict(modelFit, newdata),

    ' prob = function(modelFit, newdata, ...) {

    ' # For SVM with probability=TRUE, need to handle attr(predict(...), "probabilities")

    ' pred_obj <- predict(modelFit, newdata, probability = TRUE)

    ' attr(pred_obj, "probabilities")

    ' }

    ' )

    ' register_mvpa_model("my_svm", my_model_spec)

    ' loaded_model <- load_model("my_svm")

    ' }

    register_mvpa_model <- function(name, model_spec) { # Basic validation of model_spec structure required_elements <- c("type", "library", "label", "parameters", "grid", "fit", "predict", "prob") if (!all(required_elements %in% names(model_spec))) { stop("model_spec is missing one or more required elements: ", paste(setdiff(required_elements, names(model_spec)), collapse=", ")) } if (!is.data.frame(model_spec$parameters) || !all(c("parameter", "class", "label") %in% names(model_spec$parameters))) { stop("'model_spec$parameters' must be a data.frame with columns: parameter, class, label.") } # Assign to MVPAModels environment MVPAModels[[name]] <- model_spec invisible(NULL) } ```

  2. create_mvpa_folds(): Already detailed in Phase 1.

Benefits Gained (Reiterated): Reduced Dependencies: Significant drop in installation footprint and complexity. Modern Tooling: Leverages rsample and yardstick for robust and maintainable resampling and metrics. Explicit Model Handling: Clearer distinction between internal MVPAModels and user-provided ones. Control: More direct control over the tuning loop. * Future Flexibility: Easier to integrate other tuning backends (like tune::tune_grid) or advanced resampling strategies later without being tied to caret's large ecosystem.

This formalized plan builds directly on your excellent analysis and provides a concrete path forward. The main effort will be in thoroughly implementing and testing the new tune_model function and ensuring metric calculations align.This is an excellent, detailed "surgery plan." It's well-structured, pragmatic, and correctly identifies the key caret dependencies and proposes sound replacements. The phased approach is spot on for managing complexity and ensuring the package remains runnable.

Here's a formalized path forward, building upon your plan, with some additional details, considerations, and helper function definitions.

Formalized Plan: Excising caret from rMVPA

Overall Strategy: We will adopt the phased approach outlined. The core idea is to replace caret's model management, resampling, and tuning infrastructure with a combination of in-house solutions and lightweight, focused tidymodels packages (rsample for resampling, yardstick for metrics). This will significantly reduce dependencies and provide more direct control.

Phase 0: Preparation 1. Branching: Create a new feature branch in Git (e.g., feature/remove-caret). 2. Dependency Management: * In DESCRIPTION: * Add rsample and yardstick to Imports:. * caret will remain in Imports: until Phase 4. * Add #' @importFrom rsample ... and #' @importFrom yardstick ... in relevant R files as functions are used.

Phase 1: Decouple Model Discovery and Basic Infrastructure

Phase 2: Resampling Control and Performance Metrics

Phase 3: Hyperparameter Tuning

Phase 4: Cleanup and Finalization

  1. Update DESCRIPTION:
    • Remove caret and c060 (if glmnet_opt's epsgo is replaced or c060 is re-evaluated for direct use) from Imports.
    • Ensure rsample, yardstick, rlang (already there), purrr (already there) are in Imports. If glmnet_opt still uses epsgo from c060, then c060 needs to stay. (The plan above keeps epsgo for glmnet_opt but replaces its internal createFolds).
  2. Documentation:
    • Update vignettes, examples, and function documentation that might refer to caret or imply caret models can be used directly.
    • Clearly document the MVPAModels structure and register_mvpa_model().
  3. Final R CMD check --as-cran.

Helper Function: register_mvpa_model (R/common.R or R/classifiers.R) (As detailed in your Phase 3, copied here for completeness)

    #' Register a Custom MVPA Model
    #'
    #' Adds a user-defined model specification to the rMVPA model registry (MVPAModels).
    #'
    #' @param name A character string, the unique name for the model.
    #' @param model_spec A list containing the model specification. It must include
    #'   elements: `type` ("Classification" or "Regression"), `library` (character vector
    #'   of required packages for the *model itself*, not for rMVPA's wrappers), 
    #'   `label` (character, usually same as name), `parameters`
    #'   (data.frame of tunable parameters: parameter, class, label), `grid` (function to
    #'   generate tuning grid, takes x, y, len args), `fit` (function), `predict` (function), 
    #'   and `prob` (function for classification, takes modelFit, newdata; should return matrix/df with colnames as levels).
    #' @export
    #' @examples
    #' \dontrun{
    #' # Example of how a user might define an e1071 SVM spec
    #' my_svm_spec <- list(
    #'   type = "Classification", library = "e1071", label = "my_svm",
    #'   parameters = data.frame(parameter = "cost", class = "numeric", label = "Cost (C)"),
    #'   # grid should return a data.frame with columns matching 'parameter' names in 'parameters'
    #'   grid = function(x, y, len = NULL) { 
    #'      data.frame(cost = if (is.null(len) || len == 1) 1 else 10^seq(-2, 2, length.out = len))
    #'   },
    #'   # fit function receives: x, y, wts (weights), param (current params from grid), 
    #'   # lev (levels of y), last (unused), weights (unused), classProbs (unused by e1071::svm)
    #'   fit = function(x, y, wts, param, lev, last, weights, classProbs, ...) {
    #'      e1071::svm(x, y, cost = param$cost, probability = TRUE, ...) # Ensure probability=TRUE for prob
    #'   },
    #'   # predict function receives: modelFit (output of $fit), newdata
    #'   predict = function(modelFit, newdata, ...) {
    #'      predict(modelFit, newdata, ...)
    #'   },
    #'   # prob function receives: modelFit, newdata
    #'   # Should return a matrix/df with columns named as in levels(y)
    #'   prob = function(modelFit, newdata, ...) {
    #'     pred_obj <- predict(modelFit, newdata, probability = TRUE)
    #'     attr(pred_obj, "probabilities") 
    #'   }
    #' )
    #' register_mvpa_model("my_svm", my_svm_spec)
    #' # Now load_model("my_svm") would work.
    #' }
    register_mvpa_model <- function(name, model_spec) {
      required_elements <- c("type", "library", "label", "parameters", "grid", "fit", "predict", "prob")
      if (!all(required_elements %in% names(model_spec))) {
        stop("model_spec is missing one or more required elements: ", 
             paste(setdiff(required_elements, names(model_spec)), collapse=", "))
      }
      if (!is.data.frame(model_spec$parameters) || 
          !all(c("parameter", "class", "label") %in% names(model_spec$parameters))) {
        stop("'model_spec$parameters' must be a data.frame with columns: parameter, class, label.")
      }
      MVPAModels[[name]] <- model_spec
      invisible(NULL)
    }

Okay, here's a ticketed list of items for the caret refactoring, based on the formalized plan. Each ticket includes the targeted file(s), a description of the task, and acceptance criteria (Definition of Done - DoD).

Project: Refactor rMVPA - Remove caret Dependency

Preamble: All work should be done on a dedicated feature branch (e.g., feature/remove-caret). After each phase (or significant ticket), run R CMD check and relevant existing tests to ensure the package remains in a runnable state. * Incrementally update NAMESPACE with @importFrom directives as new functions from rsample and yardstick are used.

Phase 0: Preparation

  1. [ ] Ticket #001: Setup - Branching and Initial Dependency Declaration
    • File(s): DESCRIPTION
    • Task:
      1. Create a new Git feature branch (e.g., feature/remove-caret).
      2. Add rsample and yardstick to the Suggests: field in the DESCRIPTION file. (They will be moved to Imports: later as their functions are directly used).
    • DoD:
      • New Git branch created and checked out.
      • DESCRIPTION file updated with rsample and yardstick in Suggests:.
      • Package installs and loads without error.

Phase 1: Decouple Model Discovery and Basic Infrastructure

  1. [ ] Ticket #002: Modify load_model() - Remove caret Fallback

    • File(s): R/common.R
    • Task: Remove the else if (length(caret::getModelInfo(name)) > 0) … block from load_model(). Update the error message for unknown models.
    • DoD:
      • load_model() no longer calls caret::getModelInfo().
      • load_model() successfully loads models defined in MVPAModels.
      • load_model() throws an error for models not in MVPAModels.
      • Relevant unit tests for load_model pass.
  2. [ ] Ticket #003: Implement create_mvpa_folds() Helper

    • File(s): New file (e.g., R/resampling_utils.R)
    • Task: Implement the create_mvpa_folds(y, k, list, seed) function using rsample::vfold_cv for stratified k-fold CV (if y is factor) or simple random k-fold CV.
    • DoD:
      • create_mvpa_folds() function exists and is exported (or internally available if preferred).
      • Unit tests for create_mvpa_folds() cover:
        • Correct number of folds returned.
        • list=TRUE returns a list of index vectors.
        • list=FALSE returns a vector of fold assignments.
        • Stratification works correctly for factor y.
        • seed argument ensures reproducibility.
        • Handles edge cases (e.g., k > n, k = n).
  3. [ ] Ticket #004: Patch glmnet_opt Model for Fold Creation

    • File(s): R/classifiers.R (within MVPAModels$glmnet_opt$fit)
    • Task: Replace the caret::createFolds() call with a call to the new create_mvpa_folds(y, k=5, list=FALSE, seed=1234) to generate the foldid_vector.
    • DoD:
      • MVPAModels$glmnet_opt$fit no longer calls caret::createFolds.
      • glmnet_opt model trains successfully using folds generated by create_mvpa_folds.
      • Existing tests for glmnet_opt (if any specifically test fold generation) pass.
  4. [ ] Ticket #005: Verify Prediction Method Calls

    • File(s): R/model_fit.R (predict.class_model_fit, predict.regression_model_fit)
    • Task: Review these prediction methods. Confirm that object$model$prob and object$model$predict correctly refer to functions within the MVPAModels specifications after load_model changes, not caret model objects.
    • DoD:
      • Code review confirms correct dispatch to MVPAModels functions.
      • No code changes are required if the assumption holds.
      • Existing prediction tests continue to pass.
  5. [ ] Ticket #006: Remove load_caret_libs() Function

    • File(s): R/model_fit.R
    • Task: Delete the load_caret_libs() function.
    • DoD:
      • Function load_caret_libs() is removed.
      • Package compiles and loads without error.
  6. [ ] Ticket #007: Phase 1 Integration Testing

    • File(s): N/A (Testing activity)
    • Task: Run R CMD check. Execute all existing unit tests. Pay special attention to model loading, glmnet_opt training, and general prediction pathways.
    • DoD:
      • R CMD check passes with no new errors/warnings related to these changes.
      • All existing relevant unit tests pass.

Phase 2: Resampling Control and Performance Metrics

  1. [ ] Ticket #008: Rewrite get_control()

    • File(s): R/model_fit.R
    • Task: Rewrite get_control(y, nreps) to return a simple list containing the metric name (e.g., "roc_auc", "accuracy", "rmse") and number (for nreps), removing caret::trainControl logic.
    • DoD:
      • get_control() no longer calls caret::trainControl.
      • Function returns the expected list structure based on y.
  2. [ ] Ticket #009: Remove mclass_summary() Function

    • File(s): R/model_fit.R
    • Task: Delete the mclass_summary() function.
    • DoD:
      • Function mclass_summary() is removed.
      • Package compiles and loads.
  3. [ ] Ticket #010: Update binary_perf() using yardstick

    • File(s): R/performance.R
    • Task: Rewrite binary_perf(observed, predicted, probs) to use yardstick::accuracy_vec and yardstick::roc_auc_vec.
    • DoD:
      • binary_perf() uses yardstick functions.
      • Returns named vector c(Accuracy = ..., AUC = ...).
      • Metrics are validated against known examples or previous caret outputs.
  4. [ ] Ticket #011: Update multiclass_perf() using yardstick

    • File(s): R/performance.R
    • Task: Rewrite multiclass_perf(observed, predicted, probs, class_metrics) to use yardstick::accuracy_vec and yardstick::roc_auc_vec (with estimator="macro" for average AUC). Implement per-class AUC if class_metrics=TRUE.
    • DoD:
      • multiclass_perf() uses yardstick functions.
      • Returns named vector including Accuracy, AUC, and per-class AUCs if requested.
      • Metrics validated.
  5. [ ] Ticket #012: Update performance.regression_result() using yardstick

    • File(s): R/performance.R
    • Task: Rewrite performance.regression_result(x, split_list, ...) to use yardstick::rsq_vec, yardstick::rmse_vec, and stats::cor for Spearman correlation.
    • DoD:
      • performance.regression_result() uses yardstick and stats::cor.
      • Returns named vector c(R2=..., RMSE=..., spearcor=...).
      • Metrics validated.
  6. [ ] Ticket #013: Update Performance Wrappers in mvpa_model.R

    • File(s): R/mvpa_model.R
    • Task: Ensure get_multiclass_perf(), get_binary_perf(), get_regression_perf() correctly call the updated S3 performance methods.
    • DoD:
      • Wrapper functions correctly dispatch to the new yardstick-based performance calculators.
      • compute_performance.mvpa_model functions as expected.
  7. [ ] Ticket #014: Phase 2 Integration Testing

    • File(s): N/A (Testing activity)
    • Task: Run R CMD check. Execute all existing unit tests, focusing on those that calculate and report performance metrics. Compare outputs with previous versions if possible.
    • DoD:
      • R CMD check passes.
      • Performance metrics are correctly calculated and reported.

Phase 3: Hyperparameter Tuning

  1. [ ] Ticket #015: Rewrite tune_model() - Custom Tuning Loop

    • File(s): R/model_fit.R
    • Task: Implement the new tune_model(mspec, x, y, wts, param_grid, nreps) function. This will involve:
      1. Determining the optimization metric and direction (higher better/lower better) from get_control().
      2. Setting up resamples (e.g., rsample::bootstraps).
      3. Looping through param_grid.
      4. For each parameter set, looping through resamples:
        • Fitting the model (mspec$model$fit) on the analysis set.
        • Predicting (mspec$model$predict or mspec$model$prob) on the assessment set.
        • Calculating the chosen metric using yardstick.
      5. Averaging the metric across resamples for each parameter set.
      6. Selecting the param_grid row that yields the best average metric.
    • DoD:
      • tune_model() function implemented, no longer calls caret::train.
      • Uses rsample for resampling and yardstick for metric calculation.
      • Correctly identifies and returns the best parameter set from param_grid.
      • Unit tests for tune_model with simple mock models and grids verify its logic.
  2. [ ] Ticket #016: Integrate new tune_model() into train_model.mvpa_model

    • File(s): R/model_fit.R (train_model.mvpa_model)
    • Task: Modify train_model.mvpa_model so that if the current_param_grid (obtained from tune_grid(obj, ...)) has more than one row, it calls the new tune_model() to determine best_param. Otherwise, current_param_grid is used as best_param.
    • DoD:
      • train_model.mvpa_model correctly calls the new tune_model when appropriate.
      • The final model is fitted using the best_param determined by tune_model or the single provided parameter set.
      • Existing tests for models involving tuning (e.g., sda with its default grid, corclass with multiple options) pass and select reasonable parameters.
  3. [ ] Ticket #017: Phase 3 Integration Testing

    • File(s): N/A (Testing activity)
    • Task: Run R CMD check. Thoroughly test models with tuning grids. Verify that the tuning process completes and reasonable parameters are selected. This is a critical validation step.
    • DoD:
      • R CMD check passes.
      • Models with hyperparameter tuning (tune_grid in mvpa_model call results in multiple rows) train successfully.

Phase 4: Cleanup and Finalization

  1. [ ] Ticket #018: Update DESCRIPTION File - Final Dependencies

    • File(s): DESCRIPTION
    • Task:
      1. Remove caret from Imports:.
      2. If glmnet_opt still uses epsgo from c060, ensure c060 is in Imports:. Otherwise, remove it if epsgo usage is also refactored out (outside current scope).
      3. Ensure rsample and yardstick are in Imports:.
    • DoD:
      • DESCRIPTION file accurately reflects the new dependencies.
      • Package installs correctly with the updated dependencies.
  2. [ ] Ticket #019: Implement register_mvpa_model() Helper

    • File(s): R/common.R or R/classifiers.R
    • Task: Implement the register_mvpa_model(name, model_spec) function to allow users to add models to MVPAModels. Include basic validation for the model_spec structure.
    • DoD:
      • register_mvpa_model() function implemented and exported.
      • A simple test case demonstrates adding a custom model specification and then loading it with load_model().
  3. [ ] Ticket #020: Documentation Update

    • File(s): All Rd files, vignettes, README.
    • Task:
      1. Remove any mentions of caret models being loadable via load_model if they are not part of MVPAModels.
      2. Update documentation for functions whose behavior or dependencies changed (e.g., mvpa_model regarding tuning if internals changed significantly).
      3. Document the MVPAModels structure and the new register_mvpa_model() function.
      4. Ensure examples are updated and do not rely on implicit caret behavior.
    • DoD:
      • All user-facing documentation is accurate and reflects the removal of caret.
      • Examples are functional with the refactored code.
  4. [ ] Ticket #021: Final Comprehensive Check and Merge Preparation

    • File(s): Entire codebase.
    • Task:
      1. Perform a final R CMD check --as-cran.
      2. Review all changes for any lingering caret references or unintended consequences.
      3. Ensure all unit tests pass.
      4. Consider building vignettes to check for issues.
    • DoD:
      • R CMD check --as-cran is clean.
      • All tests pass.
      • Code review complete.
      • Branch is ready to be merged into the main development line.

This ticketed list provides a systematic way to track progress and ensure all aspects of the refactoring are addressed.



bbuchsbaum/rMVPA documentation built on June 10, 2025, 8:23 p.m.