R/compositemodelfunctions.R

Defines functions ComputeMetaFeatureWeights ComputeIndividualPerformance PruneForwardStepwise PruneBackwardStepwise PruneStepwiseParallel PruneExhaustive PrunePredictors ObtainCompositeModels ObtainSubgraphNeighborhoods ComputeTScore ComputeSignificance CompositePrediction DoSignificancePropagation

Documented in CompositePrediction ComputeIndividualPerformance ComputeMetaFeatureWeights ComputeSignificance ComputeTScore DoSignificancePropagation ObtainCompositeModels ObtainSubgraphNeighborhoods PruneBackwardStepwise PruneExhaustive PruneForwardStepwise PrunePredictors PruneStepwiseParallel

#' Obtain a prediction from a composite model, given the pairs to include in the model.
#' @param pairs A list of pairs to include in the composite model.
#' @param modelResults A ModelResults object.
#' @param verbose Whether or not to print out each step.
#' @param pruningMethod The method to use for pruning. Right now, only "error.t.test" is valid.
#' @param modelRetention Strategy for model retention. "stringent" (the default)
#' retains only models that improve the prediction score. "lenient" also retains models that
#' neither improve nor reduce the prediction score.
#' @param minCutoff Mininum cutoff for the prediction.
#' @param maxCutoff Maximum cutoff for the prediction.
#' @param useCutoff Whether or not to use the cutoff for prediction. Default is FALSE.
#' @param components A Model object containing the list of components in the graph.
#' @param pruningTechnique Pruning technique to use. Possible methods are "backward.stepwise",
#' "forward.stepwise", "individual.performance", and "exhaustive".
#' @param covar Covariates in the model
#' @param weights Current weights
#' @param doPooling Whether or not to pool predictors together using the structure of the graph.
#' @param averaging If TRUE, then averaging is used to combine predictors rather than
#' retaining the same functional form for both the input and the output.
#' @param zeroOut This parameter zeros out predictors outside of the allowed
#' range.
#' @param feedbackPairs The pairs to use when calculating the feedback model
#' @param trimming Set to "edgewise" to trim edges at each layer or "modelwise"
#' to trim entire models (neighborhoods, connected components) at each layer.
#' @return A final predicted value for each sample in the input data
#' @export
DoSignificancePropagation <- function(pairs, modelResults, covar = c(), verbose = FALSE,
                                      pruningMethod = "error.t.test", modelRetention = "stringent",
                                      minCutoff, maxCutoff, useCutoff = FALSE, weights, components,
                                      pruningTechnique = "backward.stepwise", doPooling = TRUE,
                                      averaging = FALSE, zeroOut = FALSE, feedbackPairs = NULL,
                                      trimming = "modelwise"){   
  # Convert true values to numeric.
  trueVals <- as.matrix(modelResults@model.input@input.data@sampleMetaData[,modelResults@model.input@stype])
  if(!is.numeric(trueVals)){
    catNames <- sort(unique(trueVals))
    trueValuesChar <- trueVals
    trueVals <- matrix(rep(0, length(trueValuesChar)), ncol = ncol(trueValuesChar),
                       nrow = nrow(trueValuesChar))
    trueVals[which(trueValuesChar == catNames[2])] <- 1
  }

  # Initialize consolidated pairs.
  targets <- pairs$target
  sources <- pairs$source
  pairs <- pairs$edge
  prevModels <- methods::new("Model", modelsKept = 1:length(unlist(pairs)),
                    pairs = as.list(unlist(pairs)), targets = as.list(as.character(unlist(targets))), 
                    sources = as.list(as.character(unlist(sources))))
  currentModel <- methods::new("Model", modelsKept = 1:length(pairs),
                                    pairs = pairs, targets = targets, sources = sources)
  consolidated <- methods::new("CompositeModelSet", compositeModels = currentModel,
                      expandedCompositeModels = currentModel,
                      mapping = data.frame(from = 1:length(unlist(pairs)),
                                           to = unlist(lapply(1:length(pairs), function(i){rep(i, length(pairs[[i]]))}))))
  layerNumber <- 1
  # Compute individual performance of each predictor.
  individualPerformance <- ComputeIndividualPerformance(predictions = modelResults@model.input@edge.wise.prediction,
                                                          trueVal = trueVals)
  individualPerformance <- individualPerformance[unlist(pairs)]

  # If we are not pooling, simply prune all predictors simultaneously.
  prunedPairs <- pairs
  if(doPooling == FALSE){
    prunedPairs <- GENN::PrunePredictors(compositeSubgraphs = consolidated,
                                                              previousModels = prevModels,
                                                              modelResults = modelResults, 
                                                              verbose = verbose,
                                                              pruningMethod = pruningMethod,
                                                              modelRetention = modelRetention,
                                                              minCutoff = minCutoff,
                                                              maxCutoff = maxCutoff,
                                                              useCutoff = useCutoff,
                                                              weights = weights,
                                                              individualPerformance = individualPerformance,
                                                              layerNumber = layerNumber,
                                                              pruningTechnique = pruningTechnique,
                                                              averaging = averaging,
                                                              zeroOut = zeroOut)
  }else{
    # Run the feedback layer if applicable.
    prunedFeedbackPairs <- NULL
    feedbackSignificance <- NULL
    if(!is.null(feedbackPairs)){
      consolidatedFeedback <- methods::new("CompositeModelSet", compositeModels = methods::new("Model", 
                                                                                               pairs = list(unlist(prevModels@pairs)),
                                                                                               targets = list(unlist(prevModels@targets)),
                                                                                               sources = list(unlist(prevModels@sources))),
                                   expandedCompositeModels = methods::new("Model",
                                                                          pairs = list(unlist(prevModels@pairs)),
                                                                          targets = list(unlist(prevModels@targets)),
                                                                          sources = list(unlist(prevModels@sources))))
      consolidatedFeedback@mapping <- data.frame(from = 1:length(unlist(consolidatedFeedback@compositeModels@pairs)), to = 1)
      prunedFeedbackPairs <- GENN::PrunePredictors(compositeSubgraphs = consolidatedFeedback,
                                                                previousModels = prevModels,
                                                                modelResults = modelResults,
                                                                verbose = verbose,
                                                                pruningMethod = pruningMethod,
                                                                modelRetention = modelRetention,
                                                                minCutoff = minCutoff,
                                                                maxCutoff = maxCutoff,
                                                                useCutoff = useCutoff,
                                                                weights = weights,
                                                                individualPerformance = individualPerformance,
                                                                layerNumber = layerNumber,
                                                                pruningTechnique = pruningTechnique,
                                                                averaging = averaging,
                                                                zeroOut = zeroOut)
      feedbackPrediction <- GENN::CompositePrediction(pairs = unlist(prunedFeedbackPairs@pairs),
                                                                       targets = unlist(prunedFeedbackPairs@targets),
                                                                       sources = unlist(prunedFeedbackPairs@sources),
                                                                       modelResults = modelResults,
                                                                       minCutoff = minCutoff,
                                                                       maxCutoff = maxCutoff,
                                                                       useCutoff = useCutoff,
                                                                       weights = weights,
                                                                       averaging = averaging,
                                                                       zeroOut = zeroOut)
      feedbackSignificance <- GENN::ComputeSignificance(pred = unlist(feedbackPrediction),
                                                                     trueVal = trueVals,
                                                                     pruningMethod = pruningMethod)
    }
    # Run the first layer.
    prunedPairs <- GENN::PrunePredictors(compositeSubgraphs = consolidated,
                                                              previousModels = prevModels,
                                                              modelResults = modelResults, 
                                                              verbose = verbose,
                                                              pruningMethod = pruningMethod,
                                                              modelRetention = modelRetention,
                                                              minCutoff = minCutoff,
                                                              maxCutoff = maxCutoff,
                                                              useCutoff = useCutoff,
                                                              weights = weights,
                                                              individualPerformance = individualPerformance,
                                                              layerNumber = layerNumber,
                                                              pruningTechnique = pruningTechnique,
                                                              averaging = averaging,
                                                              zeroOut = zeroOut)
    # If we are including feedback, put the trimmed pairs from the feedback model
    # back in when appropriate.
    if(!is.null(feedbackSignificance)){
      for(m in 1:length(prunedPairs@pairs)){
        compositeModel <- GENN::CompositePrediction(pairs = prunedPairs@pairs[[m]],
                                                                         targets = prunedPairs@targets[[m]],
                                                                         sources = prunedPairs@sources[[m]],
                                                                         modelResults = modelResults,
                                                                         minCutoff = minCutoff,
                                                                         maxCutoff = maxCutoff,
                                                                         useCutoff = useCutoff,
                                                                         weights = weights,
                                                                         averaging = averaging,
                                                                         zeroOut = zeroOut)
        significance <- GENN::ComputeSignificance(pred = unlist(compositeModel),
                                                                       trueVal = trueVals,
                                                                       pruningMethod = pruningMethod)
        if(significance < feedbackSignificance){
          newPairs <- unique(c(prunedPairs@pairs[[m]], intersect(prunedFeedbackPairs@pairs[[1]],
                                                                 consolidated@expandedCompositeModels@pairs[[m]])))
          prunedPairs@pairs[[m]] <- newPairs
          prunedPairs@targets[[m]] <- unlist(lapply(newPairs, function(pair){return(strsplit(pair, split = "__")[[1]][2])}))
          prunedPairs@sources[[m]] <- unlist(lapply(newPairs, function(pair){return(strsplit(pair, split = "__")[[1]][1])}))
        }
      }
    }

    # Perform modelwise or edgewise trimming for subsequent layers.
    if(trimming == "modelwise"){
      
      # Update the previous models to obtain the mapping.
      consolidated <- GENN::ObtainCompositeModels(components = components,
                                                                       pairsInEachPredictor = consolidated@expandedCompositeModels,
                                                                       importantModels = prunedPairs)
      prevModels <- prunedPairs
      
      # Run the second layer.
      layerNumber <- layerNumber + 1
      # Compute individual performance.
      individualPerformance <- unlist(lapply(1:length(prevModels@pairs), function(m){
        compositeModel <- GENN::CompositePrediction(pairs = prevModels@pairs[[m]],
                                                                         targets = prevModels@targets[[m]],
                                                                         sources = prevModels@sources[[m]],
                                                                         modelResults = modelResults,
                                                                         minCutoff = minCutoff,
                                                                         maxCutoff = maxCutoff,
                                                                         useCutoff = useCutoff,
                                                                         weights = weights,
                                                                         averaging = averaging,
                                                                         zeroOut = zeroOut)
        significance <- GENN::ComputeSignificance(pred = unlist(compositeModel),
                                                                       trueVal = trueVals,
                                                                       pruningMethod = pruningMethod)
        return(significance)
      }))
      prunedPairs <- GENN::PrunePredictors(compositeSubgraphs = consolidated,
                                                                previousModels = prevModels,
                                                                modelResults = modelResults, 
                                                                verbose = verbose,
                                                                pruningMethod = pruningMethod,
                                                                modelRetention = modelRetention,
                                                                minCutoff = minCutoff,
                                                                maxCutoff = maxCutoff,
                                                                useCutoff = useCutoff,
                                                                weights = weights,
                                                                individualPerformance = individualPerformance,
                                                                layerNumber = layerNumber,
                                                                pruningTechnique = pruningTechnique,
                                                                averaging = averaging,
                                                                zeroOut = zeroOut)
      
      # If we are including feedback, put the trimmed pairs from the feedback model
      # back in when appropriate.
      if(!is.null(feedbackSignificance)){
        for(m in 1:length(prunedPairs@pairs)){
          compositeModel <- GENN::CompositePrediction(pairs = prunedPairs@pairs[[m]],
                                                                           targets = prunedPairs@targets[[m]],
                                                                           sources = prunedPairs@sources[[m]],
                                                                           modelResults = modelResults,
                                                                           minCutoff = minCutoff,
                                                                           maxCutoff = maxCutoff,
                                                                           useCutoff = useCutoff,
                                                                           weights = weights,
                                                                           averaging = averaging,
                                                                           zeroOut = zeroOut)
          significance <- GENN::ComputeSignificance(pred = unlist(compositeModel),
                                                                         trueVal = trueVals,
                                                                         pruningMethod = pruningMethod)
          if(significance < feedbackSignificance){
            newPairs <- unique(c(prunedPairs@pairs[[m]], intersect(prunedFeedbackPairs@pairs[[1]],
                                                                   consolidated@expandedCompositeModels@pairs[[m]])))
            prunedPairs@pairs[[m]] <- newPairs
            prunedPairs@targets[[m]] <- unlist(lapply(newPairs, function(pair){return(strsplit(pair, split = "__")[[1]][2])}))
            prunedPairs@sources[[m]] <- unlist(lapply(newPairs, function(pair){return(strsplit(pair, split = "__")[[1]][1])}))
          }
        }
      }
      consolidated <- GENN::ObtainCompositeModels(components = components,
                                                                       pairsInEachPredictor = consolidated@expandedCompositeModels,
                                                                       importantModels = prunedPairs)
      
      # If number of composite models > 1, then concatenate all models and prune.
      if(length(consolidated@compositeModels@pairs) > 1){
        prevModels <- prunedPairs
        consolidated@compositeModels <- methods::new("Model", pairs = list(unlist(consolidated@compositeModels@pairs)),
                                                     sources = list(unlist(consolidated@compositeModels@sources)),
                                                     targets = list(unlist(consolidated@compositeModels@targets)),
                                                     modelsKept = unlist(consolidated@compositeModels@modelsKept))
        consolidated@expandedCompositeModels <- methods::new("Model", pairs = list(unlist(consolidated@expandedCompositeModels@pairs)),
                                                             sources = list(unlist(consolidated@expandedCompositeModels@sources)),
                                                             targets = list(unlist(consolidated@expandedCompositeModels@targets)),
                                                             modelsKept = unlist(consolidated@expandedCompositeModels@modelsKept))
        consolidated@mapping$to <- rep(1, nrow(consolidated@mapping))
        prunedPairs <- GENN::PrunePredictors(compositeSubgraphs = consolidated,
                                                                  previousModels = prevModels,
                                                                  modelResults = modelResults, 
                                                                  verbose = verbose,
                                                                  pruningMethod = pruningMethod,
                                                                  modelRetention = modelRetention,
                                                                  minCutoff = minCutoff,
                                                                  maxCutoff = maxCutoff,
                                                                  useCutoff = useCutoff,
                                                                  weights = weights,
                                                                  individualPerformance = individualPerformance,
                                                                  layerNumber = layerNumber,
                                                                  pruningTechnique = pruningTechnique,
                                                                  averaging = averaging,
                                                                  zeroOut = zeroOut)
      }
    }else if(trimming == "edgewise"){
      # In this case, we do not update the previous models (because we always use edgewise models),
      # we do not update individual performance (again, we always use edgewise performance), and
      # we change the pruning technique to backward only.
      consolidated <- GENN::ObtainCompositeModels(components = components,
                                                                       pairsInEachPredictor = consolidated@expandedCompositeModels,
                                                                       importantModels = prunedPairs)
      prevModels <- methods::new("Model", modelsKept = 1:length(unlist(consolidated@compositeModels@pairs)),
                                 pairs = as.list(unlist(consolidated@compositeModels@pairs)), 
                                 targets = as.list(as.character(unlist(consolidated@compositeModels@targets))), 
                                 sources = as.list(as.character(unlist(consolidated@compositeModels@sources))))
      consolidated@mapping <- data.frame(from = 1:length(unlist(consolidated@compositeModels@pairs)), to = 1)
      idx <- 1
      for(i in 1:length(consolidated@compositeModels@pairs)){
        if(length(consolidated@compositeModels@pairs[[i]])>0){
          consolidated@mapping$to[idx:(idx + length(consolidated@compositeModels@pairs[[i]]) - 1)] <- i
          idx <- idx + length(consolidated@compositeModels@pairs[[i]])
        }
      }

      # Run the second layer.
      prunedPairs <- GENN::PrunePredictors(compositeSubgraphs = consolidated,
                                                                previousModels = prevModels,
                                                                modelResults = modelResults, 
                                                                verbose = verbose,
                                                                pruningMethod = pruningMethod,
                                                                modelRetention = modelRetention,
                                                                minCutoff = minCutoff,
                                                                maxCutoff = maxCutoff,
                                                                useCutoff = useCutoff,
                                                                weights = weights,
                                                                individualPerformance = individualPerformance,
                                                                layerNumber = layerNumber,
                                                                pruningTechnique = pruningTechnique,
                                                                averaging = averaging,
                                                                zeroOut = zeroOut)
      # If we are including feedback, put the trimmed pairs from the feedback model
      # back in when appropriate.
      if(!is.null(feedbackSignificance)){
        for(m in 1:length(prunedPairs@pairs)){
          compositeModel <- GENN::CompositePrediction(pairs = prunedPairs@pairs[[m]],
                                                                           targets = prunedPairs@targets[[m]],
                                                                           sources = prunedPairs@sources[[m]],
                                                                           modelResults = modelResults,
                                                                           minCutoff = minCutoff,
                                                                           maxCutoff = maxCutoff,
                                                                           useCutoff = useCutoff,
                                                                           weights = weights,
                                                                           averaging = averaging,
                                                                           zeroOut = zeroOut)
          significance <- GENN::ComputeSignificance(pred = unlist(compositeModel),
                                                                         trueVal = trueVals,
                                                                         pruningMethod = pruningMethod)
          if(significance < feedbackSignificance){
            newPairs <- unique(c(prunedPairs@pairs[[m]], intersect(prunedFeedbackPairs@pairs[[1]],
                                                                   consolidated@expandedCompositeModels@pairs[[m]])))
            prunedPairs@pairs[[m]] <- newPairs
            prunedPairs@targets[[m]] <- unlist(lapply(newPairs, function(pair){return(strsplit(pair, split = "__")[[1]][2])}))
            prunedPairs@sources[[m]] <- unlist(lapply(newPairs, function(pair){return(strsplit(pair, split = "__")[[1]][1])}))
          }
        }
      }
      consolidated <- GENN::ObtainCompositeModels(components = components,
                                                                       pairsInEachPredictor = consolidated@expandedCompositeModels,
                                                                       importantModels = prunedPairs)
      
      # If number of composite models > 1, then concatenate all models and prune.
      if(length(consolidated@compositeModels@pairs) > 1){
        prevModels <- methods::new("Model", modelsKept = 1:length(unlist(consolidated@compositeModels@pairs)),
                                   pairs = as.list(unlist(consolidated@compositeModels@pairs)), 
                                   targets = as.list(as.character(unlist(consolidated@compositeModels@targets))), 
                                   sources = as.list(as.character(unlist(consolidated@compositeModels@sources))))
        consolidated@mapping <- data.frame(from = 1:length(unlist(consolidated@compositeModels@pairs)), to = 1)
        consolidated@compositeModels <- methods::new("Model", pairs = list(unlist(consolidated@compositeModels@pairs)),
                                                     sources = list(unlist(consolidated@compositeModels@sources)),
                                                     targets = list(unlist(consolidated@compositeModels@targets)),
                                                     modelsKept = unlist(consolidated@compositeModels@modelsKept))
        consolidated@expandedCompositeModels <- methods::new("Model", pairs = list(unlist(consolidated@expandedCompositeModels@pairs)),
                                                             sources = list(unlist(consolidated@expandedCompositeModels@sources)),
                                                             targets = list(unlist(consolidated@expandedCompositeModels@targets)),
                                                             modelsKept = unlist(consolidated@expandedCompositeModels@modelsKept))
        prunedPairs <- GENN::PrunePredictors(compositeSubgraphs = consolidated,
                                                                  previousModels = prevModels,
                                                                  modelResults = modelResults, 
                                                                  verbose = verbose,
                                                                  pruningMethod = pruningMethod,
                                                                  modelRetention = modelRetention,
                                                                  minCutoff = minCutoff,
                                                                  maxCutoff = maxCutoff,
                                                                  useCutoff = useCutoff,
                                                                  weights = weights,
                                                                  individualPerformance = individualPerformance,
                                                                  layerNumber = layerNumber,
                                                                  pruningTechnique = pruningTechnique,
                                                                  averaging = averaging,
                                                                  zeroOut = zeroOut)
      }
    }
  }
  
  # Return consolidated pairs.
  return(prunedPairs)
}

