#' Calculate information gain
#' @param net a DAG
#' @param selected a set of already-selected intervention targets
#' @param candidate a candidate for intervention
#' @param beta an optional dataframe for the edgewise-prior
#' @export
info_gain <- function(net, selected, candidate, beta = NULL){
ce_start <- nrow(bnlearn::undirected.arcs(ctsdag(net, interventions = selected, beta = beta)))
ce_next <- nrow(bnlearn::undirected.arcs(ctsdag(net, interventions = c(selected, candidate), beta = beta)))
.5 * (ce_start - ce_next)
}
#' Bayesian Hypothesis Test of Zero Information Gain
#' Returns the probability of no information gain
#' @param info_gain_dist A distribution of information gain values sampled from the posterior
zero_gain_prob <- function(info_gain_dist, debug = FALSE){
p0 <- sum(info_gain_dist == 0)/length(info_gain_dist) # Prob of null (0 information gain)
if(debug) message("Probability of no gain: ", round(p0, 4))
p0
}
#' Select a Candidate for Intervention
#' Selects the a target for intervention from a list of candidates based on which is expected to provide
#' the most information gain.
#' @param DAGS A sample of high scoring DAG structures
#' @param selected a set of already selected interventions
#' @param candidates a set of potential targets
#' @param cluster a cluster object such as in the package "parallel"
#' @param k threshold stopping critera (threshold for odds of no information gain for an intervention)
#' @param beta dataframe for edge-wise probability prior
#' @param debug boolean for reporting debugging messages
#' @export
select_next_intervention <- function(dags, selected, candidates, k, beta = NULL, cluster = NULL, debug = FALSE){
top_candidate <- NULL
max_gain <- 0
max_mean <- 0
top_dist <- NULL
p0 <- 1
for(node in candidates){
if(is.null(cluster)){
info_gain_dist <- unlist(lapply(dags, info_gain, selected = selected, candidate = node, beta = beta))
}else{
info_gain_dist <- unlist(parLapply(cluster, dags, info_gain, selected = selected, candidate = node, beta = beta))
}
info_gain_est <- mean(info_gain_dist)
if(debug) message("information gain estimate for ", node, " is ", info_gain_est)
if(info_gain_est > max_gain){
top_candidate <- node
max_gain <- info_gain_est
top_dist <- info_gain_dist
}
}
# Reporting results of stopping test rather than enforcing a hard stop.
if(!is.null(top_dist)){
p0 <- zero_gain_prob(top_dist, debug)
}
list(top_candidate = top_candidate,
predicted_gain = max_gain,
top_dist = top_dist,
p0 = p0)
}
#' Randomly Select Candidates for Intervention
#' This function provides all the output of select_next_intervention, except that it selects
#' interventions randomly. This is used for benchmarking purposes.
#' @param DAGS A sample of high scoring DAG structures
#' @param selected a set of already selected interventions
#' @param candidates a set of potential targets
#' @param cluster a cluster object such as in the package "parallel"
#' @param k threshold stopping critera (threshold for odds of no information gain for an intervention)
#' @param debug boolean for reporting debugging messages
#' @export
select_random_intervention <- function(dags, selected, candidates, k, cluster = NULL, debug = FALSE){
node <- sample(candidates, 1)
if(is.null(cluster)){
info_gain_dist <- unlist(lapply(dags, info_gain, selected = selected, candidate = node))
}else{
info_gain_dist <- unlist(parLapply(cluster, dags, info_gain, selected = selected, candidate = node))
}
info_gain_est <- mean(info_gain_dist)
if(debug) message("information gain estimate for ", node, " is ", info_gain_est)
# Reporting results of stopping test rather than enforcing a hard stop.
p0 <- zero_gain_prob(info_gain_dist, debug)
list(candidate = node,
predicted_gain = info_gain_est,
gain_dist = info_gain_dist,
p0 = p0)
}
#' Active Learning Algorithm
#' Orders candidate intervention targets by their expected information gain.
#' @param .data the historic data used to generate edge probabilities.
#' @param candidates candidate intervention targets
#' @param beta dataframe for edge-wise probability prior
#' @param algorithm.args arguments passed to the structure inference algorithm used to calculate edge probabilities.
#' @param debug if TRUE prints messages useful for debugging.
#' @export
active_learning <- function(dags, candidates, beta = NULL, cluster = NULL, debug = FALSE){
info_gains <- NULL
select <- NULL
all_results <- NULL
selected <- NULL
p0_vals <- NULL
while(length(candidates) > 0){
sim_results <- select_next_intervention(dags, selected, candidates, cluster = cluster, k = k, beta = beta, debug = debug)
next_inh <- sim_results$top_candidate
next_gain <- sim_results$predicted_gain
p0_vals <- c(p0_vals, sim_results$p0)
info_gains <- c(info_gains, next_gain)
message("next_gain: ", round(next_gain, 4))
if(next_gain == 0) break
candidates <- setdiff(candidates, next_inh)
selected <- c(selected, next_inh)
all_results <- c(all_results, structure(list(sim_results), names = paste0(selected, collapse = "-")))
print(selected)
if(length(candidates) == 0) break
}
list(selected = selected, info_gains = info_gains, p0 = p0_vals, all_results = all_results)
}
#' Passive Learning (Active Learning Algorithm Benchmark)
#' Randomly orders candidate intervention targets.
#' @param .data the historic data used to generate edge probabilities.
#' @param candidates candidate intervention targets
#' @param algorithm.args arguments passed to the structure inference algorithm used to calculate edge probabilities.
#' @param debug if TRUE prints messages useful for debugging.
#' @export
passive_learning <- function(dags, candidates, cluster = NULL, debug = FALSE){
info_gains <- NULL
select <- NULL
all_results <- NULL
selected <- NULL
p0_vals <- NULL
while(length(candidates) > 0){
sim_results <- select_random_intervention(dags, selected, candidates, cluster = cluster, k = k, debug = debug)
next_inh <- sim_results$candidate
next_gain <- sim_results$predicted_gain
p0_vals <- c(p0_vals, sim_results$p0)
info_gains <- c(info_gains, next_gain)
message("next_gain: ", round(next_gain, 4))
if(next_gain == 0) break
candidates <- setdiff(candidates, next_inh)
selected <- c(selected, next_inh)
all_results <- c(all_results, structure(list(sim_results), names = paste0(selected, collapse = "-")))
print(selected)
if(length(candidates) == 0) break
}
list(selected = selected, info_gains = info_gains, p0 = p0_vals, all_results = all_results)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.