R/performance.R

#' List of true positive, false positive, and false negatives in edge inference
#'
#' @export
performance_outcomes <- function(inferred, ground_truth){
  if("bn.fit" %in% class(ground_truth)){
    ground_truth <- fit2net(ground_truth)
  }
  ground_truth_edges <- name_edges(ground_truth)
  inferred_edges <- name_edges(inferred)
  detected_edges <- intersect(inferred_edges, ground_truth_edges)
  false_negatives <- setdiff(ground_truth_edges, inferred_edges)
  false_positives <- setdiff(inferred_edges, ground_truth_edges)
  list(tp = detected_edges, fp = false_positives, fn = false_negatives)
}


#' Performance predictions and labels for structure inference with model averaging
#'
#' Compares model averaging results in Bayesian network structure inference to a ground truth
#' network and obtains a table of predictions and labels for the presence of edges.
#'
#' Note: this looks at performance at the edge level, not the network level.  A DAG assembled from
#' model averaging results frequently requires removal of some high scoring edges to avoid creation of cycles.
#' By focusing purely on edge discovery, that constraint is avoided.
#'
#' In causal network inference edge direction has a causal interpretation, but generally edges do not reflect
#' causality, just conditional dependence relationships. By setting the directed_edges argument to FALSE,
#' prediction values for undirected edges are calculated.  This enables evaluation of performance of recovery
#' of conditional dependence relationships, and may be more informative than direct edge recovery if the
#' network inference doesn't meet the assumptions of causal inference.
#'
#' @param strength model averaging resuts, an object of class bn.strength.  See ?bn.strength.
#' @param true An object of class bn representing the ground truth network.
#' @param directed_edges boolean that defaults to TRUE.  TRUE if you want to evaluate to recovery of
#' directed edges, FALSE for the recovery of undirected.
#' @return A table containing the from and to nodes, a prediction score, and a label indicating presence
#' in the reference graph.
#' @examples
#' net <- simGaussianNet(8)
#' net_structure <- bn.net(net)
#' sim_data <- rbn(net, 1000)
#' ma_results <- boot.strength(sim_data, R = 100, m = 1000, algorithm = "tabu",
#'                            algorithm.args = list(score = "bic-g"))
#' output <- get_performance(net_structure, ma_results)
#' library(ROCR)
#' output %$%
#'    prediction %>%
#'    performance("tpr", "fpr") %>%
#'    plot
#' @export
get_performance <- function(strength, ground_truth, consider_direction = TRUE){
  if("bn.fit" %in% class(ground_truth)) ground_truth <- fit2net(ground_truth)
  if(nrow(undirected.arcs(net)) > 0 && consider_direction){
    stop("consider_direction is set to true but some edges are undirected.")
  }
  if(!consider_direction) ground_truth <- moral(ground_truth)
  edge_labels <- name_edge_df(ground_truth) %>% # coerce to data frame
    mutate(label = 1) %>% # label edges present in the ground truth net as 1
    {dplyr::select(., edge_name, label)}
  strength %>% # Take the strength object
    mutate(edge_name = arcs2names(.[c("from", "to")], directed_edges = consider_direction)) %>% # name the edges
    {
      if(consider_direction){ # Set prediction values to (strength * direction) if directed, (strength) otherwise
        return(mutate(., prediction = strength * direction))
      }else{
        return(mutate(., prediction = strength))
      }
    } %>%
#    select(from, to, edge_name, prediction) %>% # pull from, to, edge_name, and prediction
    {
      if(!consider_direction) return(filter(., !duplicated(edge_name))) # if using undirected edges, get unique
      return(.)
    } %>%
    merge(edge_labels, by = "edge_name", all.x = TRUE) %>% # merge both columns by 'edge_name'
    mutate(label = ifelse(is.na(label), 0, 1)) %>% # convert NA values in label to 0
    mutate(prediction = as.numeric(prediction), as.numeric(label)) %>%
    select(from, to, strength, direction, prediction, label)
}