#' Obtain a prediction from a composite model, given the pairs to include in the model.
#' @param pairs A list of pairs to include in the composite model.
#' @param modelResults A ModelResults object.
#' @param minCutoff Mininum cutoff for the prediction.
#' @param maxCutoff Maximum cutoff for the prediction.
#' @param useCutoff Whether or not to use the cutoff for prediction. Default is FALSE.
#' @param weights The weight for each predictor, calculated using ComputeMetaFeatureWeights()
#' @param targets Target analytes for all pairs
#' @param sources Source analytes for all pairs
#' @param averaging If TRUE, then averaging is used to combine predictors rather than
#' retaining the same functional form for both the input and the output.
#' @param zeroOut This parameter zeros out predictors outside of the allowed
#' range.
#' @return A final predicted value for each sample in the input data
#' @export
CompositePrediction <- function(pairs, modelResults, minCutoff, maxCutoff, useCutoff = FALSE, weights,
                                targets, sources,averaging = FALSE, zeroOut = FALSE){

  # Run prediction for training data.
  final_val <- Predict(pairs = pairs,
    inputData = modelResults@model.input@input.data, 
                       weights = weights, 
                       model = modelResults, 
                       minCutoff = minCutoff, 
                       maxCutoff = maxCutoff, 
                       useCutoff = useCutoff,
                      useActivation = FALSE,
    independentVarType = modelResults@model.input@independent.var.type, 
    outcomeType = modelResults@model.input@outcome,
    targets = targets,
    sources = sources,
    averaging = averaging,
    zeroOut = zeroOut)
  return(final_val)
}

