#' Calculate confusion matrix for cross-validation.
#'
#' Calculate a confusion matrix based on predictions from cross-validation.
#'
#' @param cvFit cv.glmnet object from [metapredictCv()].
#' @param lambda value of lambda at which to use predictions.
#' @param ematMerged matrix of gene expression for genes by samples.
#' @param sampleMetadata data.frame of sample metadata.
#' @param className name of column in `sampleMetadata` containing the true
#' labels.
#' @param classLevels Order of classes in the confusion matrix. If `NA`
#' (default), then the function uses the order in `cvFit`.
#'
#' @return An object of class `table`.
#'
#' @export
calcConfusionCv = function(cvFit, lambda, ematMerged, sampleMetadata,
className = 'class', classLevels = NA) {
if (is.na(classLevels[1])) {
classLevels = names(cvFit$glmnet.fit$beta)}
cvProbs = cvFit$fit.preval[, , which.min(abs(cvFit$lambda - lambda))]
rownames(cvProbs) = colnames(ematMerged)
colnames(cvProbs) = names(cvFit$glmnet.fit$beta)
preds = colnames(cvProbs)[apply(cvProbs, MARGIN = 1, function(x) which.max(x))]
predictedClass = factor(preds, levels = classLevels)
classValues = mergeDataTable(colnames(ematMerged), sampleMetadata)[[className]]
trueClass = factor(classValues, levels = classLevels)
return(table(trueClass, predictedClass))}
#' Calculate confusion matrices (or matrix) for validation datasets.
#'
#' Calculate confusion matrices based on predictions for validation datasets.
#'
#' @param predsList list of predictions from [metapredict()].
#' @param lambda value of lambda at which to use predictions.
#' @param sampleMetadata data.frame of sample metadata.
#' @param className name of column in `sampleMetadata` containing the true
#' labels.
#' @param classLevels Order of classes in the confusion matrix. If `NA`
#' (default), then the function uses the order in `cvFit`.
#' @param each logical indicating whether to calculate a confusion matrix for
#' each validation dataset (default) or one confusion matrix including all
#' datasets.
#'
#' @return If `isTRUE(each)`, a list of objects of class `table`. Otherwise, an
#' object of class `table`.
#'
#' @export
calcConfusionValidation = function(
predsList, lambda, sampleMetadata, className = 'class', classLevels = NA,
each = TRUE) {
validationStudyName = NULL
if (is.na(classLevels[1])) {
classLevels = colnames(predsList[[1]])}
if (isTRUE(each)) {
confusion = foreach(validationStudyName = names(predsList)) %do% {
predsProb = predsList[[validationStudyName]][, , 1]
predsClass = colnames(predsProb)[apply(predsProb, MARGIN = 1,
function(x) which.max(x))]
predictedClass = factor(predsClass, levels = classLevels)
sm = mergeDataTable(rownames(predsProb), sampleMetadata)
trueClass = factor(sm[[className]], levels = classLevels)
conf = table(trueClass, predictedClass)}
names(confusion) = names(predsList)
} else {
predsProb = do.call(rbind, lapply(predsList, function(x) x[, , 1]))
predsClass = colnames(predsProb)[apply(predsProb, MARGIN = 1,
function(x) which.max(x))]
predictedClass = factor(predsClass, levels = classLevels)
sm = mergeDataTable(rownames(predsProb), sampleMetadata)
trueClass = factor(sm[[className]], levels = classLevels)
confusion = table(trueClass, predictedClass)}
return(confusion)}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.