#' Count the True Positive and False Positive Inferred Edges
#' @param inferred The learned/inferred network
#' @param ground_truth The ground truth network
#' @export
count_positives <- function(inferred, ground_truth){
  if("bn.fit" %in% class(ground_truth)){
    ground_truth <- fit2net(ground_truth)
  }
  results <- performance_outcomes(inferred, ground_truth)
  fp_count <- length(results$fp)
  tp_count <- length(results$tp)
  list(tp = tp_count, fp = fp_count)
}

#' Label arcs accourding to their detection in network inference
#'
#' Returns dataframes of arcs/edges labeled according to their status as
#' true positives, false positives, or true negatives in arc detection.
#' @param inferred an inferred network structure, an object of class \emph{bn}.
#' @param ground_truth the ground truth network structure, as in the gold standard against
#' which the inferred network is compared.
#' @return a list of two elements, each a dataframe of arcs.  The first to columns
#' are the "to" and "from" nodes with exactly the same interpretation as one would expect
#' from the \emph{arcs} function
#' @seealso performance_plot
#' @export
performance_arc_list <- function(inferred, ground_truth){
  ground_truth_df <- name_edge_df(ground_truth)
  ground_truth_edges <- name_edges(ground_truth)
  inferred_df <- name_edge_df(inferred)
  inferred_edges <- name_edges(inferred)
  detected_edges <- intersect(inferred_edges, ground_truth_edges)
  false_negatives <- setdiff(ground_truth_edges, inferred_edges)
  false_positives <- setdiff(inferred_edges, ground_truth_edges)
  tp_and_fn <- list(detected = data.frame(edge_name = detected_edges),
                    fn = data.frame(edge_name = false_negatives)) %>%
    lapply(function(item){
      if(!is.null(item)){
        merge(item, ground_truth_df, by = "edge_name") %>%
          select(from, to)
      } else {
        NULL
      }
    }) %>%
  {.[[1]]$type <- "true positive"
   .[[2]]$type <- "false negative"
   rbind(.[[1]], .[[2]])
  }
  tp_and_fp <- list(detected = data.frame(edge_name = detected_edges),
                    fp = data.frame(edge_name = false_positives)) %>%
    lapply(function(item){
      if(!is.null(item)){
        merge(item, inferred_df, by = "edge_name") %>%
          select(from, to)
      } else {
        NULL
      }
    }) %>%
  {.[[1]]$type <- "true positive"
   .[[2]]$type <- "false positive"
   rbind(.[[1]], .[[2]])
  }
  list(tp_and_fn = tp_and_fn,
       tp_and_fp = tp_and_fp)
}