#' Compute the significance value for a given prediction. You may use information
#' gain, odds ratio, or t-statistic.
#' @param pred Predicted values
#' @param trueVal The true values (predictions or outcomes) of the input data.
#' @param pruningMethod The method to use for pruning. Right now, only "error.t.test" is valid.
#' @export
ComputeSignificance <- function(pred, trueVal, pruningMethod = "error.t.test"){
  if(pruningMethod == "error.t.test"){
    return(ComputeTScore(pred = pred, trueVal = trueVal))
  }else{
    stop(paste(pruningMethod, "is not a valid pruning method."))
  }
}

#' Compute t-score for a prediction using the paired t-test. The t-score
#' tells us how far away the prediction errors are from the mean errors.
#' @param pred Predictions
#' @param trueVal The true values (predictions or outcomes) of the input data.
#' @param includeVarianceTest Scale the t-score by the f-statistic (the ratio of variances).
#' Default is FALSE.
#' @export
ComputeTScore <- function(pred, trueVal, includeVarianceTest = FALSE){
  trueVal <- trueVal[which(!is.nan(pred))]
  pred <- pred[which(!is.nan(pred))]
  tScore <- -1000
  if(length(trueVal)>0 && length(pred)>0){
    
    # Compute the absolute error values.
    meanTrue <- rep(mean(trueVal), length(pred))
    meanAbsError <- abs(trueVal - meanTrue)
    predAbsError <- abs(trueVal - pred)

    # Compute Welch's t-test.
    n <- length(meanAbsError)
    meanMeanAbsError <- mean(meanAbsError)
    meanPredAbsError <- mean(predAbsError)
    tScore <- meanMeanAbsError - meanPredAbsError
    if(n > 1){
      stdevMeanAbsError <- stats::sd(meanAbsError) / sqrt(n)
      stdevPredAbsError <- stats::sd(predAbsError) / sqrt(n)
      tScore <- (meanMeanAbsError - meanPredAbsError) / sqrt((stdevMeanAbsError ^ 2) + (stdevPredAbsError ^ 2))
    }
  }
  
  return(tScore)
}

#' Obtain the list of pairs in a composite model.
#' @param modelInput A ModelInput object
#' @return A list of sets, where each set is a neighborhood in the graph.
#' @export
ObtainSubgraphNeighborhoods <- function(modelInput){
  # Convert to graph.
  fullgraph <- igraph::graph_from_adjacency_matrix(modelInput@coregulation.graph)
  edges <- unlist(lapply(colnames(modelInput@edge.wise.prediction), function(name){
    return(gsub(pattern = "__", replacement = "|", fixed = TRUE, x = name))
  }))
  graph<-igraph::subgraph.edges(fullgraph,edges)
  #sources <- unlist(lapply(colnames(modelInput@edge.wise.prediction), function(name){
  #  return(strsplit(name, "__")[[1]][1])
  #}))
  #targets <- unlist(lapply(colnames(modelInput@edge.wise.prediction), function(name){
  #  return(strsplit(name, "__")[[1]][2])
  #}))
  adjMat <- as.matrix(igraph::as_adjacency_matrix(graph))
  # Return the list of edges for each node so that they do not need to be looked
  # up for each edge.
  allNodes <- igraph::as_ids(igraph::V(graph))
  print("Finding pairs containing each analyte")
  edgesPerNode <- lapply(1:length(allNodes), function(i){
    node <- allNodes[i]
    nodeTargetEdges <- c()
    nodeSourceEdges <- c()
    whichTarget <- c()
    whichSource <- c()
    tryCatch({whichTarget <- colnames(adjMat)[which(adjMat[node,] == 1)]},error=function(cond){})
    tryCatch({whichSource <- rownames(adjMat)[which(adjMat[,node] == 1)]},error=function(cond){})
    if(length(whichTarget) > 0){
      nodeTargetEdges <- paste(node, whichTarget, sep = "|")
    }
    if(length(whichSource) > 0){
      nodeSourceEdges <- paste(whichSource, node, sep = "|")
    }
    return(c(nodeTargetEdges, nodeSourceEdges))
  })
  names(edgesPerNode) <- allNodes
  
  # For each edge, find its nodes' neighbors. Ensure that the neighborhood is not a 
  # subset of any of the neighbors' neighborhoods. This prevents subset graphs
  # from being returned.
  print("Segmenting graph into neighborhoods")
  edgesToCheckQueue <- edges
  edgeNeighborhoods <- list()
  targetNeighborhoods <- list()
  sourceNeighborhoods <- list()
  while(length(edgesToCheckQueue) > 0){
    # Select the next edge to check. Find its neighborhoods.
    edge <- edgesToCheckQueue[[1]]
    nodes <- igraph::ends(graph, edge)
    edgeAndNeighbors <- union(edgesPerNode[[nodes[1]]], edgesPerNode[[nodes[2]]])
    
    # Remove the current edge from the queue
    edgesToCheckQueue <- setdiff(edgesToCheckQueue, edgeAndNeighbors)
    # Add the neighborhood.
    edgeAndNeighborsFormatted <- gsub("|", "__", edgeAndNeighbors, fixed = TRUE)
    edgeNeighborhoods[[length(edgeNeighborhoods) + 1]] <- edgeAndNeighborsFormatted
    targetNeighborhoods[[length(targetNeighborhoods) + 1]] <- unname(unlist(data.frame(strsplit(edgeAndNeighborsFormatted, "__"))[2,]))
    sourceNeighborhoods[[length(sourceNeighborhoods) + 1]] <- unname(unlist(data.frame(strsplit(edgeAndNeighborsFormatted, "__"))[1,]))
    cat(".")
   }
  
  # Return neighborhoods.
  print("Done")
  return(list(edge = edgeNeighborhoods, target = targetNeighborhoods, source = sourceNeighborhoods))
}

