R/rsa_model.R

Defines functions rsa_model print.rsa_design print.rsa_model merge_results.rsa_model train_model.rsa_model run_lm_semipartial run_lm_constrained run_cor run_lm check_collinearity run_rfit rsa_model_mat rsa_design sanitize

Documented in merge_results.rsa_model rsa_design rsa_model train_model.rsa_model

#' @noRd
#' @keywords internal
sanitize <- function(name) {
  name <- gsub(":", ".", name)
  name <- gsub(" ", "", name)
  name <- gsub("[\\(\\)]", ".", name, perl=TRUE)
  name <- gsub(",", "_", name)
  name <- gsub("\\.$", "", name)
  name
}


#' Construct a design for an RSA (Representational Similarity Analysis) model
#'
#' This function constructs a design for an RSA model using the provided formula, data, and optional parameters.
#'
#' @param formula A formula expression specifying the dissimilarity-based regression function.
#' @param data A named list containing the dissimilarity matrices and any other auxiliary variables.
#' @param block_var An optional \code{formula}, \code{character} name or \code{integer} vector designating the block structure.
#' @param split_by An optional \code{formula} indicating grouping structure for evaluating test performance.
#' @param keep_intra_run A \code{logical} indicating whether to include within-run comparisons (default: FALSE).
#' @return A list with class attributes "rsa_design" and "list", containing:
#'   \describe{
#'     \item{formula}{The input formula}
#'     \item{data}{The input data}
#'     \item{split_by}{The split_by formula}
#'     \item{split_groups}{Grouping structure for split_by}
#'     \item{block_var}{Block structure}
#'     \item{include}{Logical vector for including/excluding comparisons}
#'     \item{model_mat}{Model matrix generated by rsa_model_mat}
#'   }
#' @details
#' The function creates an RSA design based on the input parameters. It checks the validity of the input data and
#' handles splitting conditions for evaluation of test performance. It also processes optional block structures and
#' within-run comparisons.
#' @importFrom assertthat assert_that
#' @export
#' @examples
#' dismat <- dist(matrix(rnorm(100*100), 100, 100))
#' rdes <- rsa_design(~ dismat, list(dismat=dismat))
rsa_design <- function(formula, data, block_var=NULL, split_by=NULL, keep_intra_run=FALSE) {
  assert_that(purrr::is_formula(formula))
  
  # Check that all variables are either matrices, "dist", or vectors
  nr <- sapply(data, function(x) {
    if (is.matrix(x)) {
      nrow(x)
    } else if (inherits(x, "dist")) {
      attr(x, "Size")
    } else if (is.vector(x)) {
      length(x)
    } else {
      stop(paste("illegal variable type", class(x)))
    }
  })
  
  assert_that(all(nr == nr[1]), msg="all elements in 'data' must have the same number of rows")
  
  check_split <- function(split_var) {
    minSplits <- min(table(split_var))
    if (minSplits < 3) {
      stop(paste("error: splitting condition results in fewer than 3 observations in at least one set"))
    }
  }
  
  # Create split groups if split_by is provided
  split_groups <- if (!is.null(split_by)) {
    split_var <- parse_variable(split_by, data)
    split(seq_along(split_var), split_var)
  }
  
  # Process block_var if provided
  block_var <- if (!is.null(block_var)) {
    parse_variable(block_var, data)
  }
  
  # Include/exclude within-run comparisons based on keep_intra_run
  include <- if (!is.null(block_var) && !keep_intra_run) {
    as.vector(dist(block_var)) != 0
  }
  
  # Create the RSA design as a list
  des <- list(
    formula=formula,
    data=data,
    split_by=split_by,
    split_groups=split_groups,
    block_var=block_var,
    include=include
  )
  
  # Add model matrix to the design list
  mmat <- rsa_model_mat(des)
  des$model_mat <- mmat
  
  # Set the class attributes
  class(des) <- c("rsa_design", "list")
  
  # Return the RSA design
  des
}


