R/compare_model_predictions.R

Defines functions standardize_cell_type_names compare_model_predictions

Documented in compare_model_predictions standardize_cell_type_names

#' Compare predictions from different models
#' 
#' This function runs the same input through multiple models and compares their predictions.
#' It provides both individual predictions and a consensus analysis.
#' 
#' @note This function uses create_standardization_prompt from prompt_templates.R
#
#
#
#'   Supported models:
#'   - OpenAI: 'gpt-5.5', 'gpt-5.4', 'gpt-5.4-mini'
#'   - Anthropic: 'claude-opus-4-7', 'claude-opus-4-6', 'claude-sonnet-4-6', 'claude-haiku-4-5-20251001'
#'   - DeepSeek: 'deepseek-v4-flash', 'deepseek-v4-pro'
#'   - Google: 'gemini-3.1-pro-preview', 'gemini-3-flash-preview', 'gemini-3.1-flash-lite'
#'   - Alibaba: 'qwen3.6-max-preview', 'qwen3.6-plus', 'qwen3.6-flash'
#'   - Stepfun: 'step-3.5-flash', 'step-3.5-flash-2603', 'step-3'
#'   - Zhipu/Z.AI: 'glm-5.1', 'glm-5-turbo', 'glm-5'
#'   - MiniMax: 'MiniMax-M2.7', 'MiniMax-M2.7-highspeed', 'MiniMax-M2.5'
#'   - X.AI: 'grok-4.3', 'grok-4.3-latest', 'grok-latest'
#'   - OpenRouter: Provides access to models from multiple providers through a single API. Format: 'provider/model-name'
#'     - OpenAI models: 'openai/gpt-5.5', 'openai/gpt-5.4-mini'
#'     - Anthropic models: 'anthropic/claude-opus-4.7', 'anthropic/claude-sonnet-4.6'
#'     - Google models: 'google/gemini-3.1-pro-preview', 'google/gemini-3-flash-preview'
#'     - X.AI models: 'x-ai/grok-4.3'
#'     - Stepfun models: 'stepfun/step-3.5-flash'
#
#'   1. With provider names as keys: `list("openai" = "sk-...", "anthropic" = "sk-ant-...", "openrouter" = "sk-or-...")`
#'   2. With model names as keys: `list("gpt-5.5" = "sk-...", "claude-sonnet-4-6" = "sk-ant-...")`
#'   
#'   The system first tries to find the API key using the provider name. If not found, it then tries using the model name.
#'   Example:
#'   ```r
#'   api_keys <- list(
#'     "openai" = Sys.getenv("OPENAI_API_KEY"),
#'     "anthropic" = Sys.getenv("ANTHROPIC_API_KEY"),
#'     "openrouter" = Sys.getenv("OPENROUTER_API_KEY"),
#'     "claude-opus-4-7" = "your-claude-opus-key"
#'   )
#'   ```
#'
#' @param input Either a data frame from Seurat's FindAllMarkers() containing columns 'cluster', 'gene', and 'avg_log2FC', or a list with 'genes' field for each cluster
#' @param tissue_name Tissue context (e.g., 'human PBMC', 'mouse brain') for more accurate annotations
#' @param models Vector of model names to use for comparison. Default includes top models from each provider
#' @param api_keys Named list of API keys for the models, with provider or model names as keys.
#'   Every model in \code{models} must resolve to a non-NULL API key.
#' @param top_gene_count Number of top genes to use per cluster when input is from Seurat. Default: 10
#' @param consensus_threshold Minimum agreement threshold for consensus (0-1). Default: 0.5.
#'   Consensus is only evaluated when at least two non-missing model predictions are available for a cluster.
#' @param base_urls Optional base URLs for API endpoints. Can be a string or named list for provider-specific custom endpoints.
#'
#' @return List containing individual model predictions and consensus analysis
#'   If a cluster has fewer than two valid predictions after alignment/padding,
#'   its consensus-related outputs are \code{NA}.
#' @export
#' @examples
#' \dontrun{
#' # Compare predictions using different models
#' api_keys <- list(
#'   "claude-sonnet-4-6" = "your-anthropic-key",
#'   "deepseek-v4-pro" = "your-deepseek-key",
#'   "gemini-3.1-pro-preview" = "your-gemini-key",
#'   "qwen3.6-plus" = "your-qwen-key"
#' )
#' 
#' results <- compare_model_predictions(
#'   input = list(gs1=c('CD4','CD3D'), gs2='CD14'),
#'   tissue_name = 'PBMC',
#'   api_keys = api_keys
#' )
#' }
compare_model_predictions <- function(input,
                                      tissue_name,
                                      models = c("claude-opus-4-7",
                                                 "gpt-5.5",
                                                 "gemini-3.1-pro-preview",
                                                 "deepseek-v4-flash",
                                                 "qwen3.6-plus",
                                                 "grok-4.3"),
                                      api_keys,
                                      top_gene_count = 10,
                                      consensus_threshold = 0.5,
                                      base_urls = NULL) {
  
  # Validate inputs
  if (!is.list(api_keys) || length(api_keys) == 0) {
    stop("api_keys must be a non-empty list with named elements corresponding to models")
  }
  
  # Validate model/API-key pairs without letting one bad model block the rest
  eligible_models <- character(0)
  for (m in models) {
    provider <- tryCatch(get_provider(m), error = function(e) NULL)
    api_key <- if (is.null(provider)) NULL else get_api_key(m, api_keys)
    if (is.null(provider)) {
      warning(sprintf("Skipping model '%s': unsupported model name", m))
    } else if (is.null(api_key)) {
      warning(sprintf("Skipping model '%s': no API key found for provider '%s' or model name", m, provider))
    } else {
      eligible_models <- c(eligible_models, m)
    }
  }
  if (length(eligible_models) == 0) {
    stop("No models have both a supported provider and an API key")
  }
  models <- eligible_models

  # Extract cluster IDs from input for display
  prompt_result <- create_annotation_prompt(input, tissue_name, top_gene_count)
  cluster_ids <- names(prompt_result$gene_lists)

  # Initialize results storage
  all_predictions <- list()
  successful_models <- character(0)

  # Get predictions from each model
  for (model in models) {
    message(sprintf("\nRunning predictions with model: %s", model))
    tryCatch({
      api_key <- get_api_key(model, api_keys)

      predictions <- annotate_cell_types(
        input = input,
        tissue_name = tissue_name,
        model = model,
        api_key = api_key,
        top_gene_count = top_gene_count,
        base_urls = base_urls
      )
      all_predictions[[model]] <- predictions
      successful_models <- c(successful_models, model)
    }, error = function(e) {
      warning(sprintf("Error with model %s: %s", model, e$message))
      log_warn(sprintf("Model %s failed during prediction", model), list(model = model, error = e$message))
    })
  }
  
  # Check if we have any successful predictions
  if (length(successful_models) == 0) {
    stop("No models successfully completed predictions")
  }
  
  # Standardize cell type names using LLM
  message("\nStandardizing cell type names...")
  standardized_predictions <- standardize_cell_type_names(all_predictions, successful_models, api_keys, base_urls = base_urls)
  
  # Pad all prediction vectors to equal length with NA to avoid vector recycling
  all_vectors <- all_predictions[successful_models]
  std_vectors <- standardized_predictions[successful_models]
  max_len <- max(vapply(std_vectors, length, integer(1)),
                 vapply(all_vectors, length, integer(1)))
  pad <- function(x) { length(x) <- max_len; x }
  cluster_labels <- cluster_ids
  length(cluster_labels) <- max_len
  missing_labels <- is.na(cluster_labels) | !nzchar(cluster_labels)
  cluster_labels[missing_labels] <- paste0("..prediction_", seq_len(max_len)[missing_labels])

  comparison_matrix <- do.call(cbind, lapply(std_vectors, pad))
  colnames(comparison_matrix) <- successful_models
  rownames(comparison_matrix) <- cluster_labels
  raw_matrix <- do.call(cbind, lapply(all_vectors, pad))
  colnames(raw_matrix) <- successful_models
  rownames(raw_matrix) <- cluster_labels
  n_clusters <- nrow(comparison_matrix)

  # Calculate consensus and agreement statistics
  consensus_results <- apply(comparison_matrix, 1, function(row) {
    # Remove NAs
    valid_predictions <- row[!is.na(row)]
    if (length(valid_predictions) < 2) {
      return(list(consensus = NA, consensus_proportion = NA, entropy = NA))
    }
    
    # Count occurrences of each prediction
    pred_table <- table(valid_predictions)
    max_agreement <- max(pred_table) / length(valid_predictions)
    
    # Get consensus if agreement meets threshold
    consensus <- if (max_agreement >= consensus_threshold) {
      names(pred_table)[which.max(pred_table)]
    } else {
      NA
    }
    
    # Calculate consensus proportion
    consensus_proportion <- if (!is.na(consensus)) {
      pred_table[consensus] / length(valid_predictions)
    } else {
      NA
    }
    
    # Calculate entropy
    entropy <- -sum(pred_table / length(valid_predictions) * log2(pred_table / length(valid_predictions)))
    
    list(
      consensus = consensus,
      consensus_proportion = consensus_proportion,
      entropy = entropy
    )
  })
  
  # Format results
  consensus_predictions <- sapply(consensus_results, function(x) x$consensus)
  consensus_proportions <- sapply(consensus_results, function(x) x$consensus_proportion)
  entropies <- sapply(consensus_results, function(x) x$entropy)
  names(consensus_predictions) <- cluster_labels
  names(consensus_proportions) <- cluster_labels
  names(entropies) <- cluster_labels

  # Calculate overall statistics
  model_agreement_matrix <- matrix(NA, 
                                   nrow = length(successful_models), 
                                   ncol = length(successful_models),
                                   dimnames = list(successful_models, successful_models))
  
  for (i in seq_along(successful_models)) {
    for (j in seq_along(successful_models)) {
      if (i != j) {
        valid_comparisons <- !is.na(comparison_matrix[, successful_models[i]]) &
          !is.na(comparison_matrix[, successful_models[j]])
        if (any(valid_comparisons)) {
          agreement <- mean(comparison_matrix[valid_comparisons, successful_models[i]] ==
                              comparison_matrix[valid_comparisons, successful_models[j]])
          model_agreement_matrix[i,j] <- agreement
        }
      }
    }
  }
  
  mean_or_na <- function(x) {
    if (all(is.na(x))) NA_real_ else mean(x, na.rm = TRUE)
  }

  # Prepare summary statistics
  summary_stats <- list(
    total_clusters = n_clusters,
    consensus_reached = sum(!is.na(consensus_predictions)),
    mean_consensus_proportion = mean_or_na(consensus_proportions),
    mean_entropy = mean_or_na(entropies),
    model_agreement_matrix = model_agreement_matrix
  )

  # Return results
  results <- list(
    individual_predictions = all_predictions[successful_models],
    standardized_predictions = standardized_predictions[successful_models],
    comparison_matrix = comparison_matrix,
    consensus_predictions = consensus_predictions,
    consensus_proportions = consensus_proportions,
    entropies = entropies,
    summary_stats = summary_stats
  )
  
  # Print summary
  message("\nModel Comparison Summary:")
  message(sprintf("Total clusters analyzed: %d", summary_stats$total_clusters))
  message(sprintf("Clusters with consensus: %d (%.1f%%)",
              summary_stats$consensus_reached,
              100 * summary_stats$consensus_reached / summary_stats$total_clusters))
  mean_consensus_label <- if (is.na(summary_stats$mean_consensus_proportion)) {
    "NA"
  } else {
    sprintf("%.2f", summary_stats$mean_consensus_proportion)
  }
  mean_entropy_label <- if (is.na(summary_stats$mean_entropy)) {
    "NA"
  } else {
    sprintf("%.2f", summary_stats$mean_entropy)
  }
  message(sprintf("Mean consensus proportion: %s", mean_consensus_label))
  message(sprintf("Mean entropy: %s", mean_entropy_label))

  message("\nPairwise Model Agreement:")
  message(paste(utils::capture.output(print(model_agreement_matrix)), collapse = "\n"))

  message("\nDetailed Results:")
  for (i in 1:n_clusters) {
    cluster_label <- cluster_labels[i]
    message(sprintf("\nCluster %s:", cluster_label))
    for (model in successful_models) {
      message(sprintf("  %s: %s (Standardized: %s)",
                model,
                raw_matrix[i, model],
                comparison_matrix[i, model]))
    }
    consensus_proportion_label <- if (is.na(consensus_proportions[i])) {
      "NA"
    } else {
      sprintf("%.2f", consensus_proportions[i])
    }
    entropy_label <- if (is.na(entropies[i])) {
      "NA"
    } else {
      sprintf("%.2f", entropies[i])
    }
    message(sprintf("  Consensus: %s (Consensus Proportion: %s, Entropy: %s)",
                consensus_predictions[i],
                consensus_proportion_label,
                entropy_label))
  }
  
  invisible(results)
}