#' Obtain the list of pairs in a composite model. Do this by merging all composite
#' models with overlap. Use the full models so that we
#' consider the original overlap before pruning. This uses only the structure
#' of the graph and requires no dependencies on previous layers.
#' @param components A Model object containing the list of components in the graph.
#' @param pairsInEachPredictor A list of pairs contained within each predictor.
#' @param importantModels A list of all pairs that were found to be important in the
#' previous layer.
#' @return A list with the following elements: A list of sets comprising the trimmed
#' models, a list of sets comprising the untrimmed models, and a mapping from
#' the predictors in the previous stage to the current stage.
#' @export
ObtainCompositeModels <- function(components, pairsInEachPredictor, importantModels){
  
  # Initialize.
  pairMapping <- data.frame(from = 1:length(importantModels@pairs), to = 1:length(importantModels@pairs))
  compositePredictors <- list()
  compositePredictorTargets <- list()
  compositePredictorSources <- list()
  compositeFullModels <- list()
  compositeFullTargets <- list()
  compositeFullSources <- list()
  
  # Only merge composite models if there is more than one composite model to begin with.
  if(length(importantModels@pairs) > 1){
    
    # Find all connected components of the graph. This is identical to merging all
    # clusters that have overlap.
    for(i in 1:length(components@pairs)){

      # Find what maps to the component and update mapping.
      doesMap <- unlist(lapply(pairsInEachPredictor@pairs[importantModels@modelsKept], function(model){
        retval <- FALSE
        if(length(setdiff(model, components@pairs[[i]])) == 0){
          retval <- TRUE
        }
        return(retval)
      }))
      pairMapping$to[which(doesMap == TRUE)] <- i

      # Update composite predictors.
      importantPairs <- importantModels@pairs
      whichImportant <- which(components@pairs[[i]] %in% unlist(importantPairs))
      compositePredictors[[i]] <- unlist(components@pairs[[i]][whichImportant])
      compositeFullModels[[i]] <- components@pairs[[i]]
      compositePredictorTargets[[i]] <- unlist(components@targets[[i]][whichImportant])
      compositeFullTargets[[i]] <- components@targets[[i]]
      compositePredictorSources[[i]] <- unlist(components@sources[[i]][whichImportant])
      compositeFullSources[[i]] <- components@sources[[i]]
    }
  }else{
    compositePredictors <- importantModels@pairs
    compositeFullModels <- pairsInEachPredictor@pairs
    compositePredictorTargets <- importantModels@targets
    compositeFullTargets <- pairsInEachPredictor@targets
    compositePredictorSources <- importantModels@sources
    compositeFullSources <- pairsInEachPredictor@sources
    pairMapping <- data.frame(to = 1, from = 1)
  }

  # Return the data.
  return(methods::new("CompositeModelSet", compositeModels = methods::new("Model", pairs = compositePredictors, targets = compositePredictorTargets,
                                                        sources = compositePredictorSources, modelsKept = 1:length(compositePredictors)), 
              expandedCompositeModels = methods::new("Model", pairs = compositeFullModels, targets = compositeFullTargets, sources = compositeFullSources,
                                            modelsKept = 1:length(compositeFullModels)), mapping = pairMapping))
}