#' Construct a model matrix for an RSA (Representational Similarity Analysis) design
#'
#' This function constructs a model matrix for the given RSA design by processing distance matrices and other variables.
#'
#' @param rsa_des An RSA design object created by \code{rsa_design()}, containing formula, data, and optional parameters.
#' @return A named list of vectors, where:
#'   \itemize{
#'     \item Names correspond to sanitized variable names from the formula
#'     \item Each vector is the processed version of the corresponding input data
#'     \item For distance matrices, only the lower triangle is included
#'     \item If rsa_des$include is specified, vectors are subset accordingly
#'   }
#' @details
#' The function takes an RSA design object as input and processes the distance matrices and other variables to
#' construct a model matrix. It handles different types of input matrices, including symmetric and asymmetric
#' distance matrices, and can include or exclude within-run comparisons based on the RSA design.
#' @examples
#' dismat <- dist(matrix(rnorm(100*100), 100, 100))
#' rdes <- rsa_design(~ dismat, list(dismat=dismat))
#' rsa_model_mat(rdes)
#' @keywords internal
#' @noRd
rsa_model_mat <- function(rsa_des) {
  rvars <- labels(terms(rsa_des$formula))
  denv <- list2env(rsa_des$data)
  vset <- lapply(rvars, function(x) eval(parse(text=x), envir=denv))
  
  # Process input variables to create vectors from distance matrices
  vmatlist <- lapply(vset, function(v) {
    if (inherits(v, "dist")) {
      as.vector(v)  # class "dist"
    } else if (isSymmetric(v)) {
      v[lower.tri(v)]
    } else {
      as.vector(dist(v))
    }
  })
  
  # Include or exclude within-run comparisons based on rsa_des$include
  if (!is.null(rsa_des$include)) {
    vmatlist <- lapply(vmatlist, function(v) v[rsa_des$include])
  }
  
  # Assign sanitized names
  names(vmatlist) <- sanitize(rvars)
  
  # Return the model matrix as a named list of vectors
  vmatlist
}


#' @keywords internal
#' @importFrom Rfit rfit
#' @noRd
run_rfit <- function(dvec, obj) {
  form <- paste("dvec", "~", paste(names(obj$design$model_mat), collapse = " + "))
  obj$design$model_mat$dvec <- dvec
  res <- Rfit::rfit(form, data=obj$design$model_mat)
  coef(res)[-1]
}


#' @keywords internal
#' @importFrom stats coef cor dist rnorm terms lm sd
#' @noRd
check_collinearity <- function(model_mat) {
  # Get the design matrix and names
  vnames <- names(model_mat)
  design_matrix <- as.matrix(model_mat)
  
  # Basic check - need at least 2 columns for collinearity
  if (ncol(design_matrix) > 1) {
    # Calculate correlation matrix and check for high correlations
    tryCatch({
      cor_matrix <- cor(design_matrix, use = "pairwise.complete.obs")
      diag(cor_matrix) <- 0
      if (any(abs(cor_matrix) > 0.99, na.rm = TRUE)) {
        high_cor_pairs <- which(abs(cor_matrix) > 0.99, arr.ind = TRUE)
        if (nrow(high_cor_pairs) > 0) {
          problem_vars <- apply(high_cor_pairs, 1, function(idx) {
            paste(vnames[c(idx[1], idx[2])], collapse = " and ")
          })
          stop(sprintf("Collinearity detected among predictors: %s. Consider removing one of the correlated variables.",
                       paste(problem_vars, collapse = "; ")))
        }
      }
      
      # Also check linear dependencies using QR decomposition
      qr_result <- qr(design_matrix)
      if (qr_result$rank < ncol(design_matrix)) {
        stop(sprintf("Design matrix is rank deficient (rank %d < %d columns). Some predictors are linear combinations of others.", 
                     qr_result$rank, ncol(design_matrix)))
      }
    }, error = function(e) {
      if (grepl("NA/NaN/Inf", e$message)) {
        stop("Cannot compute correlations due to NA/NaN/Inf values in predictor variables.")
      } else {
        stop(paste("Error checking collinearity:", e$message))
      }
    })
  }
}