#' Standardize cell type names using a language model
#' 
#' This function takes predictions from multiple models and standardizes the cell type
#' nomenclature to ensure consistent naming across different models' outputs.
#' 
#
#
#
#'   1. With provider names as keys: `list("openai" = "sk-...", "anthropic" = "sk-ant-...", "openrouter" = "sk-or-...")`
#'   2. With model names as keys: `list("gpt-5.5" = "sk-...", "claude-sonnet-4-6" = "sk-ant-...")`
#
#
#' @keywords internal
standardize_cell_type_names <- function(predictions,
                                       models,
                                       api_keys,
                                       standardization_model = "claude-sonnet-4-6",
                                       base_urls = NULL) {
  # Get API key for standardization model
  api_key <- get_api_key(standardization_model, api_keys)
  
  if (is.null(api_key)) {
    warning(sprintf("No API key found for standardization model '%s'. Using the first available model instead.",
                  standardization_model))
    log_warn("No API key for standardization model, falling back",
             list(requested = standardization_model, fallback = models[1]))
    standardization_model <- models[1]
    api_key <- get_api_key(standardization_model, api_keys)
  }
  
  # Get unique cell type names from all predictions
  all_cell_types <- unique(unlist(predictions[models]))
  all_cell_types <- all_cell_types[!is.na(all_cell_types)]
  
  if (length(all_cell_types) == 0) {
    warning("No valid cell type predictions found to standardize")
    log_warn("No valid cell type predictions found to standardize")
    return(predictions)
  }
  
  # Create a mapping of original cell types to standardized names
  message(sprintf("Using %s to standardize %d unique cell type names", 
                standardization_model, length(all_cell_types)))
  
  # Use the standardization prompt template from prompt_templates.R
  prompt <- create_standardization_prompt(all_cell_types)
  
  # Call the LLM to get standardized names
  tryCatch({
    response <- get_model_response(
      prompt = prompt,
      model = standardization_model,
      api_key = api_key,
      base_urls = base_urls
    )
    
    # Parse the response to extract mappings
    # response may be a character vector (one element per line from get_model_response),
    # so collapse first to ensure all lines are processed
    mapping_lines <- strsplit(paste(response, collapse = "\n"), "\n")[[1]]
    mapping <- list()
    
    # Function to clean cell type names by removing prefixes, numbers, etc.
    clean_cell_type <- function(cell_type) {
      if (is.na(cell_type)) return(cell_type)
      
      # Remove various number/cluster prefixes
      cleaned <- cell_type
      # Remove "1: ", "Cluster 1: ", etc.
      cleaned <- gsub("^\\s*\\d+\\s*:\\s*", "", cleaned)
      cleaned <- gsub("^\\s*[Cc]luster\\s*\\d+\\s*:\\s*", "", cleaned)
      # Also remove any leading numbers with any separator
      cleaned <- gsub("^\\s*\\d+\\s*[.:-]?\\s*", "", cleaned)
      # Trim any leading/trailing whitespace
      cleaned <- trimws(cleaned)
      
      return(cleaned)
    }

    valid_mapping_keys <- unique(c(all_cell_types, vapply(all_cell_types, clean_cell_type, character(1))))
    valid_mapping_keys <- valid_mapping_keys[order(nchar(valid_mapping_keys), decreasing = TRUE)]
    for (line in mapping_lines) {
      line <- trimws(line)
      if (!nzchar(line)) next

      for (original in valid_mapping_keys) {
        prefix <- paste0(original, ":")
        if (startsWith(line, prefix)) {
          standardized <- trimws(substring(line, nchar(prefix) + 1))
          if (nzchar(standardized)) {
            mapping[[original]] <- standardized
          }
          break
        }
      }
    }

    # Apply standardization to all predictions
    standardized_predictions <- predictions
    for (model in models) {
      for (i in seq_along(predictions[[model]])) {
        original <- predictions[[model]][i]
        if (!is.na(original)) {
          # Clean the original name first
          cleaned_original <- clean_cell_type(original)
          
          # Try direct mapping with original or cleaned name
          if (original %in% names(mapping)) {
            standardized <- mapping[[original]]
          } else if (cleaned_original %in% names(mapping)) {
            standardized <- mapping[[cleaned_original]]
          } else {
            # If no mapping found, use the cleaned original
            standardized <- cleaned_original
          }
          
          # Clean the standardized result as well to remove any remaining prefixes
          standardized_predictions[[model]][i] <- clean_cell_type(standardized)
        }
      }
    }
    
    # Print standardization mapping for reference
    message("\nCell Type Standardization Mapping:\n")
    for (original in names(mapping)) {
      message(sprintf("  %s -> %s\n", original, mapping[[original]]))
    }
    
    return(standardized_predictions)
    
  }, error = function(e) {
    warning(sprintf("Error in standardization: %s\nReturning original predictions.", e$message))
    log_warn("Standardization failed, returning original predictions", list(error = e$message))
    return(predictions)
  })
}

Try the mLLMCelltype package in your browser

Any scripts or data that you put into this service are public.

mLLMCelltype documentation built on May 11, 2026, 9:06 a.m.