inst/benchmarks/run_benchmark_suite.R

#!/usr/bin/env Rscript

#Rscript bartMachine/inst/benchmarks/run_benchmark_suite.R  --java-params="-Xmx10g"  --cores=4  --folds=5  --repeats=3


opts <- list(
  output_dir = file.path("inst", "benchmarks", "results"),
  folds = 5,
  repeats = 1,
  seed = 123,
  cores = 1,
  datasets = NULL,
  packages = NULL,
  max_datasets = Inf,
  skip_tags = character(0),
  list = FALSE,
  dry_run = FALSE,
  java_params = Sys.getenv("BART_JAVA_PARAMS", ""),
  use_xoshiro = FALSE,
  bart_num_trees = 200,
  bart_burn_in = 800,
  bart_iter = 500,
  rf_ntree = 200,
  rf_mtry = NA_real_
)

parse_args <- function(args) {
  opts <- list()
  for (arg in args) {
    if (arg == "--list") {
      opts$list <- TRUE
      next
    }
    if (arg == "--dry-run") {
      opts$dry_run <- TRUE
      next
    }
    if (arg == "--help") {
      opts$help <- TRUE
      next
    }
    if (arg == "--use-xoshiro") {
      opts$use_xoshiro <- TRUE
      next
    }
    if (grepl("^--", arg)) {
      key_val <- sub("^--", "", arg)
      parts <- strsplit(key_val, "=", fixed = TRUE)[[1]]
      key <- parts[1]
      val <- if (length(parts) > 1) parts[2] else ""
      opts[[key]] <- val
    }
  }
  opts
}

get_script_dir <- function() {
  args <- commandArgs(trailingOnly = FALSE)
  file_arg <- "--file="
  match <- grep(file_arg, args)
  if (length(match) == 0) {
    return(getwd())
  }
  file_path <- sub(file_arg, "", args[match[1]])
  dirname(normalizePath(file_path))
}

print_usage <- function() {
  cat("Usage: Rscript run_benchmark_suite.R [options]\n")
  cat("Options:\n")
  cat("  --list                    List available datasets and exit\n")
  cat("  --dry-run                 List selected datasets and exit\n")
  cat("  --output-dir=PATH         Output directory (default: inst/benchmarks/results)\n")
  cat("  --folds=N                 Number of folds (default: 5)\n")
  cat("  --repeats=N               Number of CV repeats (default: 1)\n")
  cat("  --seed=N                  Random seed (default: 123)\n")
  cat("  --datasets=A,B            Only run specific dataset names\n")
  cat("  --packages=A,B            Only run datasets from packages\n")
  cat("  --max-datasets=N          Limit the number of datasets\n")
  cat("  --skip-tags=A,B           Skip datasets with these tags\n")
  cat("  --cores=N                 bartMachine cores (default: 1)\n")
  cat("  --java-params=STRING      Set options(java.parameters)\n")
  cat("  --use-xoshiro             Use Xoshiro256PlusPlus RNG (JDK 17+)\n")
  cat("  --bart-num-trees=N        bartMachine num_trees\n")
  cat("  --bart-burn-in=N          bartMachine num_burn_in\n")
  cat("  --bart-iter=N             bartMachine num_iterations_after_burn_in\n")
  cat("  --rf-ntree=N              randomForest ntree\n")
  cat("  --rf-mtry=N               randomForest mtry\n")
}

as_int <- function(value, default) {
  if (is.null(value) || !nzchar(value)) {
    return(default)
  }
  as.integer(value)
}

as_char_vec <- function(value) {
  if (is.null(value) || !nzchar(value)) {
    return(NULL)
  }
  strsplit(value, ",", fixed = TRUE)[[1]]
}

script_dir <- get_script_dir()
registry_path <- file.path(script_dir, "benchmark_registry.R")
if (!file.exists(registry_path)) {
  stop("Cannot find benchmark registry at: ", registry_path)
}
invisible(source(registry_path, local = TRUE))

raw_opts <- parse_args(commandArgs(trailingOnly = TRUE))
if (isTRUE(raw_opts$help)) {
  print_usage()
  quit(status = 0)
}