#' @keywords internal
#' @importFrom stats coef cor dist rnorm terms lm sd
#' @noRd
run_lm <- function(dvec, obj) {
  # This is the standard LM approach that returns T-values for each predictor.
  form <- paste("dvec", "~", paste(names(obj$design$model_mat), collapse = " + "))
  vnames <- names(obj$design$model_mat)
  obj$design$model_mat$dvec <- dvec
  
  fit <- lm(form, data=obj$design$model_mat)
  # Return T-values (3rd column of coef summary)
  tvals <- coef(summary(fit))[-1, 3]
  names(tvals) <- vnames
  tvals
}


#' @keywords internal
#' @noRd
run_cor <- function(dvec, obj) {
  # For 'pearson' or 'spearman' regtype, we just do correlation with each predictor
  res <- sapply(obj$design$model_mat, function(x) cor(dvec, x, method=obj$distmethod))
  names(res) <- names(obj$design$model_mat)
  res
}


################################################################################
# NEW: Constrained LM with glmnet
################################################################################
#' @keywords internal
#' @importFrom glmnet glmnet
#' @importFrom stats predict
#' @noRd
run_lm_constrained <- function(dvec, obj) {
  # Check if glmnet is available and install if needed
  if (!requireNamespace("glmnet", quietly = TRUE)) {
    stop("Package 'glmnet' is required for non-negative constraints. Please install it with: install.packages('glmnet')")
  }
  
  # Convert list of vectors (predictors) into a matrix
  var_names <- names(obj$design$model_mat)
  X <- do.call(cbind, obj$design$model_mat)  # columns = predictors
  colnames(X) <- var_names
  
  # Identify which predictors should be constrained to be >= 0
  if (is.null(obj$nneg)) {
    stop("run_lm_constrained called, but obj$nneg is NULL.")
  }
  
  nneg_names <- names(obj$nneg)
  lower_lim_vec <- sapply(var_names, function(vn) {
    if (vn %in% nneg_names) 0 else -Inf
  })
  
  # Create penalty_factor vector - add a small penalty to non-negative terms
  # and zero penalty to unconstrained terms to avoid the "all penalty factors <= 0" error
  penalty_factor <- sapply(var_names, function(vn) {
    if (vn %in% nneg_names) 1e-5 else 0
  })
  
  # Use a small alpha (mixing parameter) to ensure stability
  alpha_val <- 1e-5
  
  # Fit glmnet with near-zero lambda and minimal penalty on non-negative variables
  fit <- tryCatch({
    glmnet::glmnet(
      x = X,
      y = dvec,
      alpha = alpha_val,                       
      lambda = 1e-5,                  
      penalty.factor = penalty_factor,
      lower.limits = lower_lim_vec,
      standardize = FALSE,
      intercept = TRUE
    )
  }, error = function(e) {
    # If error still occurs, try a different approach with all penalty factors = 1
    warning("First glmnet attempt failed, trying with uniform penalties: ", e$message)
    glmnet::glmnet(
      x = X,
      y = dvec,
      alpha = alpha_val,
      lambda = 1e-5,
      lower.limits = lower_lim_vec,
      standardize = FALSE,
      intercept = TRUE
    )
  })
  
  # Extract coefficients at smallest lambda
  lambda_min <- min(fit$lambda)
  
  # Properly extract coefficients - coef is S3 method
  coef_matrix <- predict(fit, type = "coefficients", s = lambda_min)
  # First row is intercept, we want rows 2:end
  betas <- as.numeric(coef_matrix)[-1]
  names(betas) <- var_names
  
  betas
}


