R/02_LearnGraphPredictionModel.R

Defines functions ComputeRMSE ComputeClassificationError ObtainUnweightedModel OptimizeMetaFeatureCombo FormatInput InitializeGraphLearningModel

Documented in ComputeClassificationError ComputeRMSE FormatInput InitializeGraphLearningModel ObtainUnweightedModel OptimizeMetaFeatureCombo

#' Find edges that share nodes and add them to a data frame.
#' @param modelInputs An object of the ModelInput class.
#' @param iterations Maximum number of iterations. Default is 1,000.
#' @param convergenceCutoff Cutoff for convergence. Default is 0.001.
#' @param learningRate Learning rate to use during training. Default is 0.2
#' @param optimizationType Type of optimization. May be "SGD", "momentum",
#' "adagrad", or "adam". Default is "SGD".
#' @param initialMetaFeatureWeights Initial weights for model meta-features. Default
#' is 0, which results in each meta-feature being given equal weight.
#' @return An object of the ModelResults class. This object will later be filled in
#' with the weights and errors from each iteration as well as the initial
#' settings and data and the final result.
#' @export
InitializeGraphLearningModel <- function(modelInputs,
                                         iterations = 1000,
                                         convergenceCutoff = 0.001,
                                         learningRate = 0.2,
                                         optimizationType = "SGD",
                                         initialMetaFeatureWeights = 0){
  # Initialize metafeature weights.
  weights_count <- length(modelInputs@metaFeatures)
  wt_name <- names(modelInputs@metaFeatures)
  
  # Initialize data frame with maximum number of iterations.
  tracking.frame <- as.data.frame(matrix(-1, nrow = iterations, 
                                         ncol = 2 + (2 * weights_count)))
  tracking.frame.cnames <- c("Iteration", "Error")
  tracking.frame.cnames <- c(tracking.frame.cnames, paste("Weight", wt_name, sep = "_"))
  tracking.frame.cnames <- c(tracking.frame.cnames, paste("Gradient", wt_name, sep = "_"))
  colnames(tracking.frame) <- tracking.frame.cnames
  tracking.frame$Error[1] <- .Machine$double.xmax
  tracking.frame$Iteration[1] <- 0

  # Initialize weights with uniform distribution.
  max_phen <- max(modelInputs@true.phenotypes)
  num_edges <- nrow(modelInputs@edge.wise.prediction)
  weights <- as.matrix(rep(1 / weights_count, weights_count))
  if(initialMetaFeatureWeights != 0){
    weights <- initialMetaFeatureWeights
  }
  names(weights) <- names(modelInputs@metaFeatures)
  
  tracking.frame[1,3:(2+weights_count)] <- weights

  # Initialize and return results.
  newModelResults <- methods::new("ModelResults", model.input=modelInputs,
                        iteration.tracking=tracking.frame, max.iterations=iterations,
                        convergence.cutoff=convergenceCutoff, learning.rate=learningRate,
                        previous.metaFeature.weights=as.matrix(rep(0,length(weights))), 
                        current.metaFeature.weights=as.matrix(weights),
                        current.gradient=as.matrix(rep(-1,length(weights))),
                        previous.momentum=as.matrix(rep(0,length(weights))),
                        previous.update.vector=as.matrix(rep(0,length(weights))),
                        sum.square.gradients=as.matrix(rep(0,length(weights))),
                        current.iteration=0, 
                        optimization.type=optimizationType,
                        pairs = "")
  return(newModelResults)
}