if (!is.null(raw_opts$output_dir)) {
  opts$output_dir <- raw_opts$output_dir
}
opts$folds <- as_int(raw_opts$folds, opts$folds)
opts$repeats <- as_int(raw_opts$repeats, opts$repeats)
opts$seed <- as_int(raw_opts$seed, opts$seed)
opts$cores <- as_int(raw_opts$cores, opts$cores)
opts$datasets <- as_char_vec(raw_opts$datasets)
opts$packages <- as_char_vec(raw_opts$packages)
opts$max_datasets <- as_int(raw_opts$`max-datasets`, opts$max_datasets)
opts$skip_tags <- as_char_vec(raw_opts$`skip-tags`)
opts$java_params <- if (!is.null(raw_opts$`java-params`)) raw_opts$`java-params` else opts$java_params
opts$bart_num_trees <- as_int(raw_opts$`bart-num-trees`, opts$bart_num_trees)
opts$bart_burn_in <- as_int(raw_opts$`bart-burn-in`, opts$bart_burn_in)
opts$bart_iter <- as_int(raw_opts$`bart-iter`, opts$bart_iter)
opts$rf_ntree <- as_int(raw_opts$`rf-ntree`, opts$rf_ntree)
if (!is.null(raw_opts$`rf-mtry`) && nzchar(raw_opts$`rf-mtry`)) {
  opts$rf_mtry <- as.numeric(raw_opts$`rf-mtry`)
}
if (!is.null(raw_opts$use_xoshiro)) {
  opts$use_xoshiro <- isTRUE(raw_opts$use_xoshiro)
}
opts$list <- isTRUE(raw_opts$list)
opts$dry_run <- isTRUE(raw_opts$dry_run)

print_registry <- function(registry) {
  entries <- vapply(
    registry,
    function(item) paste(item$package, item$name, item$task, sep = " | "),
    character(1)
  )
  cat(paste(entries, collapse = "\n"), "\n")
}

registry <- benchmark_registry
if (!is.null(opts$datasets)) {
  registry <- Filter(function(item) item$name %in% opts$datasets, registry)
}
if (!is.null(opts$packages)) {
  registry <- Filter(function(item) item$package %in% opts$packages, registry)
}
if (!is.null(opts$skip_tags) && length(opts$skip_tags) > 0) {
  registry <- Filter(
    function(item) {
      tags <- item$tags
      if (is.null(tags)) {
        return(TRUE)
      }
      !any(tags %in% opts$skip_tags)
    },
    registry
  )
}
if (is.finite(opts$max_datasets)) {
  registry <- head(registry, opts$max_datasets)
}

if (opts$list) {
  print_registry(benchmark_registry)
  quit(status = 0)
}
if (opts$dry_run) {
  print_registry(registry)
  quit(status = 0)
}



if (nzchar(opts$java_params)) {
  options(java.parameters = c(opts$java_params, "--add-modules=jdk.incubator.vector"))
}

suppressPackageStartupMessages({
  if (!requireNamespace("bartMachine", quietly = TRUE)) {
    stop("bartMachine is not installed.")
  }
  if (!requireNamespace("randomForest", quietly = TRUE)) {
    stop("randomForest is not installed.")
  }
  library(bartMachine)
  library(randomForest)
})

set_bart_machine_num_cores(opts$cores, verbose = FALSE)

normalize_predictors <- function(df) {
  df[] <- lapply(df, function(col) {
    if (is.character(col) || is.logical(col)) {
      return(as.factor(col))
    }
    if (inherits(col, "ordered")) {
      return(factor(col, ordered = FALSE))
    }
    col
  })
  df
}

drop_unused_factor_levels <- function(df) {
  df[] <- lapply(df, function(col) {
    if (is.factor(col)) {
      return(droplevels(col))
    }
    col
  })
  df
}

align_factor_levels <- function(train_df, test_df, test_y) {
  drop_idx <- rep(FALSE, nrow(test_df))
  for (col_name in names(train_df)) {
    train_col <- train_df[[col_name]]
    test_col <- test_df[[col_name]]
    if (is.factor(train_col)) {
      test_col <- factor(as.character(test_col), levels = levels(train_col))
      missing_level <- is.na(test_col) & !is.na(test_df[[col_name]])
      drop_idx <- drop_idx | missing_level
      test_df[[col_name]] <- test_col
    }
  }
  if (any(drop_idx)) {
    test_df <- test_df[!drop_idx, , drop = FALSE]
    test_y <- test_y[!drop_idx]
  }
  list(train = train_df, test = test_df, test_y = test_y, dropped = sum(drop_idx))
}