################################################################################
# NEW: Semi-Partial LM
################################################################################
#' @keywords internal
#' @noRd
run_lm_semipartial <- function(dvec, obj) {
  # Compute semi-partial correlations from a single LM fit
  # sr_i = sign(t_i) * sqrt( (t_i^2 * MSE) / TSS )
  
  form <- paste("dvec", "~", paste(names(obj$design$model_mat), collapse = " + "))
  vnames <- names(obj$design$model_mat)
  obj$design$model_mat$dvec <- dvec
  
  fit  <- lm(form, data=obj$design$model_mat)
  smry <- summary(fit)
  
  # t-values (excluding intercept)
  tvals <- coef(smry)[-1, 3]
  
  # Residual MSE
  MSE <- smry$sigma^2
  
  # TSS = total sum of squares of the outcome
  y_mod <- fit$model$dvec
  TSS   <- sum((y_mod - mean(y_mod))^2)
  
  # Compute sr_i
  sr2 <- (tvals^2 * MSE) / TSS
  sr  <- sign(tvals) * sqrt(sr2)
  
  names(sr) <- vnames
  sr
}


################################################################################
# TRAINING METHOD
################################################################################
#' @rdname train_model
#' @param obj An object of class \code{rsa_model}.
#' @param train_dat The training data.
#' @param y The response variable.
#' @param indices The indices of the training data.
#' @param ... Additional arguments passed to the training method.
#' @return
#' Depending on \code{obj$regtype}:
#' \itemize{
#'   \item \code{"lm"} + no constraints + \code{obj$semipartial=TRUE}: semi-partial correlations
#'   \item \code{"lm"} + no constraints + \code{obj$semipartial=FALSE}: T-values of each predictor
#'   \item \code{"lm"} + \code{nneg} constraints: raw coefficients from constrained \code{glmnet}
#'   \item \code{"rfit"}: robust regression coefficients
#'   \item \code{"pearson"} or \code{"spearman"}: correlation coefficients
#' }
#' @method train_model rsa_model
#' @export
train_model.rsa_model <- function(obj, train_dat, y, indices, ...) {
  # 1) correlation-based distance
  dtrain <- 1 - cor(t(train_dat), method=obj$distmethod)
  dvec   <- dtrain[lower.tri(dtrain)]
  
  # 2) Exclude certain comparisons if needed
  if (!is.null(obj$design$include)) {
    dvec <- dvec[obj$design$include]
  }
  
  # 3) Switch on regtype + constraints + semipartial
  out <- switch(
    obj$regtype,
    
    # robust regression
    rfit = run_rfit(dvec, obj),
    
    # linear model
    lm = {
      has_nneg <- (!is.null(obj$nneg) && length(obj$nneg) > 0)
      
      if (has_nneg) {
        # Use constrained approach
        run_lm_constrained(dvec, obj)
      } else if (isTRUE(obj$semipartial)) {
        # Semi-partial correlations
        run_lm_semipartial(dvec, obj)
      } else {
        # Standard approach = t-values
        run_lm(dvec, obj)
      }
    },
    
    # correlation-based
    pearson  = run_cor(dvec, obj),
    spearman = run_cor(dvec, obj)
  )
  
  out
}