#' @rdname performance_plot
#' @export
performance_plot_list <- function(inferred, ground_truth){
  ground_truth_df <- name_edge_df(ground_truth)
  ground_truth_edges <- name_edges(ground_truth)
  inferred_df <- name_edge_df(inferred)
  inferred_edges <- name_edges(inferred)
  detected_edges <- intersect(inferred_edges, ground_truth_edges)
  false_negatives <- setdiff(ground_truth_edges, inferred_edges)
  false_positives <- setdiff(inferred_edges, ground_truth_edges)
  highlight_ground_truth <- list(
    detected = list(detected_edges, "green", "solid"),
    fn = list(false_negatives, "black", "dashed")
  ) %>%
    lapply(function(item){
      if(length(item[[1]]) > 0){
        data.frame(edge_name = item[[1]], col = item[[2]], lty = item[[3]])
      } else {
        NULL
      }
    }) %>%
    lapply(function(item){
      if(!is.null(item)){
        merge(item, ground_truth_df, by = "edge_name") %>%
          select(from, to, col, lty)
      } else {
        NULL
      }
    }) %>%
  {do.call("rbind", .)} %>%
  {list(arcs = as.matrix(dplyr::select(., from, to)),
        col = as.character(.$col),
        lty = as.character(.$lty))}
##
  highlight_inferred <- list(
    detected = list(detected_edges, "green", "solid"),
    fp = list(false_positives, "darkred", "solid")
  ) %>%
  lapply(function(item){
    if(length(item[[1]]) > 0){
      data.frame(edge_name = item[[1]], col = item[[2]], lty = item[[3]])
    } else {
      NULL
    }
  }) %>%
  lapply(function(item){
    if(!is.null(item)){
      merge(item, inferred_df, by = "edge_name") %>%
        select(from, to, col, lty)
    } else {
      NULL
    }
  }) %>%
  {do.call("rbind", .)} %>%
  {list(arcs = as.matrix(dplyr::select(., from, to)),
        col = as.character(.$col),
        lty = as.character(.$lty))}
   list(ground_truth = highlight_ground_truth,
        inferred = highlight_inferred)
}
#' Visualizing Performance of Network Inference on Model Averaging Results
#'
#' Given a ground truth network structure and an inferred (learned) network structure,
#' visualize the performance of the network inferrence by highlighting the true positive edges
#' (detected edges) and false negative edges (undetected edges) on the ground truth network,
#' and by viewing the true positive edges and false positive edges (detected edges that
#' are not in the ground truth network) on the inferred network.
#' True positives are highlighted in green, false positives are highlighted in red,
#' and false negatives are dashed lines.
#'
#' @param inferred an inferred network structure, an object of class \emph{bn}.
#' @param ground_truth the ground truth network structure, as in the gold standard against
#' which the inferred network is compared.
#' @param plot_truth if TRUE then plot true positives and false negatives on the ground
#' truth network structure, FALSE then plot true positives and false positives on the
#' inferred network structure.
#' @export
performance_plot <- function(inferred, ground_truth, plot_truth){
  highlight_lists <- performance_plot_list(inferred, ground_truth)
  if(plot_truth){
    graphviz.plot(ground_truth, highlight = highlight_lists$ground_truth)
  }else{
    graphviz.plot(inferred, highlight = highlight_lists$inferred)
  }
}


#' Visualizing Performance of Network Inference on Model Averaging Results
#'
#' A wrapper for the strength.plot function in bnlearn.  See ?strength.plot.
#' This does the same thing, except it colors edges by whether they are true positives (green),
#' false positives (red), or false negatives (blue), adding more information to the graph.
#'
#' @param strength model averaging resuts, an object of class bn.strength.  See ?bn.strength.
#' @param ground_truth the graph structure to be plotted, sn object of class bn.
#' @param plot_truth if TRUE plot on 'true' (reference) graph structure, FALSE then plot on consensus graph structure
#' @param threshold a numeric value, the minimum strength required for an arc to be included in the averaged network.
#' If nothing is entered, the threshold attribute of the strength argument is used.
#' @export
strength_plot <- function(strength, ground_truth, directed_edges = TRUE, plot_truth = TRUE, threshold = NULL){
  highlight <- strength_plot_list(strength, ground_truth, directed_edges = directed_edges, plot_truth = plot_truth, threshold = threshold)
  if(plot_truth){
    strength.plot(ground_truth, strength, highlight = highlight)
  } else {
    strength.plot(consensus_net, strength, highlight = highlight)
  }
}
#' @rdname strength_plot
#' @export
strength_plot_list <- function(strength, ground_truth, directed_edges = TRUE, plot_truth = TRUE, threshold = NULL){
  if(is.null(threshold)) {
    consensus_net <- averaged.network(strength)
  }else {
    consensus_net <- averaged.network(strength, threshold = threshold)
  }
  if(!directed_edges){
    consensus_net <- moral(consensus_net)
    ground_truth <- moral(ground_truth)
  }
  ground_truth_df <- name_edge_df(ground_truth)
  if(plot_truth){
    graph_df <- ground_truth_df
  } else{
    graph_df <- name_edge_df(consensus_net)
  }
  consensus_arcs <- arcs2names(arcs(consensus_net), directed_edges)
  ground_truth_arcs <- arcs2names(arcs(ground_truth), directed_edges)
  highlight <- list(tp = list(intersect(ground_truth_arcs, consensus_arcs), "darkgreen"),
                    fp = list(setdiff(consensus_arcs, ground_truth_arcs), "darkred"),
                    fn = list(setdiff(ground_truth_arcs, consensus_arcs), "darkblue")) %>%
    lapply(function(item){
      if(length(item[[1]]) > 0){
        data.frame(edge_name = item[[1]], col = item[[2]])
      } else {
        NULL
      }
    }) %>%
    lapply(function(item){
      if(!is.null(item)){
        merge(item, graph_df, by = "edge_name") %>%
          {.[, c("from", "to", "col")]}
      } else {
        NULL
      }
    }) %>%
    {do.call("rbind", .)} %>%
    {list(arcs = as.matrix(dplyr::select(., from, to)), col = as.character(.$col))}
  highlight
}