drop_high_cardinality_factors <- function(train_df, test_df, max_levels) {
  factor_cols <- names(train_df)[vapply(train_df, is.factor, logical(1))]
  if (length(factor_cols) == 0) {
    return(list(train = train_df, test = test_df, dropped = character(0)))
  }
  level_counts <- vapply(train_df[factor_cols], nlevels, integer(1))
  drop_cols <- factor_cols[level_counts > max_levels]
  if (length(drop_cols) > 0) {
    keep_idx <- !(names(train_df) %in% drop_cols)
    train_df <- train_df[, keep_idx, drop = FALSE]
    test_df <- test_df[, keep_idx, drop = FALSE]
  }
  list(train = train_df, test = test_df, dropped = drop_cols)
}

load_dataset <- function(def) {
  if (!requireNamespace(def$package, quietly = TRUE)) {
    return(list(ok = FALSE, reason = paste("Package missing:", def$package)))
  }
  env <- new.env(parent = emptyenv())
  data(list = def$name, package = def$package, envir = env)
  if (!exists(def$name, envir = env)) {
    return(list(ok = FALSE, reason = "Dataset not found"))
  }
  dat <- get(def$name, envir = env)
  if (is.matrix(dat)) {
    dat <- as.data.frame(dat)
  } else if (is.list(dat) && !is.data.frame(dat)) {
    if (!is.null(dat$x) && !is.null(dat$y)) {
      dat <- data.frame(dat$x, y = dat$y)
    } else {
      return(list(ok = FALSE, reason = "Dataset format not supported"))
    }
  }
  dat <- as.data.frame(dat)
  if (!is.null(def$drop)) {
    dat <- dat[, !(names(dat) %in% def$drop), drop = FALSE]
  }
  if (!(def$response %in% names(dat))) {
    return(list(ok = FALSE, reason = "Response column not found"))
  }
  dat <- dat[complete.cases(dat), , drop = FALSE]
  y <- dat[[def$response]]
  X <- dat[, setdiff(names(dat), def$response), drop = FALSE]
  X <- normalize_predictors(X)
  if (def$task == "classification") {
    if (is.logical(y) || is.numeric(y) || is.integer(y)) {
      y <- as.factor(y)
    } else if (is.character(y)) {
      y <- as.factor(y)
    }
    y <- droplevels(y)
    if (!is.factor(y) || length(levels(y)) != 2) {
      return(list(ok = FALSE, reason = "Classification target is not binary"))
    }
    # Enforce uniformity: first level becomes "1", second becomes "0"
    original_levels <- levels(y)
    y <- factor(ifelse(y == original_levels[1], "1", "0"), levels = c("1", "0"))
  } else {
    if (!is.numeric(y) && !is.integer(y)) {
      return(list(ok = FALSE, reason = "Regression target is not numeric"))
    }
    y <- as.numeric(y)
  }
  list(ok = TRUE, X = X, y = y, data = dat)
}

make_folds <- function(y, k, seed, stratify) {
  set.seed(seed)
  n <- length(y)
  fold_ids <- integer(n)
  if (stratify) {
    for (lvl in levels(y)) {
      idx <- which(y == lvl)
      idx <- sample(idx)
      fold_ids[idx] <- rep(seq_len(k), length.out = length(idx))
    }
  } else {
    fold_ids <- sample(rep(seq_len(k), length.out = n))
  }
  fold_ids
}

score_regression <- function(pred, y) {
  rmse <- sqrt(mean((pred - y) ^ 2))
  mae <- mean(abs(pred - y))
  denom <- sum((y - mean(y)) ^ 2)
  rsq <- if (denom > 0) 1 - sum((pred - y) ^ 2) / denom else NA_real_
  c(rmse = rmse, mae = mae, rsq = rsq)
}

score_classification <- function(pred, y, prob = NULL) {
  pred <- factor(pred, levels = levels(y))
  misclass <- mean(pred != y)
  accuracy <- mean(pred == y)
  c(misclass = misclass, accuracy = accuracy)
}

fold_rows <- list()
skip_rows <- list()