#' Merge Results for RSA Model
#'
#' This function takes the computed coefficients/correlations/t-values 
#' from \code{train_model.rsa_model} for a single ROI/searchlight and formats 
#' it into the standard output tibble.
#'
#' @param obj The RSA model specification.
#' @param result_set A tibble containing the results from the processor function for the current ROI/sphere. 
#'   Expected to have a \code{$result} column containing the named vector from \code{train_model.rsa_model}.
#' @param indices Voxel indices for the current ROI/searchlight sphere.
#' @param id Identifier for the current ROI/searchlight center.
#' @param ... Additional arguments (ignored).
#'
#' @return A tibble row with the formatted "performance" metrics (coefficients/t-values/correlations) for the ROI/sphere.
#' @importFrom tibble tibble
#' @importFrom futile.logger flog.error
#' @rdname merge_results-methods
#' @method merge_results rsa_model
#' @rdname merge_results-methods
merge_results.rsa_model <- function(obj, result_set, indices, id, ...) {
  
  # Check for errors from previous steps (processor/train_model)
  if (any(result_set$error)) {
    emessage <- result_set$error_message[which(result_set$error)[1]]
    # Return standard error tibble structure
    return(
      tibble::tibble(
        result       = list(NULL), 
        indices      = list(indices),
        performance  = list(NULL), 
        id           = id,
        error        = TRUE,
        error_message= emessage
      )
    )
  }
  
  # Extract the model outputs (named vector) computed by train_model.
  if (!"result" %in% names(result_set) || length(result_set$result) == 0 || is.null(result_set$result[[1]])) {
     error_msg <- sprintf("merge_results (rsa_model): result_set missing or has NULL/empty 'result' field where model outputs were expected for ROI/ID %s.", id)
     futile.logger::flog.error(error_msg)
     
     # Create NA performance matrix based on expected names from the design
     expected_names <- names(obj$design$model_mat)
     perf_mat <- matrix(NA_real_, nrow=1, ncol=length(expected_names), dimnames=list(NULL, expected_names))
     if (length(expected_names) == 0) perf_mat <- NULL 
     
     return(tibble::tibble(result=list(NULL), indices=list(indices), performance=list(perf_mat), 
                           id=id, error=TRUE, error_message=error_msg))
  }
  
  model_outputs <- result_set$result[[1]]
  
  # Validate the extracted results
  if (!is.numeric(model_outputs) || is.null(names(model_outputs)) || length(model_outputs) == 0) {
      error_msg <- sprintf("merge_results (rsa_model): Extracted model outputs are not a named non-empty numeric vector for ROI/ID %s.", id)
      futile.logger::flog.error(error_msg)
      
      # Attempt to use names from model_outputs if partially valid, otherwise from design
      current_names <- names(model_outputs)
      if (is.null(current_names) || length(current_names) == 0) {
        current_names <- names(obj$design$model_mat)
      }
      
      perf_mat <- matrix(NA_real_, nrow=1, ncol=length(current_names), dimnames=list(NULL, current_names))
      if (length(current_names) == 0) perf_mat <- NULL
      
      return(tibble::tibble(result=list(NULL), indices=list(indices), performance=list(perf_mat), 
                            id=id, error=TRUE, error_message=error_msg))
  }
  
  # Format the results into a 1-row matrix (performance matrix)
  perf_mat <- matrix(model_outputs, nrow = 1, dimnames = list(NULL, names(model_outputs)))
  
  # Return the standard tibble structure
  tibble::tibble(
    result      = list(NULL), # Model outputs are now in 'performance'
    indices     = list(indices),
    performance = list(perf_mat),
    id          = id,
    error       = FALSE,
    error_message = "~" 
  )
}


################################################################################
# PRINT METHODS
################################################################################
#' @export
#' @method print rsa_model
print.rsa_model <- function(x, ...) {
  # Print header
  cat("\\n", "RSA Model", "\\n")
  cat(rep("-", 20), "\\n\\n")

  # Model configuration
  cat("Configuration:\\n")
  cat("  |- Distance Method: ", x$distmethod, "\\n")
  cat("  |- Regression Type: ", x$regtype, "\\n")

  # If nonneg constraints are present
  if (!is.null(x$nneg) && length(x$nneg) > 0) {
    cat("  |- Non-negativity on: ", paste(names(x$nneg), collapse=", "), "\\n")
  }

  # If semipartial
  if (isTRUE(x$semipartial) && (is.null(x$nneg) || length(x$nneg) == 0)) {
    cat("  |- Semi-partial: TRUE\\n")
  }
  cat("\\n")

  # Dataset info
  cat("Dataset:\\n")
  dims <- dim(x$dataset$train_data)
  dim_str <- paste0(
    paste(dims[-length(dims)], collapse=" x "),
    " x ", dims[length(dims)], " observations"
  )
  cat("  |- Dimensions: ", dim_str, "\\n")
  cat("  |- Type: ", class(x$dataset$train_data)[1], "\\n")
  cat("\\n")

  # Design info
  cat("Design:\\n")
  cat("  |- Formula: ", deparse(x$design$formula), "\\n")

  var_names <- names(x$design$model_mat)
  cat("  |- Predictors: ", paste(var_names, collapse=", "), "\\n")
  cat("\\n")

  # Structure info
  cat("Structure:\\n")

  # Block info
  if (!is.null(x$design$block_var)) {
    blocks <- table(x$design$block_var)
    cat("  |- Blocking: Present\\n")
    cat("  |- Number of Blocks: ", length(blocks), "\\n")
    cat("  |- Mean Block Size: ",
        format(mean(blocks), digits=2),
        " (SD: ",
        format(sd(blocks), digits=2),
        ")\\n")
  } else {
    cat("  |- Blocking: None\\n")
  }

  # Split info
  if (!is.null(x$design$split_by)) {
    split_info <- length(x$design$split_groups)
    cat("  |- Split Groups: ", split_info, "\\n")
  } else {
    cat("  |- Split Groups: None\\n")
  }

  cat("\\n")
}