#' Given multiple composite predictors, prune the predictors that are not needed.
#'
#' @include AllClasses.R
#' 
#' @param compositeSubgraphs A list of pairs to include in the composite model.
#' @param previousModels A list of the previous models that were consolidated.
#' @param modelResults A ModelResults object.
#' @param verbose Whether or not to print out each step.
#' @param makePlots Whether or not to plot the pruned model at each step.
#' @param pruningMethod The method to use for pruning. Right now, only "error.t.test" is valid.
#' @param tolerance Tolerance factor when computing equality of two numeric values.
#' @param modelRetention Strategy for model retention. "stringent" (the default)
#' retains only models that improve the prediction score. "lenient" also retains models that
#' neither improve nor reduce the prediction score.
#' @param minCutoff Mininum cutoff for the prediction.
#' @param maxCutoff Maximum cutoff for the prediction.
#' @param useCutoff Whether or not to use the cutoff for prediction. Default is FALSE.
#' @param weights The weights for each predictor, calculated using ComputeMetaFeatureWeights()
#' @param individualPerformance The score (using the pruning method) for each individual component of the model.
#' @param layerNumber Number of layer in the model.
#' @param pruningTechnique Pruning technique to use. Possible methods are "backward.stepwise",
#' "forward.stepwise", "individual.performance", "exhaustive", and "both".
#' @param averaging If TRUE, then averaging is used to combine predictors rather than
#' retaining the same functional form for both the input and the output.
#' @param zeroOut This parameter zeros out predictors outside of the allowed
#' range.
#' @return A list of sets, where each set is a neighborhood of nodes.
#' @export
PrunePredictors <- function(compositeSubgraphs, previousModels, modelResults, verbose = FALSE, makePlots = FALSE,
                            pruningMethod = "error.t.test", tolerance = 1e-5,
                            modelRetention = "stringent", minCutoff, maxCutoff, useCutoff = FALSE, weights,
                            individualPerformance, layerNumber, pruningTechnique = "backward.stepwise",
                            averaging = averaging, zeroOut = FALSE){
  
  # Convert true values to numeric if applicable.
  trueVals <- as.matrix(modelResults@model.input@input.data@sampleMetaData[,modelResults@model.input@stype])
  if(!is.numeric(trueVals)){
    catNames <- sort(unique(trueVals))
    trueValuesChar <- trueVals
    trueVals <- matrix(rep(0, length(trueValuesChar)), ncol = ncol(trueValuesChar),
                       nrow = nrow(trueValuesChar))
    trueVals[which(trueValuesChar == catNames[2])] <- 1
  }
  
  # Extract relevant information.
  models <- compositeSubgraphs@compositeModels
  mapping <- compositeSubgraphs@mapping
  prunedSubgraphs <- lapply(1:length(models@pairs), function(i){
    pairs <- models@pairs[[i]]
    targets <- models@targets[[i]]
    sources <- models@sources[[i]]
    # Compute the number of models to remove at a time to keep each iteration under 1 hour.
    numToRemoveAtOnce <- 1
    if(layerNumber == 1){
      if(length(pairs) > 0){
        startTime <- Sys.time()
        compositeModelForward <- GENN::CompositePrediction(pairs = pairs,
                                                                         targets = targets,
                                                                         sources = sources,
                                                                         modelResults = modelResults,
                                                                         minCutoff = minCutoff,
                                                                         maxCutoff = maxCutoff,
                                                                         useCutoff = useCutoff,
                                                                         weights = weights,
                                                                         averaging = averaging,
                                                                         zeroOut = zeroOut)
        endTime <- Sys.time()
        runtime <- as.numeric(endTime - startTime)
        modelCount <- length(unlist(models@pairs))
        numDesiredIterations <- 60 * 60 / runtime
        numToRemoveAtOnce <- ceiling(modelCount / numDesiredIterations)
      }else{numToRemoveAtOnce <- 0}
    }

    # Prune by weight if this is the first layer.
    pairWeights <- as.matrix(weights[,pairs])
    whichWeights <- 1:length(pairs)
    whichPrevModelPairs <- 1:length(previousModels@pairs)
    localPairs <- pairs[whichWeights]
    localTargets <- targets[whichWeights]
    localSources <- sources[whichWeights]
    # Prevent empty lists from getting converted into NA values.
    if(length(which(is.na(localPairs))) > 0){
      localPairs <- c()
      localTargets <- c()
      localSources <- c()
    }
    localPrevModels <- methods::new("Model", pairs = previousModels@pairs[whichPrevModelPairs],
                           targets = previousModels@targets[whichPrevModelPairs],
                           sources = previousModels@sources[whichPrevModelPairs])
    localMapping <- mapping[whichPrevModelPairs,]
    localMapping$from <- 1:length(whichPrevModelPairs)

    # Print initial graph.
    maxToPrint = 50
    if(verbose == TRUE){
      if(length(localPairs) < maxToPrint){
        print(paste("subgraph", i, ":", paste(sort(unlist(localPairs)), collapse = ", ")))
      }else{
        print(paste("subgraph", i, ":", paste(sort(unlist(localPairs[1:maxToPrint])), collapse = ", ")))
      }
    }
    localSources <- as.character(localSources)
    localTargets <- as.character(localTargets)
    localPairs <- as.character(localPairs)
    
    # Make sure there is at least one model remaining.
    if(length(localPairs) > 1){
      if(pruningTechnique == "backward.stepwise"){
        importantModels <- PruneBackwardStepwise(pairs = localPairs, 
                                                 targets = localTargets,
                                                 sources = localSources,
                                                 modelResults = modelResults,
                                                 minCutoff = minCutoff,
                                                 maxCutoff = maxCutoff,
                                                 useCutoff = useCutoff,
                                                 verbose = verbose,
                                                 weights = weights,
                                                 pred = unlist(models),
                                                 trueVal = trueVals,
                                                 pruningMethod = pruningMethod,
                                                 mapping = localMapping,
                                                 individualPerformance = individualPerformance,
                                                 tolerance = tolerance,
                                                 i = i,
                                                 previousModels = localPrevModels,
                                                 modelRetention = modelRetention,
                                                 averaging = averaging,
                                                 zeroOut = zeroOut, numToRemove = numToRemoveAtOnce)
      }else if(pruningTechnique == "both"){
        importantModelsBackward <- PruneBackwardStepwise(pairs = localPairs, 
                                                 targets = localTargets,
                                                 sources = localSources,
                                                 modelResults = modelResults,
                                                 minCutoff = minCutoff,
                                                 maxCutoff = maxCutoff,
                                                 useCutoff = useCutoff,
                                                 verbose = verbose,
                                                 weights = weights,
                                                 pred = unlist(models),
                                                 trueVal = trueVals,
                                                 pruningMethod = pruningMethod,
                                                 mapping = localMapping,
                                                 individualPerformance = individualPerformance,
                                                 tolerance = tolerance,
                                                 i = i,
                                                 previousModels = localPrevModels,
                                                 modelRetention = modelRetention,
                                                 averaging = averaging,
                                                 zeroOut = zeroOut, numToRemove = numToRemoveAtOnce)
        importantModelsForward <- PruneForwardStepwise(pairs = localPairs, 
                                                targets = localTargets,
                                                sources = localSources,
                                                modelResults = modelResults,
                                                minCutoff = minCutoff,
                                                maxCutoff = maxCutoff,
                                                useCutoff = useCutoff,
                                                verbose = verbose,
                                                weights = weights,
                                                pred = unlist(models),
                                                trueVal = trueVals,
                                                pruningMethod = pruningMethod,
                                                mapping = localMapping,
                                                individualPerformance = individualPerformance,
                                                tolerance = tolerance,
                                                i = i,
                                                previousModels = localPrevModels,
                                                modelRetention = modelRetention,
                                                averaging = averaging,
                                                zeroOut = zeroOut, numToAdd = numToRemoveAtOnce)
        importantModelPairs <- list()
        importantModelTargets <- list()
        importantModelSources <- list()
        modelKept <- rep(TRUE, length(importantModelsBackward@pairs))
        for(i in 1:length(importantModelsBackward@pairs)){
          whichInForward <- which(importantModelsBackward@pairs[[i]] %in% unlist(importantModelsForward@pairs))
          importantModelPairs[[i]] <- list(importantModelsBackward@pairs[[i]][whichInForward])
          importantModelTargets[[i]] <- list(importantModelsBackward@targets[[i]][whichInForward])
          importantModelSources[[i]] <- list(importantModelsBackward@sources[[i]][whichInForward])
          if(length(whichInForward) == 0){modelKept[i] <- FALSE}
        }
        importantModels <- importantModelsForward
        compositeModelForward <- GENN::CompositePrediction(pairs = unlist(importantModelsForward@pairs),
                                                                         targets = unlist(importantModelsForward@targets),
                                                                         sources = unlist(importantModelsForward@sources),
                                                                         modelResults = modelResults,
                                                                         minCutoff = minCutoff,
                                                                         maxCutoff = maxCutoff,
                                                                         useCutoff = useCutoff,
                                                                         weights = weights,
                                                                         averaging = averaging,
                                                                         zeroOut = zeroOut)
        compositeModelBackward <- GENN::CompositePrediction(pairs = unlist(importantModelsBackward@pairs),
                                                                                targets = unlist(importantModelsBackward@targets),
                                                                                sources = unlist(importantModelsBackward@sources),
                                                                                modelResults = modelResults,
                                                                                minCutoff = minCutoff,
                                                                                maxCutoff = maxCutoff,
                                                                                useCutoff = useCutoff,
                                                                                weights = weights,
                                                                                averaging = averaging,
                                                                                zeroOut = zeroOut)
        performanceForward <- GENN::ComputeSignificance(pred = compositeModelForward,
                                                                             trueVal = trueVals,
                                                                             pruningMethod = "error.t.test")
        performanceBackward <- GENN::ComputeSignificance(pred = compositeModelBackward,
                                                                              trueVal = trueVals,
                                                                              pruningMethod = "error.t.test")
        if(length(which(modelKept == TRUE)) > 0){
          compositeModelShared <- GENN::CompositePrediction(pairs = unlist(importantModelPairs[[which(modelKept == TRUE)]]),
                                                                                   targets = unlist(importantModelTargets[[which(modelKept == TRUE)]]),
                                                                                   sources = unlist(importantModelSources[[which(modelKept == TRUE)]]),
                                                                                   modelResults = modelResults,
                                                                                   minCutoff = minCutoff,
                                                                                   maxCutoff = maxCutoff,
                                                                                   useCutoff = useCutoff,
                                                                                   weights = weights,
                                                                                 averaging = averaging,
                                                                                 zeroOut = zeroOut)
          performanceShared <- GENN::ComputeSignificance(pred = compositeModelShared,
                                                                              trueVal = trueVals,
                                                                              pruningMethod = "error.t.test")

          if(performanceShared > performanceForward && performanceShared > performanceBackward){
            importantModels <- methods::new("Model", pairs = importantModelPairs[[which(modelKept == TRUE)]], 
                                   targets = importantModelTargets[[which(modelKept == TRUE)]], 
                                   sources = importantModelSources[[which(modelKept == TRUE)]], 
                                   modelsKept = importantModelsBackward@modelsKept[which(modelKept == TRUE)])
          }else if(performanceForward >= performanceShared && performanceForward >= performanceBackward){
            importantModels <- importantModelsForward
          }else if(performanceBackward >= performanceShared && performanceBackward >= performanceForward){
            importantModels <- importantModelsBackward
          }
          
        }else if(performanceForward > performanceBackward){
          importantModels <- importantModelsForward
        }else{
          importantModels <- importantModelsBackward
        }
      }else if(pruningTechnique == "forward.stepwise"){
        importantModels <- PruneForwardStepwise(pairs = localPairs, 
                                                 targets = localTargets,
                                                 sources = localSources,
                                                 modelResults = modelResults,
                                                 minCutoff = minCutoff,
                                                 maxCutoff = maxCutoff,
                                                 useCutoff = useCutoff,
                                                 verbose = verbose,
                                                 weights = weights,
                                                 pred = unlist(models),
                                                 trueVal = trueVals,
                                                 pruningMethod = pruningMethod,
                                                 mapping = localMapping,
                                                 individualPerformance = individualPerformance,
                                                 tolerance = tolerance,
                                                 i = i,
                                                 previousModels = localPrevModels,
                                                 modelRetention = modelRetention,
                                                averaging = averaging,
                                                zeroOut = zeroOut, numToAdd = numToRemoveAtOnce)
      }else if(pruningTechnique == "exhaustive"){
        importantModels <- PruneExhaustive(pairs = localPairs, 
                                                targets = localTargets,
                                                sources = localSources,
                                                modelResults = modelResults,
                                           minCutoff = minCutoff,
                                           maxCutoff = maxCutoff,
                                           useCutoff = useCutoff,
                                                verbose = verbose,
                                                weights = weights,
                                                pred = unlist(models),
                                                trueVal = trueVals,
                                                mapping = localMapping,
                                           tolerance = tolerance,
                                                i = i,
                                                previousModels = localPrevModels,
                                           averaging = averaging,
                                           zeroOut = zeroOut)
      }else if(pruningTechnique == "parallel"){
        importantModels <- PruneStepwiseParallel(pairs = localPairs, 
                                                targets = localTargets,
                                                sources = localSources,
                                                modelResults = modelResults,
                                                minCutoff = minCutoff,
                                                maxCutoff = maxCutoff,
                                                useCutoff = useCutoff,
                                                verbose = verbose,
                                                weights = weights,
                                                pred = unlist(models),
                                                trueVal = trueVals,
                                                pruningMethod = pruningMethod,
                                                mapping = localMapping,
                                                individualPerformance = individualPerformance,
                                                tolerance = tolerance,
                                                i = i,
                                                previousModels = localPrevModels,
                                                modelRetention = modelRetention,
                                                averaging = averaging,
                                                zeroOut = zeroOut)
      }else{
        stop(paste("Pruning technique", pruningTechnique, "is not valid."))
      }

    }else{
      importantModels <- methods::new("Model", pairs = list(localPairs), targets = list(localTargets), 
                             sources = list(localSources),
                             modelsKept = i)
    }

    # Print results.
    if(verbose == TRUE){
      if(length(unlist(importantModels@pairs)) < maxToPrint){
        print(paste("The pruned set of pairs is:", paste(sort(unlist(importantModels@pairs)), collapse = ",")))
      }else{
        print(paste("The pruned set of pairs is:", paste(sort(unlist(importantModels@pairs))[1:maxToPrint], collapse = ",")))
      }
    }else{
      cat(".")
    }
    return(importantModels)
  })
  lengths <- unlist(lapply(prunedSubgraphs, function(g){
    return(length(g@pairs[[1]]))
  }))
  prunedSubgraphs <- prunedSubgraphs[which(lengths > 0)]
  pairs <- lapply(prunedSubgraphs, function(g){return(unlist(g@pairs))})
  targets <- lapply(prunedSubgraphs, function(g){return(unlist(g@targets))})
  sources <- lapply(prunedSubgraphs, function(g){return(unlist(g@sources))})
  nums <- unlist(lapply(prunedSubgraphs, function(g){return(g@modelsKept)}))
  return(methods::new("Model", pairs = pairs, targets = targets, sources = sources, modelsKept = nums))
}

