R/acceptance_rates.R

Defines functions CalculateAcceptanceRates.cia_chain CalculateAcceptanceRates.cia_chains CalculateAcceptanceRates

Documented in CalculateAcceptanceRates

#' Calculate acceptance rates
#' 
#' This makes the assumption that the proposal has saved a variable "proposal_used"
#' and mcmc has saved a variable 'accept'.
#' 
#' @param chains MCMC chains.
#' @param group_by Vector of strings that are in c("chain", "proposal_used"). 
#' Default is NULL which will return the acceptance rates marginalised over
#' chains and the proposal used.
#' 
#' @returns Summary of acceptance rates per grouping.
#' 
#' @examples
#' data <- bnlearn::learning.test
#' 
#' dag <- UniformlySampleDAG(colnames(data))
#' partitioned_nodes <- DAGtoPartition(dag)
#' 
#' scorer <- CreateScorer(
#'   scorer = BNLearnScorer, 
#'   data = data
#'   )
#' 
#' results <- SampleChains(10, partitioned_nodes, PartitionMCMC(), scorer)
#' CalculateAcceptanceRates(results)
#' 
#'
#' @export
CalculateAcceptanceRates <- function(chains, group_by = NULL) UseMethod('CalculateAcceptanceRates')

#' @export
CalculateAcceptanceRates.cia_chains <- function(chains, group_by = NULL) { 
  n_chains <- length(chains)
  chain_info <- list()
  for (i in 1:n_chains) {
    chain_info[[i]] <- dplyr::bind_cols(
      proposal_used = sapply(chains[[i]]$proposal_info, function(x) x$proposal_used),
      accept = sapply(chains[[i]]$mcmc_info, function(x) x$accept),
      black_obeyed = sapply(chains[[i]]$mcmc_info, function(x) x$black_obeyed),
      white_obeyed = sapply(chains[[i]]$mcmc_info, function(x) x$white_obeyed)
    )
  }
  chain_info <- dplyr::bind_rows(chain_info, .id = 'chain')
  
  accept_summary <- chain_info |>
    dplyr::group_by_at(group_by) |>
    dplyr::summarise(mean_accept = mean(.data$accept),
                     mean_black_obeyed = mean(.data$black_obeyed),
                     mean_white_obeyed = mean(.data$white_obeyed),
                     n_accept = sum(.data$accept),
                     n_total = dplyr::n())
  
  return(accept_summary)
}

#' @export
CalculateAcceptanceRates.cia_chain <- function(chains, group_by = NULL) { 
  chains <- new_cia_chain(chains)
  accept_summary <- CalculateAcceptanceRates(chains)
  
  return(accept_summary)
}

Try the cia package in your browser

Any scripts or data that you put into this service are public.

cia documentation built on April 4, 2025, 5:23 a.m.