#' @export
#' @method print rsa_design
print.rsa_design <- function(x, ...) {
  # Print header
  cat("\\n", "RSA Design", "\\n")
  cat(rep("-", 20), "\\n\\n")

  # Formula
  cat("Formula:\\n")
  cat("  |- ", deparse(x$formula), "\\n\\n")

  # Variables
  cat("Variables:\\n")
  var_types <- sapply(x$data, function(v) {
    if (inherits(v, "dist")) "distance matrix"
    else if (is.matrix(v)) "matrix"
    else if (is.vector(v)) "vector"
    else "other"
  })

  cat("  |- Total Variables: ", length(x$data), "\\n")
  for (i in seq_along(x$data)) {
    prefix <- if (i == length(x$data)) "  |-" else "  |-"
    cat(sprintf("%s %s: %s\\n", prefix, names(x$data)[i], var_types[i]))
  }
  cat("\\n")

  # Structure
  cat("Structure:\\n")

  # Block info
  if (!is.null(x$block_var)) {
    blocks <- table(x$block_var)
    cat("  |- Blocking: Present\\n")
    cat("  |- Number of Blocks: ", length(blocks), "\\n")
    cat("  |- Block Sizes: ",
        paste0(names(blocks), ": ", blocks, collapse=", "), "\\n")
  } else {
    cat("  |- Blocking: None\\n")
  }

  # Include/exclude info
  if (!is.null(x$include)) {
    n_comparisons <- length(x$include)
    n_included <- sum(x$include)
    cat("  |- Comparisons: ",
        n_included,
        " of ",
        n_comparisons,
        sprintf(" (%.1f%%)", 100*n_included/n_comparisons), "\\n")
  } else {
    cat("  |- Comparisons: All included\\n")
  }

  cat("\\n")
}