#' Given multiple composite predictors, evaluate each combination exhaustively.
#'
#' @include AllClasses.R
#' 
#' @param modelResults The ModelResults object that will be filled in during training,
#' obtained from DoModelSetup().
#' @param pred The predicted values using each individual predictor.
#' @param trueVal The true values of the phenotype.#' 
#' @param previousModels A list of the previous models that were consolidated.
#' @param verbose Whether or not to print out each step.
#' @param pairs A list of pairs to include in the composite model.
#' @param targets A list of targets to include in the composite model.
#' @param sources A list of sources to include in the composite model.
#' @param minCutoff Mininum cutoff for the prediction.
#' @param maxCutoff Maximum cutoff for the prediction.
#' @param useCutoff Whether or not to use the cutoff for prediction. Default is FALSE.
#' @param weights The weights for each predictor, calculated using ComputeMetaFeatureWeights()
#' @param mapping A mapping from models to composite models for the next stage of pooling.
#' @param pruningMethod The method to use for pruning. Right now, only "error.t.test" is valid.
#' @param tolerance Tolerance factor when computing equality of two numeric values.
#' @param averaging If TRUE, then averaging is used to combine predictors rather than
#' retaining the same functional form for both the input and the output.
#' @param zeroOut This parameter zeros out predictors outside of the allowed
#' range.
#' @param i Model index
#' @return A ModelIDSet with the retained models.
#' @export
PruneExhaustive <- function(pairs, targets, sources, modelResults,
                                  minCutoff, maxCutoff, useCutoff,
                                  weights, pred, verbose, mapping,
                                  trueVal, previousModels,
                                  tolerance, i, pruningMethod = "error.t.test",
                            averaging = FALSE, zeroOut = zeroOut){
  
  
  # Figure out which of the previous models mapped to this one.
  previousModelsMapped <- mapping$from[which(mapping$to == i)]
  
  # Find each combination of m models, where m is a subset of the total
  # number of models.
  modelSignificanceAllList <- lapply(1:length(previousModelsMapped), function(m){
    # Find each combination of m models.
    modelCombinations <- utils::combn(previousModelsMapped, m)
    modelSignificance <- unlist(lapply(1:ncol(modelCombinations), function(comboIdx){
      
      # Extract relevant pairs and targets for this combination.
      combo <- modelCombinations[,comboIdx]
      importantModels <- modelCombinations
      importantPairs <- previousModels@pairs[combo]
      importantTargets <- previousModels@targets[combo]
      importantSources <- previousModels@sources[combo]

      # Verify that model to remove is under consideration.
      significance <- NULL
      if(length(intersect(unlist(importantPairs), pairs)) > 0){
        
        # Filter pairs.
        whichInFilteredList <- which(unlist(importantPairs) %in% pairs)
        pairsToInclude <- unlist(importantPairs)[whichInFilteredList]
        targetsToInclude <- unlist(importantTargets)[whichInFilteredList]
        sourcesToInclude <- unlist(importantSources)[whichInFilteredList]
        
        # De-duplicate pairs.
        pairCountTable <- table(pairsToInclude)
        dupedPairs <- names(pairCountTable)[which(pairCountTable > 1)]
        for(p in dupedPairs){
          pairIndices <- which(pairsToInclude == p)
          dupPairIndices <- sort(pairIndices)[2:length(pairIndices)]
          indicesToKeep <- sort(setdiff(1:length(pairsToInclude), dupPairIndices))
          pairsToInclude <- pairsToInclude[indicesToKeep]
          targetsToInclude <- targetsToInclude[indicesToKeep]
          sourcesToInclude <- sourcesToInclude[indicesToKeep]
        }
        if(length(pairsToInclude) > 0){
          compositeModel <- GENN::CompositePrediction(pairs = unlist(pairsToInclude),
                                                                           targets = unlist(targetsToInclude),
                                                                           sources = unlist(sourcesToInclude),
                                                                           modelResults = modelResults,
                                                                           minCutoff = minCutoff,
                                                                           maxCutoff = maxCutoff,
                                                                           useCutoff = useCutoff,
                                                                           weights = weights,
                                                                           averaging = averaging,
                                                                           zeroOut = zeroOut)
          significance <- GENN::ComputeSignificance(pred = unlist(compositeModel),
                                                                         trueVal = trueVal,
                                                                         pruningMethod = pruningMethod)
        }
      }
      return(significance)
   }))

    # Find the best significance of any of the models.
    greatestSignificance <- which.max(modelSignificance)
    return(list(combo = modelCombinations[,greatestSignificance],
                      significance = modelSignificance[which.max(modelSignificance)]))
  })
  modelSignificanceAll <- unlist(lapply(modelSignificanceAllList, function(sig){return(sig$significance)}))
  
  # Find the best significance over all combinations.
  bestCombination <- modelSignificanceAllList[[which.max(modelSignificanceAll)]]$combo

  return(methods::new("Model", pairs = previousModels@pairs[bestCombination], 
             targets = previousModels@targets[bestCombination], 
             sources = previousModels@sources[bestCombination],  modelsKept = i))
}

#' Given multiple composite predictors, prune the predictors that are not needed.
#'
#' @include AllClasses.R
#' 
#' @param modelResults The ModelResults object that will be filled in during training,
#' obtained from DoModelSetup().
#' @param pred The predicted values using each individual predictor.
#' @param trueVal The true values of the phenotype.#' 
#' @param previousModels A list of the previous models that were consolidated.
#' @param verbose Whether or not to print out each step.
#' @param pairs A list of pairs to include in the composite model.
#' @param targets A list of targets to include in the composite model.
#' @param sources A list of sources to include in the composite model.
#' @param minCutoff Mininum cutoff for the prediction.
#' @param maxCutoff Maximum cutoff for the prediction.
#' @param useCutoff Whether or not to use the cutoff for prediction. Default is FALSE.
#' @param weights The weights for each predictor, calculated using ComputeMetaFeatureWeights()
#' @param mapping A mapping from models to composite models for the next stage of pooling.
#' @param pruningMethod The method to use for pruning. Right now, only "error.t.test" is valid.
#' @param tolerance Tolerance factor when computing equality of two numeric values.
#' @param individualPerformance The score (using the pruning method) for each individual component of the model.
#' @param i Model index
#' @param modelRetention Strategy for model retention. "stringent" (the default)
#' retains only models that improve the prediction score. "lenient" also retains models that
#' neither improve nor reduce the prediction score.
#' @param averaging If TRUE, then averaging is used to combine predictors rather than
#' retaining the same functional form for both the input and the output.
#' @param zeroOut This parameter zeros out predictors outside of the allowed
#' range.
#' @return A ModelIDSet with the retained models.
#' @export
PruneStepwiseParallel <- function(pairs, targets, sources, modelResults,
                                  minCutoff, maxCutoff, useCutoff,
                                  weights, pred, verbose, mapping,
                                  trueVal, pruningMethod, previousModels,
                                  individualPerformance, tolerance, i,
                                  modelRetention, averaging = FALSE,
                                  zeroOut = FALSE){

  # Figure out which of the previous models mapped to this one.
  previousModelsMapped <- mapping$from[which(mapping$to == i)]
  importantModels <- previousModelsMapped
  importantPairs <- previousModels@pairs[previousModelsMapped]
  importantTargets <- previousModels@targets[previousModelsMapped]
  importantSources <- previousModels@sources[previousModelsMapped]
  significance <- NULL

  # Sort the models in order of their individual performance.
  modelPerformanceQuantiles <- stats::quantile(individualPerformance[importantModels], seq(0.1, 1, by = 0.1))

  # Sequentially test removal of each model.
  modelSignificanceList <- lapply(modelPerformanceQuantiles, function(q){
    
    # Verify that model to remove is under consideration.
    if(length(intersect(unlist(importantPairs), pairs)) > 0){
      
      # Obtain the models with performance above a given quantile.
      modelsToInclude <- previousModelsMapped[which(individualPerformance[importantModels] >= q)]

      # Obtain pairs for the full model.
      whichToInclude <- which(unlist(importantPairs) %in% unlist(previousModels@pairs[modelsToInclude]))
      pairsToInclude <- unlist(importantPairs)[whichToInclude]
      targetsToInclude <- unlist(importantTargets)[whichToInclude]
      sourcesToInclude <- unlist(importantSources)[whichToInclude]

      # Filter pairs.
      whichInFilteredList <- which(pairsToInclude %in% pairs)
      pairsToInclude <- pairsToInclude[whichInFilteredList]
      targetsToInclude <- targetsToInclude[whichInFilteredList]
      sourcesToInclude <- sourcesToInclude[whichInFilteredList]

      # De-duplicate pairs.
      pairCountTable <- table(pairsToInclude)
      dupedPairs <- names(pairCountTable)[which(pairCountTable > 1)]
      for(p in dupedPairs){
        pairIndices <- which(pairsToInclude == p)
        dupPairIndices <- sort(pairIndices)[2:length(pairIndices)]
        indicesToKeep <- sort(setdiff(1:length(pairsToInclude), dupPairIndices))
        pairsToInclude <- pairsToInclude[indicesToKeep]
        targetsToInclude <- targetsToInclude[indicesToKeep]
        sourcesToInclude <- sourcesToInclude[indicesToKeep]
      }

      if(length(pairsToInclude) > 0){
        compositeModel <- GENN::CompositePrediction(pairs = pairsToInclude,
                                                                         targets = targetsToInclude,
                                                                         sources = sourcesToInclude,
                                                                         modelResults = modelResults,
                                                                         minCutoff = minCutoff,
                                                                         maxCutoff = maxCutoff,
                                                                         useCutoff = useCutoff,
                                                                         weights = weights,
                                                                         averaging = averaging,
                                                                         zeroOut = zeroOut)
        significance <- GENN::ComputeSignificance(pred = unlist(compositeModel),
                                                                       trueVal = trueVal,
                                                                       pruningMethod = pruningMethod)
      }
    }
  })
  
  modelSignificance <- unlist(modelSignificanceList)
  whichMostSignificant <- which.max(modelSignificance)
  whichBest <- which(individualPerformance[importantModels] >= modelPerformanceQuantiles[whichMostSignificant])
  
  return(methods::new("Model", pairs = previousModels@pairs[whichBest], targets = previousModels@targets[whichBest], 
             sources = previousModels@sources[whichBest],  modelsKept = i))
}