metrics_cols <- c("rmse", "mae", "rsq", "misclass", "accuracy")
progress_cols <- c(
  "dataset", "package", "task", "model", "repeat_id", "fold", "n", "n_train", "n_test",
  "p", "train_seconds", "predict_seconds", "rmse", "mae", "rsq", "misclass", "accuracy"
)
fold_id <- 1

invisible(dir.create(opts$output_dir, recursive = TRUE, showWarnings = FALSE))
progress_path <- file.path(opts$output_dir, "benchmark_progress.csv")
if (file.exists(progress_path)) {
  invisible(file.remove(progress_path))
}
append_progress_rows <- function(rows) {
  if (nrow(rows) == 0) {
    return()
  }
  if (!all(progress_cols %in% colnames(rows))) {
    stop("Progress rows missing expected columns.")
  }
  write.table(
    rows[, progress_cols, drop = FALSE],
    progress_path,
    sep = ",",
    row.names = FALSE,
    col.names = !file.exists(progress_path),
    append = file.exists(progress_path),
    quote = TRUE
  )
}

create_summary_plots <- function(summary_df, fold_df, output_path) {
  if (nrow(summary_df) == 0 && nrow(fold_df) == 0) {
    return(invisible(NULL))
  }
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    message("ggplot2 is required to generate benchmark plots.")
    return(invisible(NULL))
  }
  plot_list <- list()
  if (nrow(summary_df) > 0) {
    for (metric in unique(summary_df$metric)) {
      df <- subset(summary_df, metric == metric)
      if (nrow(df) == 0) {
        next
      }
      df$dataset <- factor(df$dataset, levels = unique(df$dataset[order(df$mean)]))
      p <- ggplot2::ggplot(df, ggplot2::aes(x = dataset, y = mean, fill = model)) +
        ggplot2::geom_col(position = ggplot2::position_dodge(width = 0.8)) +
        ggplot2::scale_y_log10() +
        ggplot2::labs(
          title = metric,
          x = "Dataset",
          y = "Mean (log scale)"
        ) +
        ggplot2::theme_minimal() +
        ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1))
      plot_list[[length(plot_list) + 1]] <- p
    }
  }
  if (nrow(fold_df) > 0) {
    time_summary <- aggregate(cbind(train_seconds, predict_seconds) ~ dataset + model, data = fold_df, FUN = mean)
    if (nrow(time_summary) > 0) {
      time_long <- data.frame(
        dataset = rep(time_summary$dataset, times = 2),
        model = rep(time_summary$model, times = 2),
        stage = rep(c("train", "predict"), each = nrow(time_summary)),
        seconds = c(time_summary$train_seconds, time_summary$predict_seconds)
      )
      time_long$dataset <- factor(time_long$dataset, levels = unique(time_long$dataset))
      p <- ggplot2::ggplot(time_long, ggplot2::aes(x = dataset, y = seconds, fill = stage)) +
        ggplot2::geom_col(position = ggplot2::position_dodge(width = 0.8)) +
        ggplot2::scale_y_log10() +
        ggplot2::facet_wrap(~model, scales = "free_y", ncol = 1) +
        ggplot2::labs(title = "Average Timing by Dataset", x = "Dataset", y = "Seconds (log scale)") +
        ggplot2::theme_minimal() +
        ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1))
      plot_list[[length(plot_list) + 1]] <- p
    }
  }
  if (length(plot_list) == 0) {
    return(invisible(NULL))
  }
  pdf(output_path, width = 8, height = 6)
  on.exit(grDevices::dev.off(), add = TRUE)
  for (plot_obj in plot_list) {
    print(plot_obj)
  }
  invisible(NULL)
}

