R/fit_earth.R

Defines functions fit_earth

Documented in fit_earth

#' Fit an earth model
#'
#' Wrapper around [earth::earth()] with parameter validation and automatic
#' cross-validation when interaction terms are enabled.
#'
#' @param df A data frame containing the modeling data.
#' @param target Character string. Name of the response variable.
#' @param predictors Character vector. Names of predictor variables.
#' @param categoricals Character vector. Names of predictors to treat as
#'   categorical (converted to factors before fitting). Default is `NULL`.
#' @param linpreds Character vector. Names of predictors constrained to enter
#'   the model linearly (no hinge functions). Default is `NULL`.
#' @param type_map Named list or character vector. Maps column names to
#'   declared types (e.g., `"numeric"`, `"Date"`, `"factor"`). When provided,
#'   columns are coerced before fitting: Date/POSIXct columns are converted
#'   to numeric (days/seconds since epoch), and type-derived categoricals are
#'   merged into `categoricals`. Default is `NULL` (no coercion).
#' @param degree Integer. Maximum degree of interaction. Default is 1
#'   (no interactions). When >= 2, cross-validation is automatically enabled.
#' @param allowed_func Function or `NULL`. An allowed function as returned by
#'   [build_allowed_function()]. Only used when `degree >= 2`.
#' @param allowed_matrix Logical matrix or `NULL`. The allowed interaction
#'   matrix. Stored in the result for report export. Not used for fitting
#'   (use `allowed_func` instead).
#' @param nfold Integer. Number of cross-validation folds. Automatically set
#'   to 10 when `degree >= 2` unless explicitly provided. Set to 0 to disable.
#' @param nprune Integer or `NULL`. Maximum number of terms in the pruned model.
#' @param thresh Numeric. Forward stepping threshold. Default is earth's default
#'   (0.001).
#' @param penalty Numeric. Generalized cross-validation penalty per knot.
#'   Default is earth's default (if `degree > 1`, `3`; otherwise `2`).
#' @param minspan Integer or `NULL`. Minimum number of observations between
#'   knots.
#' @param endspan Integer or `NULL`. Minimum number of observations before
#'   the first and after the last knot.
#' @param fast.k Integer. Maximum number of parent terms considered at each
#'   step of the forward pass. Default is earth's default (20).
#' @param pmethod Character. Pruning method. One of `"backward"`, `"none"`,
#'   `"exhaustive"`, `"forward"`, `"seqrep"`, `"cv"`. Default is `"backward"`.
#' @param glm List or `NULL`. If provided, passed to earth's `glm` argument
#'   to fit a GLM on the earth basis functions.
#' @param trace Numeric. Trace earth's execution. 0 (default) = none,
#'   0.3 = variance model, 0.5 = cross validation, 1-5 = increasing detail.
#' @param nk Integer or `NULL`. Maximum number of model terms before pruning.
#' @param newvar.penalty Numeric or `NULL`. Penalty for adding a new variable
#'   in the forward pass (Friedman's gamma). Default 0.
#' @param fast.beta Numeric or `NULL`. Fast MARS ageing coefficient. Default 1.
#' @param ncross Integer or `NULL`. Number of cross-validations. Default 1.
#' @param stratify Logical or `NULL`. Stratify cross-validation samples.
#'   Default `TRUE`.
#' @param varmod.method Character or `NULL`. Variance model method. One of
#'   `"none"`, `"const"`, `"lm"`, `"rlm"`, `"earth"`, `"gam"`, `"power"`,
#'   `"power0"`, `"x.lm"`, `"x.rlm"`, `"x.earth"`, `"x.gam"`.
#' @param varmod.exponent Numeric or `NULL`. Power transform for variance model.
#' @param varmod.conv Numeric or `NULL`. Convergence criterion for IRLS.
#' @param varmod.clamp Numeric or `NULL`. Minimum estimated standard deviation.
#' @param varmod.minspan Integer or `NULL`. minspan for internal variance model.
#' @param keepxy Logical or `NULL`. Retain x, y in model object. Default `FALSE`.
#' @param Scale.y Logical or `NULL`. Scale response internally. Default `TRUE`.
#' @param Adjust.endspan Numeric or `NULL`. Interaction endspan multiplier.
#'   Default 2.
#' @param Auto.linpreds Logical or `NULL`. Auto-detect linear predictors.
#'   Default `TRUE`.
#' @param Force.weights Logical or `NULL`. Force weighted code path. Default
#'   `FALSE`.
#' @param Use.beta.cache Logical or `NULL`. Cache coefficients in forward pass.
#'   Default `TRUE`.
#' @param Force.xtx.prune Logical or `NULL`. Force X'X-based pruning. Default
#'   `FALSE`.
#' @param Get.leverages Logical or `NULL`. Calculate hat values. Default `TRUE`.
#' @param Exhaustive.tol Numeric or `NULL`. Condition number threshold for
#'   exhaustive pruning. Default 1e-10.
#' @param wp Numeric vector or `NULL`. Response weights.
#' @param weights Numeric vector or `NULL`. Case weights passed to earth.
#' @param ... Additional arguments passed to [earth::earth()].
#' @param .capture_trace Logical. If `TRUE` (default), capture earth's trace
#'   output. Set to `FALSE` when running in a background process.
#'
#' @return A list with class `"earthUI_result"` containing:
#'   \describe{
#'     \item{model}{The fitted earth model object.}
#'     \item{target}{Name of the response variable.}
#'     \item{predictors}{Names of predictor variables used.}
#'     \item{categoricals}{Names of categorical predictors.}
#'     \item{degree}{Degree of interaction used.}
#'     \item{cv_enabled}{Logical; whether cross-validation was used.}
#'     \item{data}{The data frame used for fitting.}
#'   }
#'
#' @export
#' @examples
#' \donttest{
#' # Using the included demo appraisal dataset
#' demo_file <- system.file("extdata", "Appraisal_1.csv", package = "earthUI")
#' df <- import_data(demo_file)
#' result <- fit_earth(df, target = "sale_price",
#'                     predictors = c("living_sqft", "lot_size", "age"))
#' format_summary(result)
#' }
fit_earth <- function(df, target, predictors, categoricals = NULL,
                      linpreds = NULL, type_map = NULL,
                      degree = 1L, allowed_func = NULL,
                      allowed_matrix = NULL,
                      nfold = NULL, nprune = NULL, thresh = NULL,
                      penalty = NULL, minspan = NULL, endspan = NULL,
                      fast.k = NULL, pmethod = NULL, glm = NULL,
                      trace = NULL, nk = NULL, newvar.penalty = NULL,
                      fast.beta = NULL, ncross = NULL, stratify = NULL,
                      varmod.method = NULL, varmod.exponent = NULL,
                      varmod.conv = NULL, varmod.clamp = NULL,
                      varmod.minspan = NULL, keepxy = NULL,
                      Scale.y = NULL, Adjust.endspan = NULL,
                      Auto.linpreds = NULL, Force.weights = NULL,
                      Use.beta.cache = NULL, Force.xtx.prune = NULL,
                      Get.leverages = NULL, Exhaustive.tol = NULL,
                      wp = NULL, weights = NULL, ..., .capture_trace = TRUE) {

  # --- Input validation ---
  if (!is.data.frame(df)) {
    stop("`df` must be a data frame.", call. = FALSE)
  }
  if (!is.character(target) || length(target) < 1L) {
    stop("`target` must be a character vector of one or more variable names.",
         call. = FALSE)
  }
  missing_targets <- setdiff(target, names(df))
  if (length(missing_targets) > 0L) {
    stop("Target variable(s) not found in data frame: ",
         paste(missing_targets, collapse = ", "), call. = FALSE)
  }
  if (!is.character(predictors) || length(predictors) == 0L) {
    stop("`predictors` must be a non-empty character vector.", call. = FALSE)
  }
  missing_preds <- setdiff(predictors, names(df))
  if (length(missing_preds) > 0L) {
    stop("Predictor(s) not found in data frame: ",
         paste(missing_preds, collapse = ", "), call. = FALSE)
  }
  overlap <- intersect(target, predictors)
  if (length(overlap) > 0L) {
    stop("Target variable(s) must not be in the predictors list: ",
         paste(overlap, collapse = ", "), call. = FALSE)
  }

  degree <- as.integer(degree)
  if (degree < 1L || degree > 10L) {
    stop("`degree` must be between 1 and 10.", call. = FALSE)
  }

  # --- Prepare data ---
  multi_response <- length(target) > 1L
  model_df <- df[, c(target, predictors), drop = FALSE]

  # Apply declared type coercions (Date->numeric, logical, factor, etc.)
  if (!is.null(type_map)) {
    model_df <- coerce_types_(model_df, type_map, predictors)
    # Merge type-derived categoricals
    type_cats <- names(type_map)[vapply(type_map, function(t) {
      t %in% c("factor", "character")
    }, logical(1L))]
    type_cats <- intersect(type_cats, predictors)
    categoricals <- union(categoricals, type_cats)
  }

  # Convert categoricals to factors
  if (!is.null(categoricals)) {
    categoricals <- intersect(categoricals, predictors)
    for (col in categoricals) {
      model_df[[col]] <- as.factor(model_df[[col]])
    }
  }

  # Remove rows with NA in target or predictors
  complete <- stats::complete.cases(model_df)
  n_removed <- sum(!complete)
  if (n_removed > 0L) {
    message("Removed ", n_removed, " rows with missing values.")
    model_df <- model_df[complete, , drop = FALSE]
    # Subset row weights to match remaining rows
    if (!is.null(weights)) weights <- weights[complete]
  }
  if (nrow(model_df) < 10L) {
    na_counts <- vapply(model_df, function(col) sum(is.na(col)), integer(1L))
    na_info <- na_counts[na_counts > 0L]
    detail <- if (length(na_info) > 0L) {
      paste0(" Columns with NAs: ",
             paste(sprintf("%s (%d)", names(na_info), na_info), collapse = ", "),
             ".")
    } else {
      ""
    }
    stop("Insufficient data: need at least 10 complete observations, have ",
         nrow(model_df), " (from ", nrow(df), " original rows).", detail,
         call. = FALSE)
  }

  # Drop factor/character columns with fewer than 2 unique values (causes contrasts error)
  drop_cols <- character(0)
  for (col in names(model_df)) {
    if (col %in% target) next
    if (is.factor(model_df[[col]])) {
      if (nlevels(droplevels(model_df[[col]])) < 2L) {
        drop_cols <- c(drop_cols, col)
      }
    } else if (is.character(model_df[[col]])) {
      if (length(unique(model_df[[col]])) < 2L) {
        drop_cols <- c(drop_cols, col)
      }
    }
  }
  if (length(drop_cols) > 0L) {
    message("Dropped factor columns with < 2 levels: ",
            paste(drop_cols, collapse = ", "))
    model_df <- model_df[, !names(model_df) %in% drop_cols, drop = FALSE]
    predictors <- setdiff(predictors, drop_cols)
    if (!is.null(categoricals)) {
      categoricals <- setdiff(categoricals, drop_cols)
    }
    if (length(predictors) == 0L) {
      stop("No predictors remaining after dropping single-level factors.",
           call. = FALSE)
    }
  }

  # Drop unused factor levels
  for (col in names(model_df)) {
    if (is.factor(model_df[[col]])) {
      model_df[[col]] <- droplevels(model_df[[col]])
    }
  }

  # --- Build earth arguments ---
  if (multi_response) {
    lhs <- paste0("cbind(", paste0("`", target, "`", collapse = ", "), ")")
    formula <- stats::as.formula(paste(lhs, "~ ."))
  } else {
    formula <- stats::as.formula(paste("`", target, "` ~ .", sep = ""))
  }

  earth_args <- list(formula = formula, data = model_df, degree = degree)

  # Auto-enable CV when degree >= 2 or variance model is requested
  needs_cv <- degree >= 2L ||
    (!is.null(varmod.method) && varmod.method != "none")
  cv_enabled <- FALSE
  if (!is.null(nfold)) {
    if (nfold > 0L) {
      earth_args$nfold <- as.integer(nfold)
      cv_enabled <- TRUE
    }
  } else if (needs_cv) {
    earth_args$nfold <- 10L
    cv_enabled <- TRUE
  }

  # Linear predictors (no hinge functions)
  if (!is.null(linpreds) && length(linpreds) > 0L) {
    linpreds <- intersect(linpreds, names(model_df))
    if (length(linpreds) > 0L) {
      # earth expects column indices (1-based, relative to predictor columns)
      pred_cols <- setdiff(names(model_df), target)
      lin_idx <- match(linpreds, pred_cols)
      lin_idx <- lin_idx[!is.na(lin_idx)]
      if (length(lin_idx) > 0L) {
        earth_args$linpreds <- lin_idx
      }
    }
  }

  if (!is.null(nprune))          earth_args$nprune          <- as.integer(nprune)
  if (!is.null(thresh))          earth_args$thresh           <- thresh
  if (!is.null(penalty))         earth_args$penalty          <- penalty
  if (!is.null(minspan))         earth_args$minspan          <- as.integer(minspan)
  if (!is.null(endspan))         earth_args$endspan          <- as.integer(endspan)
  if (!is.null(fast.k))          earth_args$fast.k           <- as.integer(fast.k)
  if (!is.null(pmethod))         earth_args$pmethod          <- pmethod
  if (!is.null(glm))             earth_args$glm              <- glm
  if (!is.null(trace))           earth_args$trace            <- trace
  if (!is.null(nk))              earth_args$nk               <- as.integer(nk)
  if (!is.null(newvar.penalty))  earth_args$newvar.penalty   <- newvar.penalty
  if (!is.null(fast.beta))       earth_args$fast.beta        <- fast.beta
  if (!is.null(ncross)) {
    # Auto-increase ncross when variance model is enabled and ncross < 3
    if (needs_cv && !is.null(varmod.method) && varmod.method != "none" &&
        as.integer(ncross) < 3L) {
      earth_args$ncross <- 3L
    } else {
      earth_args$ncross <- as.integer(ncross)
    }
  }
  if (!is.null(stratify))        earth_args$stratify         <- stratify
  if (!is.null(varmod.method) && varmod.method != "none") {
    earth_args$varmod.method <- varmod.method
  }
  if (!is.null(varmod.exponent)) earth_args$varmod.exponent  <- varmod.exponent
  if (!is.null(varmod.conv))     earth_args$varmod.conv      <- varmod.conv
  if (!is.null(varmod.clamp))    earth_args$varmod.clamp     <- varmod.clamp
  if (!is.null(varmod.minspan))  earth_args$varmod.minspan   <- as.integer(varmod.minspan)
  if (!is.null(keepxy))          earth_args$keepxy           <- keepxy
  if (!is.null(Scale.y))         earth_args$Scale.y          <- Scale.y
  if (!is.null(Adjust.endspan))  earth_args$Adjust.endspan   <- Adjust.endspan
  if (!is.null(Auto.linpreds))   earth_args$Auto.linpreds    <- Auto.linpreds
  if (!is.null(Force.weights))   earth_args$Force.weights    <- Force.weights
  if (!is.null(Use.beta.cache))  earth_args$Use.beta.cache   <- Use.beta.cache
  if (!is.null(Force.xtx.prune)) earth_args$Force.xtx.prune  <- Force.xtx.prune
  if (!is.null(Get.leverages))   earth_args$Get.leverages    <- Get.leverages
  if (!is.null(Exhaustive.tol))  earth_args$Exhaustive.tol   <- Exhaustive.tol
  if (!is.null(wp))              earth_args$wp               <- wp
  if (!is.null(weights))         earth_args$weights          <- weights

  if (degree >= 2L && !is.null(allowed_func)) {
    earth_args$allowed <- allowed_func
  }

  # Merge additional ... args
  dots <- list(...)
  earth_args <- c(earth_args, dots)

  # --- Fit model (with timing and trace capture) ---
  start_time <- proc.time()
  if (.capture_trace) {
    trace_output <- utils::capture.output({
      model <- do.call(earth::earth, earth_args)
    })
  } else {
    # Let trace go to stdout (for callr background process capture)
    model <- do.call(earth::earth, earth_args)
    trace_output <- character(0)
  }
  elapsed <- as.numeric((proc.time() - start_time)["elapsed"])

  # --- Return structured result ---
  result <- list(
    model          = model,
    target         = target,
    predictors     = predictors,
    categoricals   = if (is.null(categoricals)) character(0) else categoricals,
    linpreds       = if (is.null(linpreds)) character(0) else linpreds,
    degree         = degree,
    cv_enabled     = cv_enabled,
    allowed_matrix = allowed_matrix,
    data           = model_df,
    elapsed        = elapsed,
    trace_output   = trace_output
  )
  class(result) <- "earthUI_result"
  result
}

Try the earthUI package in your browser

Any scripts or data that you put into this service are public.

earthUI documentation built on March 26, 2026, 1:07 a.m.