#' Given multiple composite predictors, prune the predictors that are not needed.
#'
#' @include AllClasses.R
#' 
#' @param modelResults The ModelResults object that will be filled in during training,
#' obtained from DoModelSetup().
#' @param pred The predicted values using each individual predictor.
#' @param trueVal The true values of the phenotype.
#' @param previousModels A list of the previous models that were consolidated.
#' @param verbose Whether or not to print out each step.
#' @param pairs A list of pairs to include in the composite model.
#' @param targets A list of targets to include in the composite model.
#' @param sources A list of sources to include in the composite model.
#' @param minCutoff Mininum cutoff for the prediction.
#' @param maxCutoff Maximum cutoff for the prediction.
#' @param useCutoff Whether or not to use the cutoff for prediction. Default is FALSE.
#' @param weights The weights for each predictor, calculated using ComputeMetaFeatureWeights()
#' @param mapping A mapping from models to composite models for the next stage of pooling.
#' @param pruningMethod The method to use for pruning. Right now, only "error.t.test" is valid.
#' @param tolerance Tolerance factor when computing equality of two numeric values.
#' @param individualPerformance The score (using the pruning method) for each individual component of the model.
#' @param i Model index
#' @param modelRetention Strategy for model retention. "stringent" (the default)
#' retains only models that improve the prediction score. "lenient" also retains models that
#' neither improve nor reduce the prediction score.
#' @param averaging If TRUE, then averaging is used to combine predictors rather than
#' retaining the same functional form for both the input and the output
#' @param zeroOut This parameter zeros out predictors outside of the allowed
#' range.
#' @param numToRemove Number of models to remove at once when pruning.
#' @return A ModelIDSet with the retained models.
#' @export
PruneBackwardStepwise <- function(pairs, targets, sources, modelResults,
                                       minCutoff, maxCutoff, useCutoff,
                                       weights, pred, verbose, mapping,
                                       trueVal, pruningMethod, previousModels,
                                       individualPerformance, tolerance, i,
                                       modelRetention, averaging, zeroOut = FALSE,
                                       numToRemove){
  
  # Initialize the predictor to include all pairs in the composite subgraph.
  compositeModel <- GENN::CompositePrediction(pairs = pairs,
                                                                   targets = targets,
                                                                   sources = sources,
                                                                   modelResults = modelResults,
                                                                   minCutoff = minCutoff,
                                                                   maxCutoff = maxCutoff,
                                                                   useCutoff = useCutoff,
                                                                   weights = weights,
                                                                   averaging = averaging,
                                                                   zeroOut = zeroOut)
  significance <- GENN::ComputeSignificance(pred = unlist(compositeModel),
                                                                 trueVal = trueVal,
                                                                 pruningMethod = pruningMethod)
  if(verbose == TRUE){
    print(paste(list("Original", pruningMethod, "is", significance), collapse = " "))
  }

  removedLastTime <- FALSE
  compositeModelFull <- compositeModel
  significanceFull <- significance

  # Figure out which of the previous models mapped to this one.
  previousModelsMapped <- mapping$from[which(mapping$to == i)]
  importantModels <- previousModelsMapped
  importantPairs <- previousModels@pairs[previousModelsMapped]
  importantTargets <- previousModels@targets[previousModelsMapped]
  importantSources <- previousModels@sources[previousModelsMapped]

  # Sort the models in order of their individual performance.
  modelRemovalOrder <- order(individualPerformance[importantModels])
 
  # Sequentially test removal of each model.
  pairsToInclude <- pairs
  for(chunk in seq(1, length(modelRemovalOrder), numToRemove)){
    maxRemovalIdx <- pmin(length(modelRemovalOrder), chunk + numToRemove - 1)
    m <- modelRemovalOrder[chunk:maxRemovalIdx]

    # Verify that model to remove is under consideration.
    if(length(intersect(unlist(importantPairs), pairs)) > 0){
      # Save the information gain for the full model.
      if(removedLastTime == TRUE){
        compositeModelFull <- compositeModel
        significanceFull <- significance
      }
      
      # Obtain pairs for the full model.
      modelsToInclude <- setdiff(importantModels, previousModelsMapped[m])
      whichToInclude <- which(unlist(importantPairs) %in% unlist(previousModels@pairs[modelsToInclude]))
      prevPairs <- pairsToInclude
      pairsToInclude <- unlist(importantPairs)[whichToInclude]
      targetsToInclude <- unlist(importantTargets)[whichToInclude]
      sourcesToInclude <- unlist(importantSources)[whichToInclude]

      # Filter pairs.
      whichInFilteredList <- which(pairsToInclude %in% pairs)
      pairsToInclude <- pairsToInclude[whichInFilteredList]
      targetsToInclude <- targetsToInclude[whichInFilteredList]
      sourcesToInclude <- sourcesToInclude[whichInFilteredList]

      # De-duplicate pairs.
      pairCountTable <- table(pairsToInclude)
      dupedPairs <- names(pairCountTable)[which(pairCountTable > 1)]
      for(p in dupedPairs){
        pairIndices <- which(pairsToInclude == p)
        dupPairIndices <- sort(pairIndices)[2:length(pairIndices)]
        indicesToKeep <- sort(setdiff(1:length(pairsToInclude), dupPairIndices))
        pairsToInclude <- pairsToInclude[indicesToKeep]
        targetsToInclude <- targetsToInclude[indicesToKeep]
        sourcesToInclude <- sourcesToInclude[indicesToKeep]
      }

      if(length(pairsToInclude) > 0){
        compositeModel <- GENN::CompositePrediction(pairs = pairsToInclude,
                                                                         targets = targetsToInclude,
                                                                         sources = sourcesToInclude,
                                                                         modelResults = modelResults,
                                                                         minCutoff = minCutoff,
                                                                         maxCutoff = maxCutoff,
                                                                         useCutoff = useCutoff,
                                                                         weights = weights,
                                                                         averaging = averaging,
                                                                         zeroOut = zeroOut)
        significance <- GENN::ComputeSignificance(pred = unlist(compositeModel),
                                                                       trueVal = trueVal,
                                                                       pruningMethod = pruningMethod)
      }
      
      # If the significance of the new model is greater than or equal to the full model, remove the pair.
      # If only one analyte pair remains and it has not improved over the full model, keep the pair.
      # If only one analyte pair remains and the value is 0, remove the pair.
      # If the significance of the new model is less than the full model, remove the pair.
      significance <- round(significance, 2)
      significanceFull <- round(significanceFull, 2)
      model <- setdiff(unlist(prevPairs), unlist(pairsToInclude))
      removedLastTime <- FALSE
      meetsCutoffForRemoval <- significance >= significanceFull + tolerance
      if(modelRetention == "lenient"){
        meetsCutoffForRemoval <- significance > significanceFull
      }
      if(meetsCutoffForRemoval == TRUE && length(pairsToInclude) > 0){
        # Print the increase in significance.
        if(verbose == TRUE){
          maxToPrint <- min(length(model), 10)
          print(paste(list("Removed model.", pruningMethod, "after removing", paste(model[1:maxToPrint], collapse = ","), "is",
                           format(significance, nsmall = 2), ", as compared to",
                           format(significanceFull, nsmall = 2)), collapse = " "))
        }
        
        importantPairs <- pairsToInclude
        importantTargets <- targetsToInclude
        importantSources <- sourcesToInclude
        importantModels <- modelsToInclude
        removedLastTime <- TRUE
      }else if (length(pairsToInclude) == 0){
        if(verbose == TRUE){
          maxToPrint <- min(length(model), 10)
          print(paste(list("Kept", paste(model[1:maxToPrint], collapse = ","), "because it is the last remaining model"), collapse = " "))
        }
      }else{
        # Print the decrease in information gain.
        if(verbose == TRUE){
          maxToPrint <- min(length(model), 10)
          print(paste(list("Kept model.", pruningMethod, "after removing", paste(model[1:maxToPrint], collapse = ","), "is",
                           format(significance, nsmall = 2), ", as compared to",
                           format(significanceFull, nsmall = 2)), collapse = " "))
        }
      }
    }
  }
  return(methods::new("Model", pairs = list(importantPairs), targets = list(importantTargets), 
             sources = list(importantSources),  modelsKept = i))
}