for (dataset_index in seq_along(registry)) {
  def <- registry[[dataset_index]]
  message("Loading ", def$package, "::", def$name, "...")
  loaded <- load_dataset(def)
  if (!loaded$ok) {
    skip_rows[[length(skip_rows) + 1]] <- data.frame(
      dataset = def$name,
      package = def$package,
      reason = loaded$reason,
      stringsAsFactors = FALSE
    )
    message("  Skipping: ", loaded$reason)
    next
  }
  X <- loaded$X
  y <- loaded$y
  n <- nrow(X)
  p <- ncol(X)
  if (n < opts$folds) {
    skip_rows[[length(skip_rows) + 1]] <- data.frame(
      dataset = def$name,
      package = def$package,
      reason = "Not enough rows for folds",
      stringsAsFactors = FALSE
    )
    message("  Skipping: not enough rows for folds.")
    next
  }
  if (def$task == "classification") {
    class_counts <- table(y)
    if (min(class_counts) < 2) {
      skip_rows[[length(skip_rows) + 1]] <- data.frame(
        dataset = def$name,
        package = def$package,
        reason = "Insufficient class counts",
        stringsAsFactors = FALSE
      )
      message("  Skipping: insufficient class counts.")
      next
    }
  }

  for (rep_idx in seq_len(opts$repeats)) {
    fold_seed <- opts$seed + dataset_index * 1000 + rep_idx * 100
    fold_ids <- make_folds(y, opts$folds, fold_seed, def$task == "classification")

    for (fold in seq_len(opts$folds)) {
      test_idx <- which(fold_ids == fold)
      train_idx <- which(fold_ids != fold)
      train_X <- X[train_idx, , drop = FALSE]
      test_X <- X[test_idx, , drop = FALSE]
      train_y <- y[train_idx]
      test_y <- y[test_idx]

      train_X <- drop_unused_factor_levels(train_X)

      if (def$task == "classification") {
        train_y <- droplevels(train_y)
        if (length(levels(train_y)) < 2) {
          message("  Fold skipped: training data has one class only.")
          next
        }
        test_y <- factor(test_y, levels = levels(train_y))
        if (any(is.na(test_y))) {
          message("  Fold skipped: test data has unseen response levels.")
          next
        }
      }

      aligned <- align_factor_levels(train_X, test_X, test_y)
      train_X <- aligned$train
      test_X <- aligned$test
      test_y <- aligned$test_y
      if (aligned$dropped > 0) {
        message("  Dropped ", aligned$dropped, " test rows with unseen factor levels.")
      }
      if (nrow(test_X) == 0) {
        message("  Fold skipped: no test rows after factor alignment.")
        next
      }

      dropped_factors <- drop_high_cardinality_factors(train_X, test_X, 53L)
      train_X <- dropped_factors$train
      test_X <- dropped_factors$test
      if (length(dropped_factors$dropped) > 0) {
        message(
          "  Dropped high-cardinality factors for randomForest compatibility: ",
          paste(dropped_factors$dropped, collapse = ", ")
        )
      }
      if (ncol(train_X) == 0) {
        message("  Fold skipped: no predictors after dropping high-cardinality factors.")
        next
      }

      n_train <- nrow(train_X)
      n_test <- nrow(test_X)
      fold_p <- ncol(train_X)

      fold_seed <- fold_seed + fold
      bart_params <- list(
        num_trees = opts$bart_num_trees,
        num_burn_in = opts$bart_burn_in,
        num_iterations_after_burn_in = opts$bart_iter,
        use_xoshiro = opts$use_xoshiro,
        verbose = FALSE
      )
      rf_params <- list(
        ntree = opts$rf_ntree
      )
      if (!is.na(opts$rf_mtry)) {
        rf_params$mtry <- opts$rf_mtry
      }

      set.seed(fold_seed)
      bart_time <- system.time({
        bart_fit <- do.call(bartMachine, c(list(train_X, train_y), bart_params))
      })["elapsed"]
      set.seed(fold_seed)
      rf_time <- system.time({
        rf_fit <- do.call(randomForest, c(list(train_X, train_y), rf_params))
      })["elapsed"]

      if (def$task == "classification") {
        bart_pred_time <- system.time({
          bart_prob <- predict(bart_fit, test_X, type = "prob", verbose = FALSE)
        })["elapsed"]
        bart_pred <- ifelse(bart_prob >= 0.5, levels(train_y)[1], levels(train_y)[2])
        bart_pred <- factor(bart_pred, levels = levels(train_y))
        bart_metrics <- score_classification(bart_pred, test_y, bart_prob)

        rf_pred_time <- system.time({
          rf_prob <- predict(rf_fit, test_X, type = "prob")
        })["elapsed"]
        rf_prob <- rf_prob[, levels(train_y)[1]]
        rf_pred <- predict(rf_fit, test_X, type = "response")
        rf_metrics <- score_classification(rf_pred, test_y, rf_prob)
      } else {
        bart_pred_time <- system.time({
          bart_pred <- predict(bart_fit, test_X, verbose = FALSE)
        })["elapsed"]
        bart_metrics <- score_regression(bart_pred, test_y)

        rf_pred_time <- system.time({
          rf_pred <- predict(rf_fit, test_X)
        })["elapsed"]
        rf_metrics <- score_regression(rf_pred, test_y)
      }

      bart_row <- data.frame(
        dataset = def$name,
        package = def$package,
        task = def$task,
        model = "bartMachine",
        repeat_id = rep_idx,
        fold = fold,
        n = n,
        n_train = n_train,
        n_test = n_test,
        p = fold_p,
        train_seconds = as.numeric(bart_time),
        predict_seconds = as.numeric(bart_pred_time),
        rmse = ifelse("rmse" %in% names(bart_metrics), bart_metrics["rmse"], NA_real_),
        mae = ifelse("mae" %in% names(bart_metrics), bart_metrics["mae"], NA_real_),
        rsq = ifelse("rsq" %in% names(bart_metrics), bart_metrics["rsq"], NA_real_),
        misclass = ifelse("misclass" %in% names(bart_metrics), bart_metrics["misclass"], NA_real_),
        accuracy = ifelse("accuracy" %in% names(bart_metrics), bart_metrics["accuracy"], NA_real_),
        stringsAsFactors = FALSE
      )
      fold_rows[[fold_id]] <- bart_row
      append_progress_rows(bart_row)
      fold_id <- fold_id + 1

      rf_row <- data.frame(
        dataset = def$name,
        package = def$package,
        task = def$task,
        model = "randomForest",
        repeat_id = rep_idx,
        fold = fold,
        n = n,
        n_train = n_train,
        n_test = n_test,
        p = fold_p,
        train_seconds = as.numeric(rf_time),
        predict_seconds = as.numeric(rf_pred_time),
        rmse = ifelse("rmse" %in% names(rf_metrics), rf_metrics["rmse"], NA_real_),
        mae = ifelse("mae" %in% names(rf_metrics), rf_metrics["mae"], NA_real_),
        rsq = ifelse("rsq" %in% names(rf_metrics), rf_metrics["rsq"], NA_real_),
        misclass = ifelse("misclass" %in% names(rf_metrics), rf_metrics["misclass"], NA_real_),
        accuracy = ifelse("accuracy" %in% names(rf_metrics), rf_metrics["accuracy"], NA_real_),
        stringsAsFactors = FALSE
      )
      fold_rows[[fold_id]] <- rf_row
      append_progress_rows(rf_row)
      fold_id <- fold_id + 1
    }
  }
}

