Nothing
#' Predict Sporulation Potential
#'
#' This function predicts the sporulation potential of MAGs using an ensemble learning model.
#' It uses probabilities from Random Forest and SVM classifiers as inputs to a meta-model.
#'
#' @param binary_matrix A binary matrix (1/0) indicating gene presence/absence for each MAG. Must include a `genome_ID` column.
#'
#' @return A tibble with predicted class and probability of sporulation for each genome.
#' @import dplyr
#' @importFrom stats predict
#'
#' @examples
#' # Load package
#' library(SpoMAG)
#'
#' # Load example annotation tables
#' file_spor <- system.file("extdata", "one_sporulating.csv.gz", package = "SpoMAG")
#' file_aspo <- system.file("extdata", "one_asporogenic.csv.gz", package = "SpoMAG")
#'
#' # Read files
#' df_spor <- readr::read_csv(file_spor, show_col_types = FALSE)
#' df_aspo <- readr::read_csv(file_aspo, show_col_types = FALSE)
#'
#' # Step 1: Extract sporulation-related genes
#' genes_spor <- sporulation_gene_name(df_spor)
#' genes_aspo <- sporulation_gene_name(df_aspo)
#'
#' # Step 2: Convert to binary matrix
#' bin_spor <- build_binary_matrix(genes_spor)
#' bin_aspo <- build_binary_matrix(genes_aspo)
#'
#' # Step 3: Predict using ensemble model (preloaded in package)
#'
#' result_spor <- predict_sporulation(bin_spor)
#' result_aspo <- predict_sporulation(bin_aspo)
#'
#'
#' @export
predict_sporulation <- function(binary_matrix) {
# Locate model file from internal package data
model_path <- system.file("extdata", "models_sporulation.RData", package = "SpoMAG")
if (model_path == "" || !file.exists(model_path)) {
stop("Model file not found. Please ensure 'models_sporulation.RData' is included in inst/extdata.")
}
# Load the trained models
load(model_path) # loads: rf_model, svm_model, meta_model
required_packages <- c("caret", "kernlab", "randomForest")
for (pkg in required_packages) {
if (!requireNamespace(pkg, quietly = TRUE)) {
stop(paste("Package", pkg, "is required"))
}
}
# Check if "genome_ID" exists
if (!"genome_ID" %in% colnames(binary_matrix)) {
stop("Input binary matrix must contain a 'genome_ID' column.")
}
# Remove identifiers
features <- binary_matrix[, setdiff(names(binary_matrix), "genome_ID")]
# Add missing predictors used in RF and SVM
rf_vars <- setdiff(colnames(rf_model$trainingData), ".outcome")
svm_vars <- setdiff(colnames(svm_model$trainingData), ".outcome")
all_vars <- union(rf_vars, svm_vars)
missing_vars <- setdiff(all_vars, colnames(features))
for (var in missing_vars) {
features[[var]] <- 0
}
# Reorder columns to match training
features <- binary_matrix %>%
dplyr::select(dplyr::all_of(union(rf_vars, svm_vars))) %>%
tibble::as_tibble()
# Generate predictions from base models
prob_rf <- predict(rf_model, features, type = "prob")[, "Esporulante"]
prob_svm <- predict(svm_model, features, type = "prob")[, "Esporulante"]
# Create meta-model input
meta_input <- data.frame(
RF_Prob = prob_rf,
SVM_Prob = prob_svm
)
# Final prediction using the meta-model
prob_meta <- predict(meta_model, meta_input, type = "prob")[, "Esporulante"]
class_meta <- ifelse(prob_meta > 0.5, "Sporulating", "Non_sporulating")
# Combine results
result <- dplyr::tibble(
genome_ID = binary_matrix$genome_ID,
RF_Prob = prob_rf,
SVM_Prob = prob_svm,
Meta_Prob_Sporulating = prob_meta,
Meta_Prediction = class_meta
)
return(result)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.