#' 
#' Format the input for graph-based learning. This input consists of:
#' - The Laplacian of a line graph built from the co-regulation graphs, where 
#' each edge corresponds to a pair of analytes.
#' - A prediction value for each edge of the line graph, for each sample X.
#' - The true prediction values Y for each sample X.
#' - The co-regulation graph.
#' - The line graph built from the co-regulation graphs.
#' - The input data (analyte levels, covariates, and phenotype).
#' - The model results from IntLIM for each predictor.
#' - The metafeatures for each model from IntLIM.
#' - The phenotype / outcome column name in the input data.
#' - The type of the outcome / phenotype ("numeric" or "categorical").
#' - The analyte type used as the outcome variable for each model (1 or 2).
#' - The analyte type used as the independent variable for each model (1 or 2).
#' @slot metaFeatures A list of calculated meta-features for each sample
#' @slot modelProperties A data frame that includes model information, i.e. R^2,
#' interaction term p-value, and coefficients.
#' @slot inputData An IntLIMData object that includes slots for the sample data,
#' the analyte data, and the analyte meta data.
#' @param predictionGraphs A list of igraph objects, each of which includes
#' predictions for each edge.
#' @param coregulationGraph An igraph object containing the coregulation graph.
#' @param stype.class The class of the outcome ("numeric" or "categorical")
#' @param stype The phenotype or outcome of interest
#' @param covariates A list of covariates to include in the model. These will be in the
#' sampleMetaData slot of the inputData variable.
#' @param edgeTypeList List containing one or more of the following to include
#' in the line graph:
#' - "shared.outcome.analyte"
#' - "shared.independent.analyte"
#' - "analyte.chain"
#' @param verbose Whether to print the number of predictions replaced in each sample.
#' TRUE or FALSE. Default is FALSE.
#' @param outcome The outcome used in the IntLIM models.
#' @param independent.var.type The independent variable type used in the IntLIM models.
#' @param errorCorrelationGroupReps Representative predictors for groups of predictors with correlated error.
#' @param metaFeatures The metafeature values computed during input prep (DoModelSetup()).
#' @param modelProperties The results of RunIntLim().
#' @param inputData An IntLimData object.
#' @return An object of the ModelInput class. This object will contain the input data,
#' graph information, and individual predictor properties and their metafeatures.
#' @export
FormatInput <- function(predictionGraphs, coregulationGraph, metaFeatures, modelProperties,
                        inputData, stype.class, edgeTypeList, stype, verbose = TRUE, covariates = c(),
                        outcome = 2, independent.var.type = 2, errorCorrelationGroupReps){
  
  # Extract edge-wise predictions.
  predictions_by_edge <- lapply(names(predictionGraphs), function(sampName){
    df_predictions <- igraph::as_data_frame(predictionGraphs[[sampName]])
    edge_names <- paste(make.names(df_predictions$from), make.names(df_predictions$to),
                        sep = "__")
    df_predictions_new <- data.frame(Edge = edge_names, Weight = df_predictions$weight)
    return(df_predictions_new)
  })
  names(predictions_by_edge) <- names(predictionGraphs)
  predicted_weights_only <- lapply(predictions_by_edge, function(pred){
    return(pred$Weight)
  })
  predictions_flattened <- t(data.frame(predicted_weights_only))
  colnames(predictions_flattened) <- predictions_by_edge[[1]]$Edge
  
  # Obtain the predictions.
  Y <- inputData@sampleMetaData[,stype]
  if(stype.class == "factor"){
    Y <- as.numeric(Y)-1
  }
  names(Y) <- names(predictions_by_edge)
  
  # Create a ModelInput object and return it.
  if(is.null(covariates)){
    covariates <- ""
  }
  # Convert true values if needed.
  Ymat <- as.matrix(Y)
  if(!is.numeric(Ymat)){
    catNames <- sort(unique(Ymat))
    trueValuesChar <- Ymat
    Y <- matrix(rep(0, length(trueValuesChar)), ncol = ncol(trueValuesChar),
                       nrow = nrow(trueValuesChar))
    Y[which(trueValuesChar == catNames[2])] <- 1
    Y <- as.numeric(Y)
  }
  newModelInput <- methods::new("ModelInput", edge.wise.prediction=predictions_flattened[,errorCorrelationGroupReps],
                                true.phenotypes=Y, coregulation.graph=igraph::get.adjacency(coregulationGraph, sparse = FALSE), 
                                input.data = inputData, model.properties = modelProperties,
                                metaFeatures = metaFeatures, stype = stype, covariates = covariates, stype.class = stype.class,
                                outcome = outcome, independent.var.type = independent.var.type)
  return(newModelInput)
}

