R/rankingPlot.R

setGeneric("rankingPlot", function(results, ...)
{standardGeneric("rankingPlot")})

setMethod("rankingPlot", "list", 
          function(results, topRanked = seq(10, 100, 10),
                   comparison = c("within", "classificationName", "validation", "datasetName", "selectionName"),
                   referenceLevel = NULL,
                   lineColourVariable = c("validation", "datasetName", "classificationName", "selectionName", "None"),
                   lineColours = NULL, lineWidth = 1,
                   pointTypeVariable = c("datasetName", "classificationName", "validation", "selectionName", "None"),
                   pointSize = 2, legendLinesPointsSize = 1,
                   rowVariable = c("None", "datasetName", "classificationName", "validation", "selectionName"),
                   columnVariable = c("classificationName", "datasetName", "validation", "selectionName", "None"),
                   yMax = 100, fontSizes = c(24, 16, 12, 12, 12, 16),
                   title = if(comparison[1] == "within") "Feature Ranking Stability" else "Feature Ranking Commonality",
                   xLabelPositions = seq(10, 100, 10),
                   yLabel = if(is.null(referenceLevel)) "Average Common Features (%)" else paste("Average Common Features with", referenceLevel, "(%)"),
                   margin = grid::unit(c(1, 1, 1, 1), "lines"),
                   showLegend = TRUE, plot = TRUE, parallelParams = bpparam())
{
  comparison <- match.arg(comparison)            
  if(!requireNamespace("ggplot2", quietly = TRUE))
    stop("The package 'ggplot2' could not be found. Please install it.")
  if(comparison == "within" && !is.null(referenceLevel))
    stop("'comparison' should not be \"within\" if 'referenceLevel' is not NULL.")            
            
  ggplot2::theme_set(ggplot2::theme_classic() + ggplot2::theme(panel.border = ggplot2::element_rect(fill = NA)))            
  
  lineColourVariable <- match.arg(lineColourVariable)
  pointTypeVariable <- match.arg(pointTypeVariable)
  rowVariable <- match.arg(rowVariable)
  columnVariable <- match.arg(columnVariable)
  if(class(results[[1]]) == "ClassifyResult") resultsType <- "classification" else resultsType <- "selection"

  if(resultsType == "classification")
  {
    analyses <- sapply(results, function(result) result@classificationName)
    selections <- sapply(results, function(result) result@selectResult@selectionName)
    validations <- sapply(results, function(result) .validationText(result))
  } else { # Compare selections.
    analyses <- rep("Not classification", length(results))
    selections <- sapply(results, function(result) result@selectionName)
    validations <- rep("No cross-validation", length(results))
  }
  datasets <- sapply(results, function(result) result@datasetName)
  
  referenceVar <- switch(comparison, classificationName = analyses,
                         selectionName = selections,
                         datasetName = datasets,
                         validation = validations)
  if(!is.null(referenceLevel) && !(referenceLevel %in% referenceVar))
    stop("Reference level is neither a level of the comparison factor nor is it NULL.")
  
  allRankedList <- lapply(results, function(result)
  {
    if(resultsType == "classification") # ClassifyResult object rather than SelectResult object.
    {
      rankedFeatures <- result@selectResult@rankedFeatures
    } else { # A SelectResult object.
      rankedFeatures <- result@rankedFeatures
    }
    
    if(is.data.frame(rankedFeatures[[1]])) # Data set and feature ID columns.
      rankedFeatures <- lapply(rankedFeatures, function(features) paste(features[, 1], features[, 2]))
    else if("Pairs" %in% class(rankedFeatures[[1]]))
      rankedFeatures <- lapply(rankedFeatures, function(features) paste(first(features), second(features)))
    else if(class(rankedFeatures[[1]]) == "list" && is.character(rankedFeatures[[1]][[1]])) # Two-level list, such as generated by permuting and folding.
      rankedFeatures <- unlist(rankedFeatures, recursive = FALSE)
    else if(class(rankedFeatures[[1]]) == "list" && is.data.frame(rankedFeatures[[1]][[1]])) # Data set and feature ID columns.
      rankedFeatures <- unlist(lapply(rankedFeatures, function(folds) lapply(folds, function(fold) paste(fold[, 1], fold[, 2])), recursive = FALSE))
    else if(class(rankedFeatures[[1]]) == "list" && "Pairs" %in% class(rankedFeatures[[1]]))
      rankedFeatures <- unlist(lapply(rankedFeatures, function(folds) lapply(folds, function(fold) paste(first(fold), second(fold))), recursive = FALSE))
    
    rankedFeatures
  })
  
  if(comparison == "within")
  {
    if(resultsType == "selection")
      stop("'comparison' should not be \"within\" for results that are not cross-validations.")
    plotData <- do.call(rbind, bpmapply(function(result, rankedList)
    {
      averageOverlap <- rowMeans(do.call(cbind, mapply(function(features, index)
      {
        otherFeatures <- rankedList[(index + 1):length(rankedList)]
        sapply(otherFeatures, function(other)
        {
          sapply(topRanked, function(top)
          {
            length(intersect(features[1:top], other[1:top])) / top * 100
          })
        })
      }, rankedList[1:(length(rankedList) - 1)], 1:(length(rankedList) - 1), SIMPLIFY = FALSE)))
      validationText <- .validationText(result)

      data.frame(dataset = rep(result@datasetName, length(topRanked)),
                 analysis = rep(result@classificationName, length(topRanked)),
                 selection = rep(result@selectResult@selectionName, length(topRanked)),
                 validation = rep(validationText, length(topRanked)),
                 top = topRanked,
                 overlap = averageOverlap)
    }, results, allRankedList, BPPARAM = parallelParams, SIMPLIFY = FALSE))
  } else { # Commonality analysis.
    groupingVariablesValues <- setdiff(c(lineColourVariable, pointTypeVariable, rowVariable, columnVariable), comparison)
    groupingFactor <- paste(if("datasetName" %in% groupingVariablesValues) datasets,
                            if("classificationName" %in% groupingVariablesValues) analyses,
                            if("selectionName" %in% groupingVariablesValues) selections,
                            if("validation" %in% groupingVariablesValues) validations,
                            sep = " ")
    if(length(groupingFactor) == 0) groupingFactor <- rep("None", length(results))
    compareIndices <- split(1:length(results), groupingFactor)
    
    plotData <- do.call(rbind, unname(unlist(bplapply(compareIndices, function(indicesSet)
    {
      if(is.null(referenceLevel))
      {
        indiciesCombinations <- lapply(1:length(indicesSet),
                                       function(index) c(indicesSet[index], indicesSet[-index]))
      } else { # Compare each factor level other than the reference level to the reference level.
        indiciesCombinations <- list(c(indicesSet[match(referenceLevel, referenceVar[indicesSet])],
                                       indicesSet[setdiff(1:length(indicesSet), match(referenceLevel, referenceVar[indicesSet]))]))
      }

      unname(lapply(indiciesCombinations, function(indiciesCombination)
      {
        aDataset <- results[[indiciesCombination[1]]]
        rankedList <- allRankedList[[indiciesCombination[1]]]
        otherDatasetIndices <- indiciesCombination[-1]
        otherDatasets <- results[otherDatasetIndices]
        
        overlapToOther <- do.call(cbind, unname(lapply(otherDatasetIndices, function(otherIndex) # Other data sets to compare to.
        { 
          rowMeans(do.call(cbind, lapply(rankedList, function(rankings) # List of ranked features of a data set.
          {
            otherRankedList <- allRankedList[[otherIndex]]
            sapply(otherRankedList, function(otherRanked) # List of ranked features of another data set.
            {
              sapply(topRanked, function(top)
              {
                length(intersect(rankings[1:top], otherRanked[1:top])) / top * 100
              })          
            })
          })))
        })))
        
        if(is.null(referenceLevel))
        {
          overlapToOther <- rowMeans(overlapToOther)
          datasetText <- rep(aDataset@datasetName, length(topRanked))
          
          if(resultsType == "classification")
          {
            selectionText <- rep(aDataset@selectResult@selectionName, length(topRanked))
            analysisText <- rep(aDataset@classificationName, length(topRanked))
            validationText <- .validationText(aDataset)
          } else { # For standalone feature selection, there is no classification.
            selectionText <- rep(aDataset@selectionName, length(topRanked))
            analysisText <- "No classification"
            validationText <- "No cross-validation"
          }
        } else { # Each other level has been compared to the reference level of the factor.
          datasetText <- rep(sapply(otherDatasets, function(dataset) dataset@datasetName), each = nrow(overlapToOther))
          selectionText <- rep(sapply(otherDatasets, function(dataset) if(resultsType == "classification") dataset@selectResult@selectionName else dataset@selectionName), each = nrow(overlapToOther))
          
          if(resultsType == "classification")
          {
            analysisText <- rep(sapply(otherDatasets, function(dataset) dataset@classificationName),
                                each = nrow(overlapToOther))
            validationText <- rep(sapply(otherDatasets, function(dataset) .validationText(dataset)),
                                  each = nrow(overlapToOther))
          } else { # For standalone feature selection, there is no classification.
            analysisText <- "No classification"
            validationText <- "No cross-validation"
          }
        }
        overlapToOther <- as.numeric(overlapToOther) # Convert matrix of overlaps to a vector. UPDATE
        topRankedAll <- rep(topRanked, length.out = length(overlapToOther))
        
        data.frame(dataset = datasetText,
                   analysis = analysisText,
                   selection = selectionText,
                   validation = validationText,
                   top = topRankedAll,
                   overlap = overlapToOther)
      }))
    }, BPPARAM = parallelParams), recursive = FALSE)))
  }
  rownames(plotData) <- NULL # Easier for viewing during maintenance.

  # Order factors in which they appeared in the user's list.
  plotData[, "dataset"] <- factor(plotData[, "dataset"], levels = unique(datasets))
  plotData[, "analysis"] <- factor(plotData[, "analysis"], levels = unique(analyses))
  plotData[, "selection"] <- factor(plotData[, "selection"], levels = unique(selections))
  plotData[, "validation"] <- factor(plotData[, "validation"], levels = unique(validations))
  
  if(is.null(lineColours) && lineColourVariable != "None")
    lineColours <- scales::hue_pal()(switch(lineColourVariable, validation = length(unique(plotData[, "validation"])), datasetName = length(unique(plotData[, "dataset"])), classificationName = length(unique(plotData[, "analysis"])), selectionName = length(unique(plotData[, "selection"]))))
  legendPosition <- ifelse(showLegend == TRUE, "right", "none")
  
  overlapPlot <- ggplot2::ggplot(plotData, ggplot2::aes_string(x = "top", y = "overlap",
                          colour = switch(lineColourVariable, validation = "validation", datasetName = "dataset", classificationName = "analysis", selectionName = "selection", None = NULL),
                          shape = switch(pointTypeVariable, validation = "validation", datasetName = "dataset", classificationName = "analysis", selectionName = "selection", None = NULL))) +
                          ggplot2::geom_line(size = lineWidth) + ggplot2::geom_point(size = pointSize) + ggplot2::scale_x_continuous(breaks = xLabelPositions, limits = range(xLabelPositions)) + ggplot2::coord_cartesian(ylim = c(0, yMax)) +
                          ggplot2::xlab("Top Features") + ggplot2::ylab(yLabel) +
                          ggplot2::ggtitle(title) + ggplot2::labs(colour = switch(lineColourVariable, validation = "Validation", datasetName = "Dataset", classificationName = "Analysis", classificationName = "Analysis", selectionName = "Feature\nSelection"), shape = switch(pointTypeVariable, validation = "Validation", datasetName = "Dataset", classificationName = "Analysis", selectionName = "Feature\nSelection")) +
                          ggplot2::theme(axis.title = ggplot2::element_text(size = fontSizes[2]), axis.text = ggplot2::element_text(colour = "black", size = fontSizes[3]), legend.position = legendPosition, legend.title = ggplot2::element_text(size = fontSizes[4]), legend.text = ggplot2::element_text(size = fontSizes[5]), plot.title = ggplot2::element_text(size = fontSizes[1], hjust = 0.5), plot.margin = margin) +
                          ggplot2::guides(colour = ggplot2::guide_legend(override.aes = list(size = legendLinesPointsSize)),
                                          shape = ggplot2::guide_legend(override.aes = list(size = legendLinesPointsSize)))
  
  if(!is.null(lineColours))
    overlapPlot <- overlapPlot + ggplot2::scale_colour_manual(values = lineColours)
  
  if(rowVariable != "None" || columnVariable != "None")
    overlapPlot <- overlapPlot + ggplot2::facet_grid(reformulate(switch(columnVariable, validation = "validation", datasetName = "dataset", classificationName = "analysis", None = '.'), switch(rowVariable, validation = "validation", datasetName = "dataset", classificationName = "analysis", None = '.'))) + ggplot2::theme(strip.text = ggplot2::element_text(size = fontSizes[6]))
  
  if(plot == TRUE)
    print(overlapPlot)
  
  overlapPlot
})

Try the ClassifyR package in your browser

Any scripts or data that you put into this service are public.

ClassifyR documentation built on Nov. 8, 2020, 6:53 p.m.