#' Construct an RSA (Representational Similarity Analysis) model
#'
#' This function creates an RSA model object by taking an MVPA (Multi-Variate Pattern Analysis) dataset and an RSA design.
#'
#' @param dataset An instance of an \code{mvpa_dataset}.
#' @param design An instance of an \code{rsa_design} created by \code{rsa_design()}.
#' @param distmethod A character string specifying the method used to compute distances between observations. 
#'        One of: \code{"pearson"} or \code{"spearman"} (defaults to "spearman").
#' @param regtype A character string specifying the analysis method. 
#'        One of: \code{"pearson"}, \code{"spearman"}, \code{"lm"}, or \code{"rfit"} (defaults to "pearson").
#' @param check_collinearity Logical indicating whether to check for collinearity in the design matrix. 
#'        Only applies when \code{regtype="lm"}. Default is TRUE.
#' @param nneg A named list of variables (predictors) for which non-negative regression coefficients should be enforced 
#'        (only if \code{regtype="lm"}). Defaults to \code{NULL} (no constraints).
#' @param semipartial Logical indicating whether to compute semi-partial correlations in the \code{"lm"} case 
#'        (only if \code{nneg} is not used). Defaults to \code{FALSE}.
#'
#' @return An object of class \code{"rsa_model"} (and \code{"list"}), containing:
#' \itemize{
#'   \item \code{dataset}    : the input dataset
#'   \item \code{design}     : the RSA design
#'   \item \code{distmethod} : the distance method used
#'   \item \code{regtype}    : the regression type
#'   \item \code{nneg}       : a named list of constrained variables, if any
#'   \item \code{semipartial}: whether to compute semi-partial correlations
#' }
#' @examples
#' # Create a random MVPA dataset
#' data <- matrix(rnorm(100 * 100), 100, 100)
#' labels <- factor(rep(1:2, each = 50))
#' mvpa_data <- mvpa_dataset(data, labels)
#'
#' # Create an RSA design with two distance matrices
#' dismat1 <- dist(data)
#' dismat2 <- dist(matrix(rnorm(100*100), 100, 100))
#' rdes <- rsa_design(~ dismat1 + dismat2, list(dismat1=dismat1, dismat2=dismat2))
#'
#' # Create an RSA model with standard 'lm' (returns t-values):
#' rsa_mod <- rsa_model(mvpa_data, rdes, regtype="lm")
#'
#' # Create an RSA model enforcing non-negativity for dismat2 only:
#' # Requires the 'glmnet' package to be installed
#' # rsa_mod_nneg <- rsa_model(mvpa_data, rdes, regtype="lm",
#' #                          nneg = list(dismat2 = TRUE))
#'
#' # Create an RSA model using 'lm' but returning semi-partial correlations:
#' rsa_mod_sp <- rsa_model(mvpa_data, rdes, regtype="lm",
#'                         semipartial = TRUE)
#'
#' # Train the model
#' fit_params <- train_model(rsa_mod_sp, mvpa_data$train_data)
#' # 'fit_params' = named vector of semi-partial correlations for each predictor
#'
#' @export
rsa_model <- function(dataset, 
                      design, 
                      distmethod = "spearman", 
                      regtype = "pearson", 
                      check_collinearity = TRUE,
                      nneg = NULL,
                      semipartial = FALSE) {
  
  assert_that(inherits(dataset, "mvpa_dataset"))
  assert_that(inherits(design, "rsa_design"))
  
  distmethod <- match.arg(distmethod, c("pearson", "spearman"))
  regtype    <- match.arg(regtype, c("pearson", "spearman", "lm", "rfit"))
  
  # Check for glmnet if nneg constraints are provided
  if (!is.null(nneg) && length(nneg) > 0 && regtype == "lm") {
    if (!requireNamespace("glmnet", quietly = TRUE)) {
      stop("Package 'glmnet' is required for non-negative constraints. Please install it with: install.packages('glmnet')")
    }
  }
  
  # If using LM, optionally check for collinearity
  if (regtype == "lm" && check_collinearity) {
    message("Checking design matrix for collinearity...")
    check_collinearity(design$model_mat)
    message("Collinearity check passed.")
  }

  # Verify that the number of observations in the dataset matches
  # the size expected by the RSA design. This prevents obscure
  # dimension errors later during model fitting.
  nobs_dataset <- {
    dims <- dim(dataset$train_data)
    if (is.null(dims)) length(dataset$train_data) else dims[length(dims)]
  }
  expected_len <- choose(nobs_dataset, 2)
  if (!is.null(design$include)) {
    expected_len <- sum(design$include)
  }
  model_vec_len <- length(design$model_mat[[1]])
  if (model_vec_len != expected_len) {
    stop(sprintf(
      "Mismatch between dataset observations (%s) and RSA design length (%s).",
      nobs_dataset, model_vec_len
    ))
  }
  
  # Create the RSA model object
  obj <- create_model_spec(
    "rsa_model", 
    dataset, 
    design, 
    distmethod = distmethod, 
    regtype    = regtype,
    nneg       = nneg,
    semipartial = semipartial
  )
  obj
}
bbuchsbaum/rMVPA documentation built on June 10, 2025, 8:23 p.m.