fold_results <- if (length(fold_rows) == 0) {
  data.frame()
} else {
  do.call(rbind, fold_rows)
}

summary_results <- data.frame()
if (nrow(fold_results) > 0) {
  summary_rows <- list()
  summary_id <- 1
  split_key <- paste(fold_results$dataset, fold_results$model, fold_results$task)
  for (key in unique(split_key)) {
    subset_idx <- which(split_key == key)
    subset_df <- fold_results[subset_idx, , drop = FALSE]
    for (metric in metrics_cols) {
      metric_vals <- subset_df[[metric]]
      if (all(is.na(metric_vals))) {
        next
      }
      summary_rows[[summary_id]] <- data.frame(
        dataset = subset_df$dataset[1],
        package = subset_df$package[1],
        task = subset_df$task[1],
        model = subset_df$model[1],
        metric = metric,
        mean = mean(metric_vals, na.rm = TRUE),
        sd = sd(metric_vals, na.rm = TRUE),
        n = sum(!is.na(metric_vals)),
        stringsAsFactors = FALSE
      )
      summary_id <- summary_id + 1
    }
  }
  if (length(summary_rows) > 0) {
    summary_results <- do.call(rbind, summary_rows)
  }
}

comparison_results <- data.frame()
if (nrow(fold_results) > 0 && "bartMachine" %in% fold_results$model && "randomForest" %in% fold_results$model) {
  comp_rows <- list()
  comp_id <- 1
  # Datasets that have both models
  datasets <- unique(fold_results$dataset)
  
  for (ds in datasets) {
    ds_df <- fold_results[fold_results$dataset == ds, , drop = FALSE]
    # Check if we have both models for this dataset
    if (!all(c("bartMachine", "randomForest") %in% ds_df$model)) {
      next
    }
    
    for (metric in metrics_cols) {
      # Extract vectors for each model
      # Ensure sorting by repeat and fold to align pairs
      bart_df <- ds_df[ds_df$model == "bartMachine", , drop = FALSE]
      rf_df <- ds_df[ds_df$model == "randomForest", , drop = FALSE]
      
      # Merge to ensure we only compare valid pairs
      # Using repeat_id and fold as keys
      merged <- merge(
        bart_df[, c("repeat_id", "fold", metric)], 
        rf_df[, c("repeat_id", "fold", metric)], 
        by = c("repeat_id", "fold"), 
        suffixes = c("_bart", "_rf")
      )
      
      val_bart <- merged[[paste0(metric, "_bart")]]
      val_rf <- merged[[paste0(metric, "_rf")]]
      
      # Filter NAs
      valid_idx <- !is.na(val_bart) & !is.na(val_rf)
      val_bart <- val_bart[valid_idx]
      val_rf <- val_rf[valid_idx]
      
      if (length(val_bart) < 2) {
        next
      }
      
      # T-test
      # We want to test difference
      # Default: two.sided, paired=TRUE
      test_res <- tryCatch(
        t.test(val_bart, val_rf, paired = TRUE, alternative = "two.sided"),
        error = function(e) NULL
      )
      
      pval <- if (!is.null(test_res)) test_res$p.value else NA_real_
      
      comp_rows[[comp_id]] <- data.frame(
        dataset = ds,
        package = ds_df$package[1],
        task = ds_df$task[1],
        metric = metric,
        bart_mean = mean(val_bart),
        rf_mean = mean(val_rf),
        diff_mean = mean(val_bart - val_rf),
        p_value = pval,
        n = length(val_bart),
        stringsAsFactors = FALSE
      )
      comp_id <- comp_id + 1
    }
  }
  
  if (length(comp_rows) > 0) {
    comparison_results <- do.call(rbind, comp_rows)
    # Add winner and order results
    comparison_results$winner <- ifelse(
      comparison_results$metric %in% c("rmse", "mae", "misclass"),
      ifelse(comparison_results$bart_mean <= comparison_results$rf_mean, "BART", "RF"),
      ifelse(comparison_results$bart_mean >= comparison_results$rf_mean, "BART", "RF")
    )
    comparison_results <- comparison_results[order(
      comparison_results$metric,
      comparison_results$dataset,
      comparison_results$p_value
    ), ]
  }
}