#' Optimize the combination of predictors by metafeatures alone (in other words,
#' exclude pooling and combine in a single layer using a linear combination).
#' @param modelResults An object of the ModelResults class.
#' @param verbose Whether to print results as you run the model.
#' @param modelRetention Strategy for model retention. "stringent" (the default)
#' retains only models that improve the prediction score. "lenient" also retains models that
#' neither improve nor reduce the prediction score.
#' @param useCutoff Whether or not to use the cutoff for prediction. Default is FALSE.
#' @param modelCountCutoff Only consider this number of models. If not provided, use all.
#' Models will be selected by positive weight.
#' @param pruningTechnique Pruning technique to use. Possible methods are "backward.stepwise",
#' "forward.stepwise", "individual.performance", and "exhaustive".
#' @param stochastic Whether to use a stochastic model (TRUE) or a batch model (FALSE)
#' @param doPooling Whether or not to pool predictors together using the structure of the graph.
#' @param doPruning Whether or not to prune predictors. If pruning is not done,
#' result is a weighted combination of all predictors.
#' @param averaging Whether to use averaging to combine predictors instead of retaining
#' the same functional form for input and output.
#' @param zeroOut This parameter zeros out predictors outside of the allowed
#' range.
#' @param feedback This parameter controls whether or not a feedback layer will be implemented,
#' i.e. whether a full pruning procedure over the entire graph will be allowed to inform
#' pruning of individual neighborhoods.
#' @param trimming Set to "edgewise" to trim edges at each layer or "modelwise"
#' to trim entire models (neighborhoods, connected components) at each layer.
#' @return A ModelResults object, with all of the tracking information from each
#' iteration filled in.
#' @export
OptimizeMetaFeatureCombo <- function(modelResults, verbose = TRUE, modelRetention = "stringent",
                                    useCutoff = FALSE, modelCountCutoff = 0, pruningTechnique = "backward.stepwise",
                                    stochastic = TRUE, doPooling = TRUE, doPruning = TRUE,
                                    averaging = FALSE, zeroOut = FALSE, feedback = FALSE,
                                    trimming = "modelwise"){
  
  # Calculate prediction cutoffs.
  predictionCutoffs <- CalculatePredictionCutoffs(modelResults = modelResults)

  # Start the first iteration and calculate a dummy weight delta.
  modelResults@current.iteration <- 1
  weight.delta <- sqrt(sum((modelResults@current.metaFeature.weights - 
                              modelResults@previous.metaFeature.weights)^2))
  
  # Get the initial pruned models.
  # Compute weights for each predictor.
  
  weights <- ComputeMetaFeatureWeights(modelResults = modelResults,
                                       metaFeatures = modelResults@model.input@metaFeatures)
  weightMax <- apply(weights, 2, max)
  pairsByWeightFiltered <- names(weightMax)[order(weightMax, decreasing = TRUE)]
  # Find the components in the graph.
  graph <- igraph::graph_from_adjacency_matrix(modelResults@model.input@coregulation.graph)
  components <- igraph::components(graph, mode = "weak")
  componentFull <- lapply(unique(components$membership), function(i){
    componentNodes <- names(components$membership)[which(components$membership == i)]
    component <- igraph::as_edgelist(igraph::subgraph(graph, componentNodes))
    return(paste(component[,1], component[,2], sep = "__"))
  })
  componentTargets <- lapply(unique(components$membership), function(i){
    componentNodes <- names(components$membership)[which(components$membership == i)]
    component <- igraph::as_edgelist(igraph::subgraph(graph, componentNodes))
    return(as.character(component[,2]))
  })
  componentSources <- lapply(unique(components$membership), function(i){
    componentNodes <- names(components$membership)[which(components$membership == i)]
    component <- igraph::as_edgelist(igraph::subgraph(graph, componentNodes))
    return(as.character(component[,1]))
  })
  componentModels <- methods::new("Model", pairs = componentFull, targets = componentTargets, sources = componentSources,
                         modelsKept = 1:length(componentFull))
  pruningMethod <- "error.t.test"
  
  # Set up the predictor input for pruning.
  pairsPredAll <- NULL
  prunedModels <- NULL
  feedbackPairs <- NULL
  feedbackModel <- NULL
  feedbackPrediction <- NULL
  feedbackSignificance <- NULL
  if(doPooling == FALSE || doPruning == FALSE || feedback == TRUE){
    # If we are not pooling, simply make a long list of predictors with no grouping.
    edges <- list(colnames(modelResults@model.input@edge.wise.prediction))
    targets <- list(unlist(lapply(colnames(modelResults@model.input@edge.wise.prediction), function(name){
      return(strsplit(name, split = "__")[[1]][2])
    })))
    sources <- list(unlist(lapply(colnames(modelResults@model.input@edge.wise.prediction), function(name){
      return(strsplit(name, split = "__")[[1]][1])
    })))
    pairsPredAll <- list(edge = edges, target = targets, source = sources)
    if(feedback == TRUE){
      feedbackPairs <- pairsPredAll
    }
  }
  if(doPooling == TRUE && doPruning == TRUE){
    # If we are pooling, group the predictors by neighborhood.
    pairsPredAll <- GENN::ObtainSubgraphNeighborhoods(modelInput = modelResults@model.input)
  }
  # Do first pruning (if applicable).
  if(doPruning == FALSE){
    prunedModels <- pairsPredAll
    prunedModels <- methods::new("Model", pairs = pairsPredAll$edge, 
                                 targets = pairsPredAll$target, 
                                 sources = pairsPredAll$source, modelsKept = as.integer(1))
  }else{
    prunedModels <- GENN::DoSignificancePropagation(pairs = pairsPredAll, modelResults = modelResults,
                                                                         verbose = verbose,
                                                                         modelRetention = modelRetention,
                                                                         useCutoff = useCutoff,
                                                                         minCutoff = predictionCutoffs$min,
                                                                         maxCutoff = predictionCutoffs$max,
                                                                         weights = weights,
                                                                         components = componentModels,
                                                                         pruningTechnique = pruningTechnique,
                                                                         doPooling = doPooling,
                                                                         averaging = averaging,
                                                                         feedbackPairs = feedbackPairs,
                                                                         trimming = trimming)
  }
 
  flatPairs <- unlist(prunedModels@pairs)
  modelResults@pairs <- flatPairs
  
  # Set initial error.
  Y.pred <- DoPrediction(modelResults = modelResults, prunedModels = prunedModels,
                         useCutoff = useCutoff,
                         minCutoff = predictionCutoffs$min,
                         maxCutoff = predictionCutoffs$max,
                         weights = weights,
                         averaging = averaging)

  if(modelResults@model.input@stype.class == "categorical"){
    modelResults@iteration.tracking$Error[1] <- 
      ComputeClassificationError(modelResults@model.input@true.phenotypes, Y.pred)
  }else{
    modelResults@iteration.tracking$Error[1] <- 
      ComputeRMSE(modelResults@model.input@true.phenotypes, Y.pred)
  }
  print(paste("Initial error is", modelResults@iteration.tracking$Error[1]))

  # Repeat the training process for all iterations, until the maximum is reached
  # or until convergence.
  sequential_convergence_count <- 0
  sequential_count_limit <- 5
  while(modelResults@current.iteration <= modelResults@max.iterations
        && (sequential_convergence_count < sequential_count_limit)){
    # For stochastic training, permute the samples, then compute the gradient one
    # sample at a time.
    perm_samples <- sample(1:nrow(modelResults@model.input@edge.wise.prediction),
                          nrow(modelResults@model.input@edge.wise.prediction))

    # Initialize previous weight vector.
    modelResults@previous.metaFeature.weights <- modelResults@current.metaFeature.weights

    if(stochastic == TRUE){
      for(i in perm_samples){
        # Extract data for each sample.
        newModelResults <- modelResults
        newModelResults@model.input@edge.wise.prediction <- 
          as.matrix(modelResults@model.input@edge.wise.prediction[i,])
        newModelResults@model.input@true.phenotypes <- 
          modelResults@model.input@true.phenotypes[i]
        newModelResults@model.input@input.data@analyteType1 <- as.matrix(modelResults@model.input@input.data@analyteType1[,i])
        newModelResults@model.input@input.data@analyteType2 <- as.matrix(modelResults@model.input@input.data@analyteType2[,i])
        newModelResults@model.input@input.data@sampleMetaData <- as.data.frame(modelResults@model.input@input.data@sampleMetaData[i,])
        metaFeaturesSamp <- lapply(1:length(modelResults@model.input@metaFeatures), function(imp){
          df <- t(as.data.frame(modelResults@model.input@metaFeatures[[imp]][i,]))
          rownames(df) <- rownames(modelResults@model.input@edge.wise.prediction)[i]
          colnames(df) <- colnames(modelResults@model.input@edge.wise.prediction)
          return(df)
        })
        weightsLocal <- weights[i,]
        names(metaFeaturesSamp) <- names(modelResults@model.input@metaFeatures)
        newModelResults@model.input@metaFeatures <- metaFeaturesSamp
        
        # Do training iteration.
        newModelResults <- DoSingleTrainingIteration(modelResults = newModelResults,
                                                     prunedModels = prunedModels,
                                                     useCutoff = useCutoff,
                                                     minCutoff = predictionCutoffs$min,
                                                     maxCutoff = predictionCutoffs$max,
                                                     weights = weightsLocal,
                                                     averaging = averaging)
        
        # Update weights and gradient in the model results according to the
        # results of this sample.
        modelResults@current.metaFeature.weights <- newModelResults@current.metaFeature.weights
        modelResults@previous.metaFeature.weights <- newModelResults@previous.metaFeature.weights
        modelResults@current.gradient <- newModelResults@current.gradient
        modelResults@previous.momentum <- newModelResults@previous.momentum
        modelResults@previous.update.vector <- newModelResults@previous.update.vector
        modelResults@iteration.tracking <- newModelResults@iteration.tracking
      }
    }else{
      
      # Do training iteration.
      newModelResults <- DoSingleTrainingIteration(modelResults = modelResults,
                                                   prunedModels = prunedModels,
                                                   useCutoff = useCutoff,
                                                   minCutoff = predictionCutoffs$min,
                                                   maxCutoff = predictionCutoffs$max,
                                                   weights = weights,
                                                   averaging = averaging)
      
      # Update weights and gradient in the model results according to the
      # results of this sample.
      modelResults@current.metaFeature.weights <- newModelResults@current.metaFeature.weights
      modelResults@previous.metaFeature.weights <- newModelResults@previous.metaFeature.weights
      modelResults@current.gradient <- newModelResults@current.gradient
      modelResults@previous.momentum <- newModelResults@previous.momentum
      modelResults@previous.update.vector <- newModelResults@previous.update.vector
      modelResults@iteration.tracking <- newModelResults@iteration.tracking
    }
    
    
    weights <- ComputeMetaFeatureWeights(modelResults = modelResults,
                                         metaFeatures = modelResults@model.input@metaFeatures)

    # Fill in the prediction for error calculation.
    Y.pred <- DoPrediction(modelResults = modelResults, prunedModels = prunedModels,
                           useCutoff = useCutoff,
                           minCutoff = predictionCutoffs$min,
                           maxCutoff = predictionCutoffs$max,
                           weights = weights,
                           averaging = averaging)

    # Compute the prediction error over all samples.
    modelResults@iteration.tracking$Iteration[modelResults@current.iteration+1] <- modelResults@current.iteration
    if(modelResults@model.input@stype.class == "categorical"){
      modelResults@iteration.tracking$Error[modelResults@current.iteration+1] <- 
        ComputeClassificationError(modelResults@model.input@true.phenotypes, Y.pred)
    }else{
      modelResults@iteration.tracking$Error[modelResults@current.iteration+1] <- 
        ComputeRMSE(modelResults@model.input@true.phenotypes, Y.pred)
    }
    currentError <- modelResults@iteration.tracking$Error[modelResults@current.iteration+1]

    # Get the new pruned models.
    if(doPruning == TRUE){
      prunedModels <- GENN::DoSignificancePropagation(pairs = pairsPredAll, modelResults = modelResults,
                                                                           verbose = verbose,
                                                                           modelRetention = modelRetention,
                                                                           useCutoff = useCutoff,
                                                                           minCutoff = predictionCutoffs$min,
                                                                           maxCutoff = predictionCutoffs$max,
                                                                           weights = weights,
                                                                           components = componentModels,
                                                                           pruningTechnique = pruningTechnique,
                                                                           doPooling = doPooling,
                                                                           averaging = averaging,
                                                                           zeroOut = zeroOut,
                                                                           trimming = trimming)
    }
    
    flatPairs <- unlist(prunedModels@pairs)
    modelResults@pairs <- flatPairs
    prunedModelSizes <- lapply(prunedModels@pairs, function(model){return(length(model))})

    # Print the weight delta and error.
    weight.delta <- sqrt(sum((modelResults@current.metaFeature.weights - modelResults@previous.metaFeature.weights)^2))
    if(modelResults@current.iteration %% 1 == 0){
      print(paste("iteration", modelResults@current.iteration, ": weight delta is", weight.delta,
                  "and error is", paste0(currentError, ". Final subgraph has ", prunedModelSizes, " edges.")))
      sortedEdges <- sort(flatPairs)
      if(prunedModelSizes > 5){
        sortedEdges <- sortedEdges[1:5]
      }
      print(paste0("Edges include: ", paste(sortedEdges, collapse = ", ")))
    }
    
    # Update the iteration.
    modelResults@current.iteration <- modelResults@current.iteration + 1
    modelResults@pairs <- flatPairs
    
    # Increment the number of convergent iterations if applicable.
    if(weight.delta < modelResults@convergence.cutoff){
      sequential_convergence_count = sequential_convergence_count + 1
    }else{
      sequential_convergence_count = 0
    }
  }
  # Make sure that what is being returned is the lowest error iteration.
  # If not, grab the lowest error iteration.
  if(min(modelResults@iteration.tracking$Error[which(!is.na(modelResults@iteration.tracking$Error))]) <
     modelResults@iteration.tracking$Error[modelResults@current.iteration]){
    
    # Find index of minimum error.
    which_min <- which.min(modelResults@iteration.tracking$Error[1:modelResults@current.iteration])
    
    # Update model results.
    modelResults@current.iteration <- modelResults@current.iteration + 1
    modelResults@iteration.tracking[modelResults@current.iteration,] <- modelResults@iteration.tracking[which_min,]
    currentError <- modelResults@iteration.tracking$Error[modelResults@current.iteration]
    
    # Get the new pruned models.
    if(doPruning == TRUE){
      prunedModels <- GENN::DoSignificancePropagation(pairs = pairsPredAll, modelResults = modelResults,
                                                                           verbose = verbose,
                                                                           modelRetention = modelRetention,
                                                                           useCutoff = useCutoff,
                                                                           minCutoff = predictionCutoffs$min,
                                                                           maxCutoff = predictionCutoffs$max,
                                                                           weights = weights,
                                                                           components = componentModels,
                                                                           pruningTechnique = pruningTechnique,
                                                                           doPooling = doPooling,
                                                                           averaging = averaging,
                                                                           trimming = trimming)
    }
    
    flatPairs <- unlist(prunedModels@pairs)
    modelResults@pairs <- flatPairs
    prunedModelSizes <- lapply(prunedModels@pairs, function(model){return(length(model))})
    
    # Print the error.
    weight.delta <- sqrt(sum((modelResults@current.metaFeature.weights - modelResults@previous.metaFeature.weights)^2))
    if(modelResults@current.iteration %% 1 == 0){
      print(paste("Returning to iteration", which_min - 1, "with error of", paste0(currentError, ". Final subgraph has ", prunedModelSizes, " edges.")))
      sortedEdges <- sort(flatPairs)
      if(prunedModelSizes > 5){
        sortedEdges <- sortedEdges[1:5]
      }
      print(paste0("Edges include: ", paste(sortedEdges, collapse = ", ")))
    }
  }
  
  # If we exited before the maximum number of iterations, remove the rest of the
  # tracking data.
  if((modelResults@current.iteration-1) < modelResults@max.iterations){
    modelResults@iteration.tracking <- modelResults@iteration.tracking[1:(modelResults@current.iteration-1),]
  }

  return(modelResults)
}