#' Visualize edge detection performance in presence of prior edges.
#'
#' In network structure inference, when some edges are known to exist a priori, it
#' is not useful to treat those edges as one would treat newly detected edges
#' when evaluating performance.  Given a gold standard network, this function
#' visualizes detection of new edges given model averaging results.
#' Prior edges are solid black lines, newly detected edges are solid green lines,
#' undetected edges are dashed black lines.
#'
#' @param model_averaging an object of class bn.strength
#' @param ref_net the gold standard network structure
#' @param prior_edges matrix of a priori known edges
#' @param threshold the cut off for edge inclusion, defaults to NULL where the threshold attribute
#' of the model_averaging argument is used.
#' @export
progress_plot <- function(model_averaging, ref_net, prior_edges, threshold = NULL){
  highlight <- progress_plot_list(model_averaging, ref_net, prior_edges, threshold)
  graphviz.plot(ref_net, highlight = highlight)
}
#' @rdname progress_plot
#' @export
progress_plot_list <- function(model_averaging, ref_net, prior_edges, threshold = NULL){
  if(!("bn.strength" %in% class(model_averaging))) {
    error("model_averaging argument should be of the 'bn.strength' class")
  }
  ref_net_df <- arcs(ref_net) %>%
    data.frame %>%
    mutate(edge_name = arcs2names(.))
  if(is.null(threshold)) threshold <- attr(model_averaging, "threshold")
  candidate_edges <- model_averaging %>%
    mutate(edge_name = paste(from, to, sep = "->")) %>%
    filter(edge_name %in% ref_net_df$edge_name) %>%
    filter(strength * direction >= threshold)
  detected_edges <- ref_net_df$edge_name %>%
    intersect(candidate_edges$edge_name) %>%
    setdiff(arcs2names(prior_edges))
  highlight <- list(
    detected = list(detected_edges, "green", "solid"),
    fn = list(setdiff(ref_net_df$edge_name, candidate_edges$edge_name), "black", "dashed")
  ) %>%
    lapply(function(item){
      if(length(item[[1]]) > 0){
        data.frame(edge_name = item[[1]], col = item[[2]], lty = item[[3]])
      } else {
        NULL
      }
    }) %>%
    lapply(function(item){
      if(!is.null(item)){
        merge(item, ref_net_df, by = "edge_name") %>%
          select(from, to, col, lty)
      } else {
        NULL
      }
    }) %>%
    {do.call("rbind", .)} %>%
    rbind(cbind(prior_edges, col = "black", lty = "solid")) %>%
      {list(arcs = as.matrix(dplyr::select(., from, to)),
            col = as.character(.$col),
            lty = as.character(.$lty))}
}
robertness/bninfo documentation built on May 27, 2019, 10:32 a.m.