skip_results <- if (length(skip_rows) == 0) {
  data.frame()
} else {
  do.call(rbind, skip_rows)
}

fold_path <- file.path(opts$output_dir, "benchmark_folds.csv")
summary_path <- file.path(opts$output_dir, "benchmark_summary.csv")
comparison_path <- file.path(opts$output_dir, "benchmark_comparisons.csv")
skip_path <- file.path(opts$output_dir, "benchmark_skipped.csv")
rds_path <- file.path(opts$output_dir, "benchmark_results.rds")
session_path <- file.path(opts$output_dir, "sessionInfo.txt")

if (nrow(fold_results) > 0) {
  write.csv(fold_results, fold_path, row.names = FALSE)
}
if (nrow(summary_results) > 0) {
  write.csv(summary_results, summary_path, row.names = FALSE)
}
if (nrow(comparison_results) > 0) {
  write.csv(comparison_results, comparison_path, row.names = FALSE)
}
if (nrow(skip_results) > 0) {
  write.csv(skip_results, skip_path, row.names = FALSE)
}
saveRDS(
  list(
    config = opts,
    registry_version = benchmark_registry_version,
    summary = summary_results,
    comparison = comparison_results,
    folds = fold_results,
    skipped = skip_results
  ),
  rds_path
)
capture.output(sessionInfo(), file = session_path)

plot_path <- file.path(opts$output_dir, "benchmark_summary_plots.pdf")
create_summary_plots(summary_results, fold_results, plot_path)

message("Benchmark complete.")
message("Summary: ", summary_path)

Try the bartMachine package in your browser

Any scripts or data that you put into this service are public.

bartMachine documentation built on Jan. 19, 2026, 9:06 a.m.