#' Given multiple composite predictors, prune the predictors that are not needed.
#'
#' @include AllClasses.R
#' 
#' @param modelResults The ModelResults object that will be filled in during training,
#' obtained from DoModelSetup().
#' @param pred The predicted values using each individual predictor.
#' @param trueVal The true values of the phenotype.
#' @param previousModels A list of the previous models that were consolidated.
#' @param verbose Whether or not to print out each step.
#' @param pairs A list of pairs to include in the composite model.
#' @param targets A list of targets to include in the composite model.
#' @param sources A list of sources to include in the composite model.
#' @param minCutoff Mininum cutoff for the prediction.
#' @param maxCutoff Maximum cutoff for the prediction.
#' @param useCutoff Whether or not to use the cutoff for prediction. Default is FALSE.
#' @param weights The weights for each predictor, calculated using ComputeMetaFeatureWeights()
#' @param mapping A mapping from models to composite models for the next stage of pooling.
#' @param pruningMethod The method to use for pruning. Right now, only "error.t.test" is valid.
#' @param tolerance Tolerance factor when computing equality of two numeric values.
#' @param individualPerformance The score (using the pruning method) for each individual component of the model.
#' @param i Model index
#' @param modelRetention Strategy for model retention. "stringent" (the default)
#' retains only models that improve the prediction score. "lenient" also retains models that
#' neither improve nor reduce the prediction score.
#' @param averaging If TRUE, then averaging is used to combine predictors rather than
#' retaining the same functional form for both the input and the output.
#' @param zeroOut This parameter zeros out predictors outside of the allowed
#' range.
#' @param numToAdd Number of models to remove at once when pruning.
#' @return A ModelIDSet with the retained models.
#' @export
PruneForwardStepwise <- function(pairs, targets, sources, modelResults,
                                  minCutoff, maxCutoff, useCutoff,
                                  weights, pred, verbose, mapping,
                                  trueVal, pruningMethod, previousModels,
                                  individualPerformance, tolerance, i, numToAdd,
                                  modelRetention, averaging, zeroOut = FALSE){
  
  # Convert values to numeric if needed.
  trueVals <- modelResults@model.input@input.data@sampleMetaData[,modelResults@model.input@stype]
  if(!is.numeric(trueVals)){
    trueVals <- as.matrix(trueVals)
    catNames <- sort(unique(trueVals))
    trueValuesChar <- trueVals
    trueVals <- matrix(rep(0, length(trueValuesChar)), ncol = ncol(trueValuesChar),
                       nrow = nrow(trueValuesChar))
    trueVals[which(trueValuesChar == catNames[2])] <- 1
  }
  
  # Figure out which of the previous models mapped to this one.
  previousModelsMapped <- mapping$from[which(mapping$to == i)]
  modelsToConsider <- previousModelsMapped
  pairsToConsider <- previousModels@pairs[previousModelsMapped]
  targetsToConsider <- previousModels@targets[previousModelsMapped]
  sourcesToConsider <- previousModels@sources[previousModelsMapped]
  
  # Sort the models in order of their individual performance.
  modelAdditionOrder <- order(-individualPerformance[modelsToConsider])
  # Initialize important models.
  importantModels <- c()
  importantPairs <- list()
  importantTargets <- list()
  importantSources <- list()
  
  # Initialize significance to be very low so that the first model is better.
  significanceLast <- -1 * .Machine$double.xmax

  # Sequentially test removal of each model.
  pairsToInclude <- list()
  for(chunk in seq(1, length(modelAdditionOrder), numToAdd)){
    maxAdditionIdx <- pmin(length(modelAdditionOrder), chunk + numToAdd - 1)
    m <- modelAdditionOrder[chunk:maxAdditionIdx]
    
    # Verify that model to add is under consideration.
    if(length(intersect(unlist(pairsToConsider), pairs)) > 0){
      
      # Obtain pairs for the full model.
      modelsToInclude <- c(importantModels, previousModelsMapped[m])
      whichToInclude <- which(unlist(pairsToConsider) %in% unlist(previousModels@pairs[previousModelsMapped[m]]))
      prevPairs <- pairsToInclude
      pairsToInclude <- c(unlist(pairsToConsider)[whichToInclude], unlist(importantPairs))
      targetsToInclude <- c(unlist(targetsToConsider)[whichToInclude], unlist(importantTargets))
      sourcesToInclude <- c(unlist(sourcesToConsider)[whichToInclude], unlist(importantSources))
      
      # Filter pairs.
      whichInFilteredList <- which(pairsToInclude %in% pairs)
      pairsToInclude <- pairsToInclude[whichInFilteredList]
      targetsToInclude <- targetsToInclude[whichInFilteredList]
      sourcesToInclude <- sourcesToInclude[whichInFilteredList]
      
      # De-duplicate pairs.
      pairCountTable <- table(pairsToInclude)
      dupedPairs <- names(pairCountTable)[which(pairCountTable > 1)]
      for(p in dupedPairs){
       pairIndices <- which(pairsToInclude == p)
       dupPairIndices <- sort(pairIndices)[2:length(pairIndices)]
       indicesToKeep <- sort(setdiff(1:length(pairsToInclude), dupPairIndices))
       pairsToInclude <- pairsToInclude[indicesToKeep]
       targetsToInclude <- targetsToInclude[indicesToKeep]
       sourcesToInclude <- sourcesToInclude[indicesToKeep]
      }
      
      if(length(pairsToInclude) > 0){
        compositeModel <- GENN::CompositePrediction(pairs = pairsToInclude,
                                                                         targets = targetsToInclude,
                                                                         sources = sourcesToInclude,
                                                                         modelResults = modelResults,
                                                                         minCutoff = minCutoff,
                                                                         maxCutoff = maxCutoff,
                                                                         useCutoff = useCutoff,
                                                                         weights = weights,
                                                                         averaging = averaging,
                                                                         zeroOut = zeroOut)
        significance <- GENN::ComputeSignificance(pred = unlist(compositeModel),
                                                                       trueVal = trueVal,
                                                                       pruningMethod = pruningMethod)
        # If the significance of the new model is greater than or equal to the full model, add the pair.
        # If only one analyte pair has been added, add the pair.
        significance <- round(significance, 2)
        significanceLast <- round(significanceLast, 2)
        model <- setdiff(unlist(pairsToInclude), unlist(prevPairs))
        meetsCutoffForAddition <- significance >= significanceLast + tolerance
        if(modelRetention == "lenient"){
          meetsCutoffForAddition <- significance > significanceLast
        }
        if(meetsCutoffForAddition == TRUE && length(pairsToInclude) > 0){
          # Print the increase in significance.
          if(verbose == TRUE){
            maxToPrint <- min(length(model), 10)
            print(paste(list("Added model.", pruningMethod, "after adding", paste(model[1:maxToPrint], collapse = ","), "is",
                             format(significance, nsmall = 2), ", as compared to",
                             format(significanceLast, nsmall = 2)), collapse = " "))
          }
          
          importantPairs <- pairsToInclude
          importantTargets <- targetsToInclude
          importantSources <- sourcesToInclude
          importantModels <- modelsToInclude
          addedLastTime <- TRUE
          significanceLast <- significance
        }else{
          # Print the decrease in information gain.
          if(verbose == TRUE){
            maxToPrint <- min(length(model), 10)
            print(paste(list("Did not add model.", pruningMethod, "after adding", paste(model[1:maxToPrint], collapse = ","), "is",
                             format(significance, nsmall = 2), ", as compared to",
                             format(significanceLast, nsmall = 2)), collapse = " "))
          }
        }
      }
    }
  }
  return(methods::new("Model", pairs = list(importantPairs), targets = list(importantTargets), 
             sources = list(importantSources),  modelsKept = i))
}

#' Obtain a prediction from a composite model, given the pairs to include in the model.
#' @param predictions Predictions for each model.
#' @param trueVal The true values (predictions or outcomes) of the input data.
#' @return A vector of significance values named by model.
#' 
#' @export
ComputeIndividualPerformance <- function(predictions, trueVal){
  pred <- predictions
  trueVal <- matrix(rep(trueVal, ncol(pred)), ncol = ncol(pred))
  if(!is.numeric(trueVal)){
    catNames <- sort(unique(trueVal))
    trueValuesChar <- trueVal
    trueVal <- matrix(rep(0, length(trueValuesChar)), ncol = ncol(trueValuesChar),
                         nrow = nrow(trueValuesChar))
    trueVal[which(trueValuesChar == catNames[2])] <- 1
  }
  tScore <- -1000

  # Compute the absolute error values.
  meanTrue <- matrix(rep(mean(trueVal), length(pred)), ncol = ncol(pred))
  meanAbsError <- abs(trueVal - meanTrue)
  predAbsError <- abs(trueVal - pred)

  # Compute Welch's t-test.
  n <- nrow(meanAbsError)
  meanMeanAbsError <- colMeans(meanAbsError)
  meanPredAbsError <- colMeans(predAbsError)
  meanMeanAbsErrorMat <- t(matrix(rep(meanMeanAbsError, n), ncol = n))
  meanPredAbsErrorMat <- t(matrix(rep(meanPredAbsError, n), ncol = n))
  stdevMeanAbsError <- sqrt(colSums((meanAbsError - meanMeanAbsErrorMat) ^ 2) / n)
  stdevPredAbsError <- sqrt(colSums((predAbsError - meanPredAbsErrorMat) ^ 2) / n)
  stdevMeanAbsErrorBar <- stdevMeanAbsError / sqrt(n)
  stdevPredAbsErrorBar <- stdevPredAbsError / sqrt(n)
  tScore <- (meanMeanAbsError - meanPredAbsError) / sqrt((stdevMeanAbsErrorBar ^ 2) + (stdevPredAbsErrorBar ^ 2))
  return(tScore)
}

#' Compute the weight of each predictor given the weights of different
#' metafeatures.
#' @param modelResults A ModelResults object containing the current weights
#' @param metaFeatures A set of metafeatures, such as that found within ModelResults
#' @return A weight matrix for each sample and each predictor.
#' @export
ComputeMetaFeatureWeights <- function(metaFeatures, modelResults){
  weights <- lapply(1:length(metaFeatures), function(i){
    imp <- metaFeatures[[i]] * modelResults@current.metaFeature.weights[i]
    return(as.matrix(imp))
  })
  weights_all <- Reduce("+",weights)
  return(weights_all)
}
ncats/MultiOmicsGraphPrediction documentation built on Aug. 23, 2023, 9:19 a.m.