R/metapredict_confusion.R

Defines functions calcConfusionValidation calcConfusionCv

Documented in calcConfusionCv calcConfusionValidation

#' Calculate confusion matrix for cross-validation.
#'
#' Calculate a confusion matrix based on predictions from cross-validation.
#'
#' @param cvFit cv.glmnet object from [metapredictCv()].
#' @param lambda value of lambda at which to use predictions.
#' @param ematMerged matrix of gene expression for genes by samples.
#' @param sampleMetadata data.frame of sample metadata.
#' @param className name of column in `sampleMetadata` containing the true
#'   labels.
#' @param classLevels Order of classes in the confusion matrix. If `NA`
#'   (default), then the function uses the order in `cvFit`.
#'
#' @return An object of class `table`.
#'
#' @export
calcConfusionCv = function(cvFit, lambda, ematMerged, sampleMetadata,
                           className = 'class', classLevels = NA) {
  if (is.na(classLevels[1])) {
    classLevels = names(cvFit$glmnet.fit$beta)}

  cvProbs = cvFit$fit.preval[, , which.min(abs(cvFit$lambda - lambda))]
  rownames(cvProbs) = colnames(ematMerged)
  colnames(cvProbs) = names(cvFit$glmnet.fit$beta)
  preds = colnames(cvProbs)[apply(cvProbs, MARGIN = 1, function(x) which.max(x))]
  predictedClass = factor(preds, levels = classLevels)

  classValues = mergeDataTable(colnames(ematMerged), sampleMetadata)[[className]]
  trueClass = factor(classValues, levels = classLevels)
  return(table(trueClass, predictedClass))}


#' Calculate confusion matrices (or matrix) for validation datasets.
#'
#' Calculate confusion matrices based on predictions for validation datasets.
#'
#' @param predsList list of predictions from [metapredict()].
#' @param lambda value of lambda at which to use predictions.
#' @param sampleMetadata data.frame of sample metadata.
#' @param className name of column in `sampleMetadata` containing the true
#'   labels.
#' @param classLevels Order of classes in the confusion matrix. If `NA`
#'   (default), then the function uses the order in `cvFit`.
#' @param each logical indicating whether to calculate a confusion matrix for
#'   each validation dataset (default) or one confusion matrix including all
#'   datasets.
#'
#' @return If `isTRUE(each)`, a list of objects of class `table`. Otherwise, an
#'   object of class `table`.
#'
#' @export
calcConfusionValidation = function(
  predsList, lambda, sampleMetadata, className = 'class', classLevels = NA,
  each = TRUE) {
  validationStudyName = NULL

  if (is.na(classLevels[1])) {
    classLevels = colnames(predsList[[1]])}

  if (isTRUE(each)) {
    confusion = foreach(validationStudyName = names(predsList)) %do% {
      predsProb = predsList[[validationStudyName]][, , 1]
      predsClass = colnames(predsProb)[apply(predsProb, MARGIN = 1,
                                             function(x) which.max(x))]
      predictedClass = factor(predsClass, levels = classLevels)

      sm = mergeDataTable(rownames(predsProb), sampleMetadata)
      trueClass = factor(sm[[className]], levels = classLevels)
      conf = table(trueClass, predictedClass)}
    names(confusion) = names(predsList)

  } else {
    predsProb = do.call(rbind, lapply(predsList, function(x) x[, , 1]))
    predsClass = colnames(predsProb)[apply(predsProb, MARGIN = 1,
                                           function(x) which.max(x))]
    predictedClass = factor(predsClass, levels = classLevels)

    sm = mergeDataTable(rownames(predsProb), sampleMetadata)
    trueClass = factor(sm[[className]], levels = classLevels)
    confusion = table(trueClass, predictedClass)}

  return(confusion)}
jakejh/metapredict documentation built on Feb. 14, 2023, 7:53 p.m.