R/getPatientPredictions.R

Defines functions getPatientPredictions

Documented in getPatientPredictions

#' Calculates patient-level classification accuracy across train/test splits
#'
#' @details Takes all the predictions across the different train/test splits,
#' and for each patient, generates a score indicating how many times they were
#' classified by netDx as belonging to each of the classes. The result is that
#' we get a measure of individual classification accuracy across the different
#' train/test splits.
#' @param predFiles (char) vector of paths to all test predictions
#' (e.g. 100 files for a 100 train/test split design).
#' Alternately, the user can also  provide a single directory name, and allow
#' the script to retrieve prediction files.
#' Format is 'rootDir/rngX/predictionResults.txt'
#' @param pheno (data.frame) ID=patient ID, STATUS=ground truth (known class
#' label). This table is required to get the master list of all patients, as
#' not every patient is classified in every split.
#' @param plotAccuracy (logical) if TRUE, shows fraction of times
#' patient is misclassified, using a dot plot
#' @return (list) of length 2.
#' 1) (data.frame) rows are patients, (length(predFiles)+2) columns.
#' Columns seq_len(length(predFiles)): Predicted labels for a given split (NA 
#' if patient was training sample for the split).
#' Column (length(predFiles)+1):
#' split, value is NA. Columns are : ID, REAL_STATUS, predStatus1,...
#' predStatusN.
#' Side effect of plotting a dot plot of % accuracy. Each dot is a patient, 
#' and the value is '% splits for which patient was classified correctly'.
#' @examples
#' inDir <- system.file("extdata","example_output",package="netDx")
#' data(pheno)
#' all_rngs <- list.dirs(inDir, recursive = FALSE)
#' all_pred_files <- unlist(lapply(all_rngs, function(x) {
#'     paste(x, 'predictionResults.txt', 
#'		sep = getFileSep())}))
#' pred_mat <- getPatientPredictions(all_pred_files, pheno)
#' @import ggplot2
#' @export
getPatientPredictions <- function(predFiles, pheno, plotAccuracy = FALSE) {
    if (length(predFiles) == 1) {
        message("predFiles is of length 1. Assuming directory\n")
        all_rngs <- list.dirs(predFiles, recursive = FALSE)
        all_rngs <- all_rngs[grep("rng", all_rngs)]
        predFiles <- unlist(lapply(all_rngs, function(x) {
            paste(x, "predictionResults.txt", 
		sep = getFileSep())
        }))
    } else {
        message("predFiles is of length > 1. Assuming filenames provided\n")
    }
    
    output_mat <- matrix(NA, nrow = nrow(pheno), ncol = length(predFiles) + 2)
    
    patient_list <- list()
    for (cur_pat in pheno$ID) patient_list[[cur_pat]] <- c()
    
    
    uq_mat <- matrix(NA, nrow = nrow(pheno), ncol = length(predFiles))
    rownames(uq_mat) <- pheno$ID
    for (ctr in seq_len(length(predFiles))) {
        curFile <- predFiles[ctr]
        dat <- read.delim(curFile, sep = "\t", header = TRUE, as.is = TRUE)
        for (k in seq_len(nrow(dat))) {
            tmp <- which(rownames(uq_mat) == dat$ID[k])
            uq_mat[tmp, ctr] <- dat$PRED_CLASS[k]
        }
    }
    
    uq_mat <- as.data.frame(uq_mat)
    pctCorr <- c()
    for (k in seq_len(nrow(uq_mat))) {
        testCt <- sum(!is.na(uq_mat[k, ]))
        cur <- sum(uq_mat[k, ] == pheno$STATUS[k], na.rm = TRUE)/testCt
        pctCorr <- c(pctCorr, cur * 100)
    }
    uq_mat <- cbind(uq_mat, pheno$STATUS, pctCorr)
    
    spos <- gregexpr("\\/", predFiles)
    # get the name of the iteration (rngX) assuming directory structure
    # rngX/pathway_CV_score.txt
    fNames <- lapply(seq_len(length(spos)), function(x) {
        n <- length(spos[[x]])
        y <- substr(predFiles[x], spos[[x]][n - 1] + 1, spos[[x]][n] - 1)
        y
    })
    fNames <- unlist(fNames)
    output_mat <- uq_mat
    rownames(output_mat) <- pheno$ID
    colnames(output_mat) <- c(fNames, "STATUS", "pctCorrect")
    
    if (plotAccuracy) {
        p <- ggplot(output_mat, aes(x = output_mat$pctCorrect)) + geom_dotplot()
        msg <- sprintf("Patient-level classification accuracy (N=%i)", 
					length(predFiles))
        p <- p + ggtitle(msg)
        p <- p + theme(axis.text = element_text(size = 13), 
					axis.title = element_text(size = 13))
        print(p)
        return(list(predictions = output_mat, plot = p))
    } else return(list(predictions = output_mat))
    
}
BaderLab/netDx documentation built on Sept. 26, 2021, 9:13 a.m.