#' Obtain the combination of predictors without including any weights.
#' @param modelResults An object of the ModelResults class.
#' @param verbose Whether to print results as you run the model.
#' @param modelRetention Strategy for model retention. "stringent" (the default)
#' retains only models that improve the prediction score. "lenient" also retains models that
#' neither improve nor reduce the prediction score.
#' @param useCutoff Whether or not to use the cutoff for prediction. Default is FALSE.
#' @param pruningTechnique Pruning technique to use. Possible methods are "backward.stepwise",
#' "forward.stepwise", "individual.performance", and "exhaustive".
#' @param doPooling Whether or not to pool predictors together using the structure of the graph.
#' @param doPruning Whether or not to prune predictors. If pruning is not done,
#' result is a weighted combination of all predictors.
#' @param averaging If TRUE, then averaging is used to combine predictors rather than
#' retaining the same functional form for both the input and the output.
#' @param zeroOut This parameter zeros out predictors outside of the allowed
#' range.
#' @return A ModelResults object, with all of the tracking information from each
#' iteration filled in.
#' @export
ObtainUnweightedModel <- function(modelResults, verbose = TRUE, modelRetention = "stringent",
                                     useCutoff = FALSE, pruningTechnique = "backward.stepwise",
                                     doPooling = TRUE, doPruning = TRUE, averaging = FALSE,
                                  zeroOut = FALSE){
  
  # Calculate prediction cutoffs.
  predictionCutoffs <- CalculatePredictionCutoffs(modelResults = modelResults)
  
  # Start the first iteration and calculate a dummy weight delta.
  modelResults@current.iteration <- 1

  # Find the components in the graph.
  graph <- igraph::graph_from_adjacency_matrix(modelResults@model.input@coregulation.graph)
  components <- igraph::components(graph, mode = "weak")
  componentFull <- lapply(unique(components$membership), function(i){
    componentNodes <- names(components$membership)[which(components$membership == i)]
    component <- igraph::as_edgelist(igraph::subgraph(graph, componentNodes))
    return(paste(component[,1], component[,2], sep = "__"))
  })
  componentTargets <- lapply(unique(components$membership), function(i){
    componentNodes <- names(components$membership)[which(components$membership == i)]
    component <- igraph::as_edgelist(igraph::subgraph(graph, componentNodes))
    return(as.character(component[,2]))
  })
  componentSources <- lapply(unique(components$membership), function(i){
    componentNodes <- names(components$membership)[which(components$membership == i)]
    component <- igraph::as_edgelist(igraph::subgraph(graph, componentNodes))
    return(as.character(component[,1]))
  })
  componentModels <- methods::new("Model", pairs = componentFull, targets = componentTargets, sources = componentSources,
                                  modelsKept = 1:length(componentFull))
  
  pruningMethod <- "error.t.test"
  
  # Set all weights to be 1.
  weights <- matrix(rep(1, length(modelResults@model.input@edge.wise.prediction)),
                       ncol = ncol(modelResults@model.input@edge.wise.prediction), 
                       nrow = nrow(modelResults@model.input@edge.wise.prediction))
  colnames(weights) <- colnames(modelResults@model.input@edge.wise.prediction)
  rownames(weights) <- rownames(modelResults@model.input@edge.wise.prediction)
  weightMax <- apply(weights, 2, max)
  pairsByWeight <- names(weightMax)[order(weightMax, decreasing = TRUE)]
  modelCountCutoff <- length(pairsByWeight)
  pairsByWeightFiltered <- pairsByWeight[1:modelCountCutoff]

  # Set up the predictor input for pruning.
  pairsPredAll <- NULL
  prunedModels <- NULL
  if(doPooling == FALSE || doPruning == FALSE){
    # If we are not pooling, simply make a long list of predictors with no grouping.
    edges <- list(colnames(modelResults@model.input@edge.wise.prediction))
    targets <- list(unlist(lapply(colnames(modelResults@model.input@edge.wise.prediction), function(name){
      return(strsplit(name, split = "__")[[1]][2])
    })))
    sources <- list(unlist(lapply(colnames(modelResults@model.input@edge.wise.prediction), function(name){
      return(strsplit(name, split = "__")[[1]][1])
    })))
    pairsPredAll <- list(edge = edges, target = targets, source = sources)
  }else{
    # If we are pooling, group the predictors by neighborhood.
    pairsPredAll <- GENN::ObtainSubgraphNeighborhoods(modelInput = modelResults@model.input)
  }
  
  # Do first pruning (if applicable).
  if(doPruning == FALSE){
    prunedModels <- pairsPredAll
    prunedModels <- methods::new("Model", pairs = pairsPredAll$edge, 
                                 targets = pairsPredAll$target, 
                                 sources = pairsPredAll$source, modelsKept = as.integer(1))
  }else{
    prunedModels <- GENN::DoSignificancePropagation(pairs = pairsPredAll, modelResults = modelResults,
                                                                         verbose = verbose,
                                                                         modelRetention = modelRetention,
                                                                         useCutoff = useCutoff,
                                                                         minCutoff = predictionCutoffs$min,
                                                                         maxCutoff = predictionCutoffs$max,
                                                                         weights = weights,
                                                                         components = componentModels,
                                                                         pruningTechnique = pruningTechnique,
                                                                         doPooling = doPooling,
                                                                         averaging = averaging,
                                                                         zeroOut = zeroOut)
  }
  
  flatPairs <- unlist(prunedModels@pairs)
  modelResults@pairs <- flatPairs
  
  # Set initial error.
  Y.pred <- DoPrediction(modelResults = modelResults, prunedModels = prunedModels,
                         useCutoff = useCutoff,
                         minCutoff = predictionCutoffs$min,
                         maxCutoff = predictionCutoffs$max,
                         weights = weights, averaging = averaging)
  
  if(modelResults@model.input@stype.class == "categorical"){
    modelResults@iteration.tracking$Error[1] <- 
      ComputeClassificationError(modelResults@model.input@true.phenotypes, Y.pred)
  }else{
    modelResults@iteration.tracking$Error[1] <- 
      ComputeRMSE(modelResults@model.input@true.phenotypes, Y.pred)
  }
  print(paste("Error is", modelResults@iteration.tracking$Error[1]))
  
  return(modelResults)
}

