R/active_learning.R

#' Calculate information gain
#' @param net a DAG
#' @param selected a set of already-selected intervention targets
#' @param candidate a candidate for intervention
#' @param beta an optional dataframe for the edgewise-prior
#' @export
info_gain <- function(net, selected, candidate, beta = NULL){
  ce_start <- nrow(bnlearn::undirected.arcs(ctsdag(net, interventions = selected, beta = beta)))
  ce_next <- nrow(bnlearn::undirected.arcs(ctsdag(net, interventions = c(selected, candidate), beta = beta)))
  .5 * (ce_start - ce_next)
}

#' Bayesian Hypothesis Test of Zero Information Gain
#' Returns the probability of no information gain
#' @param info_gain_dist A distribution of information gain values sampled from the posterior
zero_gain_prob <- function(info_gain_dist,  debug = FALSE){
  p0 <- sum(info_gain_dist == 0)/length(info_gain_dist) # Prob of null (0 information gain)
  if(debug) message("Probability of no gain: ", round(p0, 4))
  p0
}

#' Select a Candidate for Intervention
#' Selects the a target for intervention from a list of candidates based on which is expected to provide
#' the most information gain.
#' @param DAGS A sample of high scoring DAG structures
#' @param selected a set of already selected interventions
#' @param candidates a set of potential targets
#' @param cluster a cluster object such as in the package "parallel"
#' @param k threshold stopping critera (threshold for odds of no information gain for an intervention)
#' @param beta dataframe for edge-wise probability prior
#' @param debug boolean for reporting debugging messages
#' @export
select_next_intervention <- function(dags, selected, candidates, k, beta = NULL, cluster = NULL, debug = FALSE){
  top_candidate <- NULL
  max_gain <- 0
  max_mean <- 0
  top_dist <- NULL
  p0 <- 1
  for(node in candidates){
    if(is.null(cluster)){
      info_gain_dist <- unlist(lapply(dags, info_gain, selected = selected, candidate = node, beta = beta))
    }else{
      info_gain_dist <- unlist(parLapply(cluster, dags, info_gain, selected = selected, candidate = node, beta = beta))
    }
    info_gain_est <- mean(info_gain_dist)
    if(debug) message("information gain estimate for ", node, " is ", info_gain_est)
    if(info_gain_est > max_gain){
      top_candidate <- node
      max_gain <- info_gain_est
      top_dist <- info_gain_dist
      }
  }
  # Reporting results of stopping test rather than enforcing a hard stop.
  if(!is.null(top_dist)){
    p0 <- zero_gain_prob(top_dist,  debug)
  }
  list(top_candidate = top_candidate,
       predicted_gain = max_gain,
       top_dist = top_dist,
       p0 = p0)
}

#' Randomly Select Candidates for Intervention
#' This function provides all the output of select_next_intervention, except that it selects
#' interventions randomly. This is used for benchmarking purposes.
#' @param DAGS A sample of high scoring DAG structures
#' @param selected a set of already selected interventions
#' @param candidates a set of potential targets
#' @param cluster a cluster object such as in the package "parallel"
#' @param k threshold stopping critera (threshold for odds of no information gain for an intervention)
#' @param debug boolean for reporting debugging messages
#' @export
select_random_intervention <- function(dags, selected, candidates, k, cluster = NULL, debug = FALSE){
  node <- sample(candidates, 1)
  if(is.null(cluster)){
    info_gain_dist <- unlist(lapply(dags, info_gain, selected = selected, candidate = node))
  }else{
    info_gain_dist <- unlist(parLapply(cluster, dags, info_gain, selected = selected, candidate = node))
  }
  info_gain_est <- mean(info_gain_dist)
  if(debug) message("information gain estimate for ", node, " is ", info_gain_est)
  # Reporting results of stopping test rather than enforcing a hard stop.
  p0 <- zero_gain_prob(info_gain_dist,  debug)
  list(candidate = node,
       predicted_gain = info_gain_est,
       gain_dist = info_gain_dist,
       p0 = p0)
}

#' Active Learning Algorithm
#' Orders candidate intervention targets by their expected information gain.
#' @param .data the historic data used to generate edge probabilities.
#' @param candidates candidate intervention targets
#' @param beta dataframe for edge-wise probability prior
#' @param algorithm.args arguments passed to the structure inference algorithm used to calculate edge probabilities.
#' @param debug if TRUE prints messages useful for debugging.
#' @export
active_learning <- function(dags, candidates, beta = NULL, cluster = NULL, debug = FALSE){
  info_gains <- NULL
  select <- NULL
  all_results <- NULL
  selected <- NULL
  p0_vals <- NULL
  while(length(candidates) > 0){
    sim_results <- select_next_intervention(dags, selected, candidates, cluster = cluster, k = k, beta = beta, debug = debug)
    next_inh <- sim_results$top_candidate
    next_gain <- sim_results$predicted_gain
    p0_vals <- c(p0_vals, sim_results$p0)
    info_gains <- c(info_gains, next_gain)
    message("next_gain: ", round(next_gain, 4))
    if(next_gain == 0) break
    candidates <- setdiff(candidates, next_inh)
    selected <- c(selected, next_inh)
    all_results <- c(all_results, structure(list(sim_results), names = paste0(selected, collapse = "-")))
    print(selected)
    if(length(candidates) == 0) break
  }
  list(selected = selected, info_gains = info_gains, p0 = p0_vals, all_results = all_results)
}

#' Passive Learning (Active Learning Algorithm Benchmark)
#' Randomly orders candidate intervention targets.
#' @param .data the historic data used to generate edge probabilities.
#' @param candidates candidate intervention targets
#' @param algorithm.args arguments passed to the structure inference algorithm used to calculate edge probabilities.
#' @param debug if TRUE prints messages useful for debugging.
#' @export
passive_learning <- function(dags, candidates, cluster = NULL, debug = FALSE){
  info_gains <- NULL
  select <- NULL
  all_results <- NULL
  selected <- NULL
  p0_vals <- NULL
  while(length(candidates) > 0){
    sim_results <- select_random_intervention(dags, selected, candidates, cluster = cluster, k = k, debug = debug)
    next_inh <- sim_results$candidate
    next_gain <- sim_results$predicted_gain
    p0_vals <- c(p0_vals, sim_results$p0)
    info_gains <- c(info_gains, next_gain)
    message("next_gain: ", round(next_gain, 4))
    if(next_gain == 0) break
    candidates <- setdiff(candidates, next_inh)
    selected <- c(selected, next_inh)
    all_results <- c(all_results, structure(list(sim_results), names = paste0(selected, collapse = "-")))
    print(selected)
    if(length(candidates) == 0) break
  }
  list(selected = selected, info_gains = info_gains, p0 = p0_vals, all_results = all_results)
}
robertness/bninfo documentation built on May 27, 2019, 10:32 a.m.