#' Compute classification error.
#' @param true.Y The true phenotype of each sample.
#' @param pred.Y The predicted phenotype of each sample.
#' @return A vector of classification errors.
#' @export
ComputeClassificationError <- function(true.Y, pred.Y){
  
  # Round predictions to get 1 or 0.
  pred.Y <- round(pred.Y)

  # Find false and true positives and negatives.
  FP <- length(intersect(which(true.Y == 0), which(pred.Y == 1)))
  TP <- length(intersect(which(true.Y == 1), which(pred.Y == 1)))
  FN <- length(intersect(which(true.Y == 1), which(pred.Y == 0)))
  TN <- length(intersect(which(true.Y == 0), which(pred.Y == 0)))

  # Compute error and return.
  error <- (FP + FN) / (FP + FN + TP + TN)
  return(error)
}

#' Compute the normalized root mean squared error.
#' @param true.Y The true phenotype of each sample.
#' @param pred.Y The predicted phenotype of each sample.
#' @return A vector of RMSE.
#' @export
ComputeRMSE <- function(true.Y, pred.Y){
  RMSD <- sqrt(sum((true.Y - pred.Y)^2) / length(true.Y))
  return(RMSD)
}
ncats/MultiOmicsGraphPrediction documentation built on Aug. 23, 2023, 9:19 a.m.